fix: make concurrent tool batching path-aware for file mutations (#1914)
* Improve tool batching independence checks * fix: address review feedback on path-aware batching - Log malformed/non-dict tool arguments at debug level before falling back to sequential, instead of silently swallowing the error into an empty dict - Guard empty paths in _paths_overlap (unreachable in practice due to upstream filtering, but makes the invariant explicit) - Add tests: malformed JSON args, non-dict args, _paths_overlap unit tests including empty path edge cases - web_crawl is not a registered tool (only web_search/web_extract are); no addition needed to _PARALLEL_SAFE_TOOLS --------- Co-authored-by: kshitij <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
parent
050b43108c
commit
c0c14e60b4
2 changed files with 215 additions and 8 deletions
|
|
@ -806,7 +806,7 @@ class TestConcurrentToolExecution:
|
|||
mock_con.assert_not_called()
|
||||
|
||||
def test_multiple_tools_uses_concurrent_path(self, agent):
|
||||
"""Multiple non-interactive tools should use concurrent path."""
|
||||
"""Multiple read-only tools should use concurrent path."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="read_file", arguments='{"path":"x.py"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
|
|
@ -817,6 +817,94 @@ class TestConcurrentToolExecution:
|
|||
mock_con.assert_called_once()
|
||||
mock_seq.assert_not_called()
|
||||
|
||||
def test_terminal_batch_forces_sequential(self, agent):
|
||||
"""Stateful tools should not share the concurrent execution path."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="terminal", arguments='{"command":"pwd"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_seq.assert_called_once()
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_write_batch_forces_sequential(self, agent):
|
||||
"""File mutations should stay ordered within a turn."""
|
||||
tc1 = _mock_tool_call(name="read_file", arguments='{"path":"x.py"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="write_file", arguments='{"path":"x.py","content":"print(1)"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_seq.assert_called_once()
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_disjoint_write_batch_uses_concurrent_path(self, agent):
|
||||
"""Independent file writes should still run concurrently."""
|
||||
tc1 = _mock_tool_call(
|
||||
name="write_file",
|
||||
arguments='{"path":"src/a.py","content":"print(1)"}',
|
||||
call_id="c1",
|
||||
)
|
||||
tc2 = _mock_tool_call(
|
||||
name="write_file",
|
||||
arguments='{"path":"src/b.py","content":"print(2)"}',
|
||||
call_id="c2",
|
||||
)
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_con.assert_called_once()
|
||||
mock_seq.assert_not_called()
|
||||
|
||||
def test_overlapping_write_batch_forces_sequential(self, agent):
|
||||
"""Writes to the same file must stay ordered."""
|
||||
tc1 = _mock_tool_call(
|
||||
name="write_file",
|
||||
arguments='{"path":"src/a.py","content":"print(1)"}',
|
||||
call_id="c1",
|
||||
)
|
||||
tc2 = _mock_tool_call(
|
||||
name="patch",
|
||||
arguments='{"path":"src/a.py","old_string":"1","new_string":"2"}',
|
||||
call_id="c2",
|
||||
)
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_seq.assert_called_once()
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_malformed_json_args_forces_sequential(self, agent):
|
||||
"""Unparseable tool arguments should fall back to sequential."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments="NOT JSON {{{", call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_seq.assert_called_once()
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_non_dict_args_forces_sequential(self, agent):
|
||||
"""Tool arguments that parse to a non-dict type should fall back to sequential."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='"just a string"', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
|
||||
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
mock_seq.assert_called_once()
|
||||
mock_con.assert_not_called()
|
||||
|
||||
def test_concurrent_executes_all_tools(self, agent):
|
||||
"""Concurrent path should execute all tools and append results in order."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{"q":"alpha"}', call_id="c1")
|
||||
|
|
@ -943,6 +1031,39 @@ class TestConcurrentToolExecution:
|
|||
assert "ok" in result
|
||||
|
||||
|
||||
class TestPathsOverlap:
|
||||
"""Unit tests for the _paths_overlap helper."""
|
||||
|
||||
def test_same_path_overlaps(self):
|
||||
from run_agent import _paths_overlap
|
||||
assert _paths_overlap(Path("src/a.py"), Path("src/a.py"))
|
||||
|
||||
def test_siblings_do_not_overlap(self):
|
||||
from run_agent import _paths_overlap
|
||||
assert not _paths_overlap(Path("src/a.py"), Path("src/b.py"))
|
||||
|
||||
def test_parent_child_overlap(self):
|
||||
from run_agent import _paths_overlap
|
||||
assert _paths_overlap(Path("src"), Path("src/sub/a.py"))
|
||||
|
||||
def test_different_roots_do_not_overlap(self):
|
||||
from run_agent import _paths_overlap
|
||||
assert not _paths_overlap(Path("src/a.py"), Path("other/a.py"))
|
||||
|
||||
def test_nested_vs_flat_do_not_overlap(self):
|
||||
from run_agent import _paths_overlap
|
||||
assert not _paths_overlap(Path("src/sub/a.py"), Path("src/a.py"))
|
||||
|
||||
def test_empty_paths_do_not_overlap(self):
|
||||
from run_agent import _paths_overlap
|
||||
assert not _paths_overlap(Path(""), Path(""))
|
||||
|
||||
def test_one_empty_path_does_not_overlap(self):
|
||||
from run_agent import _paths_overlap
|
||||
assert not _paths_overlap(Path(""), Path("src/a.py"))
|
||||
assert not _paths_overlap(Path("src/a.py"), Path(""))
|
||||
|
||||
|
||||
class TestHandleMaxIterations:
|
||||
def test_returns_summary(self, agent):
|
||||
resp = _mock_response(content="Here is a summary of what I did.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue