fix: make concurrent tool batching path-aware for file mutations (#1914)

* Improve tool batching independence checks

* fix: address review feedback on path-aware batching

- Log malformed/non-dict tool arguments at debug level before
  falling back to sequential, instead of silently swallowing
  the error into an empty dict
- Guard empty paths in _paths_overlap (unreachable in practice
  due to upstream filtering, but makes the invariant explicit)
- Add tests: malformed JSON args, non-dict args, _paths_overlap
  unit tests including empty path edge cases
- web_crawl is not a registered tool (only web_search/web_extract
  are); no addition needed to _PARALLEL_SAFE_TOOLS

---------

Co-authored-by: kshitij <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
Teknium 2026-03-18 03:25:38 -07:00 committed by GitHub
parent 050b43108c
commit c0c14e60b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 215 additions and 8 deletions

View file

@ -806,7 +806,7 @@ class TestConcurrentToolExecution:
mock_con.assert_not_called()
def test_multiple_tools_uses_concurrent_path(self, agent):
"""Multiple non-interactive tools should use concurrent path."""
"""Multiple read-only tools should use concurrent path."""
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
tc2 = _mock_tool_call(name="read_file", arguments='{"path":"x.py"}', call_id="c2")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
@ -817,6 +817,94 @@ class TestConcurrentToolExecution:
mock_con.assert_called_once()
mock_seq.assert_not_called()
def test_terminal_batch_forces_sequential(self, agent):
"""Stateful tools should not share the concurrent execution path."""
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
tc2 = _mock_tool_call(name="terminal", arguments='{"command":"pwd"}', call_id="c2")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
messages = []
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
agent._execute_tool_calls(mock_msg, messages, "task-1")
mock_seq.assert_called_once()
mock_con.assert_not_called()
def test_write_batch_forces_sequential(self, agent):
"""File mutations should stay ordered within a turn."""
tc1 = _mock_tool_call(name="read_file", arguments='{"path":"x.py"}', call_id="c1")
tc2 = _mock_tool_call(name="write_file", arguments='{"path":"x.py","content":"print(1)"}', call_id="c2")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
messages = []
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
agent._execute_tool_calls(mock_msg, messages, "task-1")
mock_seq.assert_called_once()
mock_con.assert_not_called()
def test_disjoint_write_batch_uses_concurrent_path(self, agent):
"""Independent file writes should still run concurrently."""
tc1 = _mock_tool_call(
name="write_file",
arguments='{"path":"src/a.py","content":"print(1)"}',
call_id="c1",
)
tc2 = _mock_tool_call(
name="write_file",
arguments='{"path":"src/b.py","content":"print(2)"}',
call_id="c2",
)
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
messages = []
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
agent._execute_tool_calls(mock_msg, messages, "task-1")
mock_con.assert_called_once()
mock_seq.assert_not_called()
def test_overlapping_write_batch_forces_sequential(self, agent):
"""Writes to the same file must stay ordered."""
tc1 = _mock_tool_call(
name="write_file",
arguments='{"path":"src/a.py","content":"print(1)"}',
call_id="c1",
)
tc2 = _mock_tool_call(
name="patch",
arguments='{"path":"src/a.py","old_string":"1","new_string":"2"}',
call_id="c2",
)
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
messages = []
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
agent._execute_tool_calls(mock_msg, messages, "task-1")
mock_seq.assert_called_once()
mock_con.assert_not_called()
def test_malformed_json_args_forces_sequential(self, agent):
"""Unparseable tool arguments should fall back to sequential."""
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
tc2 = _mock_tool_call(name="web_search", arguments="NOT JSON {{{", call_id="c2")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
messages = []
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
agent._execute_tool_calls(mock_msg, messages, "task-1")
mock_seq.assert_called_once()
mock_con.assert_not_called()
def test_non_dict_args_forces_sequential(self, agent):
"""Tool arguments that parse to a non-dict type should fall back to sequential."""
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
tc2 = _mock_tool_call(name="web_search", arguments='"just a string"', call_id="c2")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
messages = []
with patch.object(agent, "_execute_tool_calls_sequential") as mock_seq:
with patch.object(agent, "_execute_tool_calls_concurrent") as mock_con:
agent._execute_tool_calls(mock_msg, messages, "task-1")
mock_seq.assert_called_once()
mock_con.assert_not_called()
def test_concurrent_executes_all_tools(self, agent):
"""Concurrent path should execute all tools and append results in order."""
tc1 = _mock_tool_call(name="web_search", arguments='{"q":"alpha"}', call_id="c1")
@ -943,6 +1031,39 @@ class TestConcurrentToolExecution:
assert "ok" in result
class TestPathsOverlap:
"""Unit tests for the _paths_overlap helper."""
def test_same_path_overlaps(self):
from run_agent import _paths_overlap
assert _paths_overlap(Path("src/a.py"), Path("src/a.py"))
def test_siblings_do_not_overlap(self):
from run_agent import _paths_overlap
assert not _paths_overlap(Path("src/a.py"), Path("src/b.py"))
def test_parent_child_overlap(self):
from run_agent import _paths_overlap
assert _paths_overlap(Path("src"), Path("src/sub/a.py"))
def test_different_roots_do_not_overlap(self):
from run_agent import _paths_overlap
assert not _paths_overlap(Path("src/a.py"), Path("other/a.py"))
def test_nested_vs_flat_do_not_overlap(self):
from run_agent import _paths_overlap
assert not _paths_overlap(Path("src/sub/a.py"), Path("src/a.py"))
def test_empty_paths_do_not_overlap(self):
from run_agent import _paths_overlap
assert not _paths_overlap(Path(""), Path(""))
def test_one_empty_path_does_not_overlap(self):
from run_agent import _paths_overlap
assert not _paths_overlap(Path(""), Path("src/a.py"))
assert not _paths_overlap(Path("src/a.py"), Path(""))
class TestHandleMaxIterations:
def test_returns_summary(self, agent):
resp = _mock_response(content="Here is a summary of what I did.")