Merge pull request #1152 from NousResearch/hermes/hermes-f47f71c0
feat: concurrent tool execution with ThreadPoolExecutor
This commit is contained in:
commit
0157253145
3 changed files with 429 additions and 2 deletions
|
|
@ -702,6 +702,168 @@ class TestExecuteToolCalls:
|
|||
assert "Truncated" in messages[0]["content"]
|
||||
|
||||
|
||||
class TestConcurrentToolExecution:
|
||||
"""Tests for _execute_tool_calls_concurrent and dispatch logic."""
|
||||
|
||||
def test_single_tool_uses_sequential_path(self, agent):
|
||||
"""Single tool call should use sequential path, not concurrent."""
|
||||
tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_seq.assert_called_once()
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_clarify_forces_sequential(self, agent):
|
||||
"""Batch containing clarify should use sequential path."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="clarify", arguments='{"question":"ok?"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_seq.assert_called_once()
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_multiple_tools_uses_concurrent_path(self, agent):
|
||||
"""Multiple non-interactive tools should use concurrent path."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="read_file", arguments='{"path":"x.py"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_con.assert_called_once()
|
||||
mock_seq.assert_not_called()
|
||||
|
||||
def test_concurrent_executes_all_tools(self, agent):
|
||||
"""Concurrent path should execute all tools and append results in order."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{"q":"alpha"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{"q":"beta"}', call_id="c2")
|
||||
tc3 = _mock_tool_call(name="web_search", arguments='{"q":"gamma"}', call_id="c3")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2, tc3])
|
||||
messages = []
|
||||
|
||||
call_log = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
call_log.append(name)
|
||||
return json.dumps({"result": args.get("q", "")})
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 3
|
||||
# Results must be in original order
|
||||
assert messages[0]["tool_call_id"] == "c1"
|
||||
assert messages[1]["tool_call_id"] == "c2"
|
||||
assert messages[2]["tool_call_id"] == "c3"
|
||||
# All should be tool messages
|
||||
assert all(m["role"] == "tool" for m in messages)
|
||||
# Content should contain the query results
|
||||
assert "alpha" in messages[0]["content"]
|
||||
assert "beta" in messages[1]["content"]
|
||||
assert "gamma" in messages[2]["content"]
|
||||
|
||||
def test_concurrent_preserves_order_despite_timing(self, agent):
|
||||
"""Even if tools finish in different order, messages should be in original order."""
|
||||
import time as _time
|
||||
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{"q":"slow"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{"q":"fast"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
q = args.get("q", "")
|
||||
if q == "slow":
|
||||
_time.sleep(0.1) # Slow tool
|
||||
return f"result_{q}"
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert messages[0]["tool_call_id"] == "c1"
|
||||
assert "result_slow" in messages[0]["content"]
|
||||
assert messages[1]["tool_call_id"] == "c2"
|
||||
assert "result_fast" in messages[1]["content"]
|
||||
|
||||
def test_concurrent_handles_tool_error(self, agent):
|
||||
"""If one tool raises, others should still complete."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
call_count = [0]
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
raise RuntimeError("boom")
|
||||
return "success"
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 2
|
||||
# First tool should have error
|
||||
assert "Error" in messages[0]["content"] or "boom" in messages[0]["content"]
|
||||
# Second tool should succeed
|
||||
assert "success" in messages[1]["content"]
|
||||
|
||||
def test_concurrent_interrupt_before_start(self, agent):
|
||||
"""If interrupt is requested before concurrent execution, all tools are skipped."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="read_file", arguments='{}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
with patch("run_agent._set_interrupt"):
|
||||
agent.interrupt()
|
||||
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
assert len(messages) == 2
|
||||
assert "cancelled" in messages[0]["content"].lower() or "skipped" in messages[0]["content"].lower()
|
||||
assert "cancelled" in messages[1]["content"].lower() or "skipped" in messages[1]["content"].lower()
|
||||
|
||||
def test_concurrent_truncates_large_results(self, agent):
|
||||
"""Concurrent path should truncate results over 100k chars."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
big_result = "x" * 150_000
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value=big_result):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 2
|
||||
for m in messages:
|
||||
assert len(m["content"]) < 150_000
|
||||
assert "Truncated" in m["content"]
|
||||
|
||||
def test_invoke_tool_dispatches_to_handle_function_call(self, agent):
|
||||
"""_invoke_tool should route regular tools through handle_function_call."""
|
||||
with patch("run_agent.handle_function_call", return_value="result") as mock_hfc:
|
||||
result = agent._invoke_tool("web_search", {"q": "test"}, "task-1")
|
||||
mock_hfc.assert_called_once_with(
|
||||
"web_search", {"q": "test"}, "task-1",
|
||||
enabled_tools=list(agent.valid_tool_names),
|
||||
)
|
||||
assert result == "result"
|
||||
|
||||
def test_invoke_tool_handles_agent_level_tools(self, agent):
|
||||
"""_invoke_tool should handle todo tool directly."""
|
||||
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo:
|
||||
result = agent._invoke_tool("todo", {"todos": []}, "task-1")
|
||||
mock_todo.assert_called_once()
|
||||
assert "ok" in result
|
||||
|
||||
|
||||
class TestHandleMaxIterations:
|
||||
def test_returns_summary(self, agent):
|
||||
resp = _mock_response(content="Here is a summary of what I did.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue