feat: concurrent tool execution with ThreadPoolExecutor
When the model returns multiple tool calls in a single response, they are now executed concurrently using a thread pool instead of sequentially. This significantly reduces wall-clock time when multiple independent tools are batched (e.g. parallel web_search, read_file, terminal calls). Architecture: - _execute_tool_calls() dispatches to sequential or concurrent path - Single tool calls and batches containing 'clarify' use sequential path - Multiple non-interactive tools use ThreadPoolExecutor (max 8 workers) - Results are collected and appended to messages in original order - _invoke_tool() extracted as shared tool invocation helper Safety: - Pre-flight interrupt check skips all tools if interrupted - Per-tool exception handling: one failure doesn't crash the batch - Result truncation (100k char limit) applied per tool - Budget pressure injection after all tools complete - Checkpoints taken before file-mutating tools - CLI spinner shows batch progress, then per-tool completion messages Tests: 10 new tests covering dispatch logic, ordering, error handling, interrupt behavior, truncation, and _invoke_tool routing.
This commit is contained in:
parent
f562d97f13
commit
5d0d5b191c
3 changed files with 429 additions and 2 deletions
264
run_agent.py
264
run_agent.py
|
|
@ -21,6 +21,7 @@ Usage:
|
|||
"""
|
||||
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
|
|
@ -193,6 +194,14 @@ class IterationBudget:
|
|||
return max(0, self.max_total - self._used)
|
||||
|
||||
|
||||
# Tools that must never run concurrently (interactive / user-facing).
|
||||
# When any of these appear in a batch, we fall back to sequential execution.
|
||||
_NEVER_PARALLEL_TOOLS = frozenset({"clarify"})
|
||||
|
||||
# Maximum number of concurrent worker threads for parallel tool execution.
|
||||
_MAX_TOOL_WORKERS = 8
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
AI Agent with tool calling capabilities.
|
||||
|
|
@ -3119,7 +3128,260 @@ class AIAgent:
|
|||
return compressed, new_system_prompt
|
||||
|
||||
def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute tool calls from the assistant message and append results to messages."""
|
||||
"""Execute tool calls from the assistant message and append results to messages.
|
||||
|
||||
Dispatches to concurrent execution when multiple independent tool calls
|
||||
are present, falling back to sequential execution for single calls or
|
||||
when interactive tools (e.g. clarify) are in the batch.
|
||||
"""
|
||||
tool_calls = assistant_message.tool_calls
|
||||
|
||||
# Single tool call or interactive tool present → sequential
|
||||
if (len(tool_calls) <= 1
|
||||
or any(tc.function.name in _NEVER_PARALLEL_TOOLS for tc in tool_calls)):
|
||||
return self._execute_tool_calls_sequential(
|
||||
assistant_message, messages, effective_task_id, api_call_count
|
||||
)
|
||||
|
||||
# Multiple non-interactive tools → concurrent
|
||||
return self._execute_tool_calls_concurrent(
|
||||
assistant_message, messages, effective_task_id, api_call_count
|
||||
)
|
||||
|
||||
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str) -> str:
|
||||
"""Invoke a single tool and return the result string. No display logic.
|
||||
|
||||
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
|
||||
tools. Used by the concurrent execution path; the sequential path retains
|
||||
its own inline invocation for backward-compatible display handling.
|
||||
"""
|
||||
if function_name == "todo":
|
||||
from tools.todo_tool import todo_tool as _todo_tool
|
||||
return _todo_tool(
|
||||
todos=function_args.get("todos"),
|
||||
merge=function_args.get("merge", False),
|
||||
store=self._todo_store,
|
||||
)
|
||||
elif function_name == "session_search":
|
||||
if not self._session_db:
|
||||
return json.dumps({"success": False, "error": "Session database not available."})
|
||||
from tools.session_search_tool import session_search as _session_search
|
||||
return _session_search(
|
||||
query=function_args.get("query", ""),
|
||||
role_filter=function_args.get("role_filter"),
|
||||
limit=function_args.get("limit", 3),
|
||||
db=self._session_db,
|
||||
current_session_id=self.session_id,
|
||||
)
|
||||
elif function_name == "memory":
|
||||
target = function_args.get("target", "memory")
|
||||
from tools.memory_tool import memory_tool as _memory_tool
|
||||
result = _memory_tool(
|
||||
action=function_args.get("action"),
|
||||
target=target,
|
||||
content=function_args.get("content"),
|
||||
old_text=function_args.get("old_text"),
|
||||
store=self._memory_store,
|
||||
)
|
||||
# Also send user observations to Honcho when active
|
||||
if self._honcho and target == "user" and function_args.get("action") == "add":
|
||||
self._honcho_save_user_observation(function_args.get("content", ""))
|
||||
return result
|
||||
elif function_name == "clarify":
|
||||
from tools.clarify_tool import clarify_tool as _clarify_tool
|
||||
return _clarify_tool(
|
||||
question=function_args.get("question", ""),
|
||||
choices=function_args.get("choices"),
|
||||
callback=self.clarify_callback,
|
||||
)
|
||||
elif function_name == "delegate_task":
|
||||
from tools.delegate_tool import delegate_task as _delegate_task
|
||||
return _delegate_task(
|
||||
goal=function_args.get("goal"),
|
||||
context=function_args.get("context"),
|
||||
toolsets=function_args.get("toolsets"),
|
||||
tasks=function_args.get("tasks"),
|
||||
max_iterations=function_args.get("max_iterations"),
|
||||
parent_agent=self,
|
||||
)
|
||||
else:
|
||||
return handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
|
||||
)
|
||||
|
||||
def _execute_tool_calls_concurrent(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute multiple tool calls concurrently using a thread pool.
|
||||
|
||||
Results are collected in the original tool-call order and appended to
|
||||
messages so the API sees them in the expected sequence.
|
||||
"""
|
||||
tool_calls = assistant_message.tool_calls
|
||||
num_tools = len(tool_calls)
|
||||
|
||||
# ── Pre-flight: interrupt check ──────────────────────────────────
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt: skipping {num_tools} tool call(s)")
|
||||
for tc in tool_calls:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": f"[Tool execution cancelled — {tc.function.name} was skipped due to user interrupt]",
|
||||
"tool_call_id": tc.id,
|
||||
})
|
||||
return
|
||||
|
||||
# ── Parse args + pre-execution bookkeeping ───────────────────────
|
||||
parsed_calls = [] # list of (tool_call, function_name, function_args)
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
|
||||
# Reset nudge counters
|
||||
if function_name == "memory":
|
||||
self._turns_since_memory = 0
|
||||
elif function_name == "skill_manage":
|
||||
self._iters_since_skill = 0
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
function_args = {}
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
# Checkpoint for file-mutating tools
|
||||
if function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
|
||||
try:
|
||||
file_path = function_args.get("path", "")
|
||||
if file_path:
|
||||
work_dir = self._checkpoint_mgr.get_working_dir_for_path(file_path)
|
||||
self._checkpoint_mgr.ensure_checkpoint(work_dir, f"before {function_name}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
parsed_calls.append((tool_call, function_name, function_args))
|
||||
|
||||
# ── Logging / callbacks ──────────────────────────────────────────
|
||||
tool_names_str = ", ".join(name for _, name, _ in parsed_calls)
|
||||
if not self.quiet_mode:
|
||||
print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}")
|
||||
for i, (tc, name, args) in enumerate(parsed_calls, 1):
|
||||
args_str = json.dumps(args, ensure_ascii=False)
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
||||
|
||||
for _, name, args in parsed_calls:
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
preview = _build_tool_preview(name, args)
|
||||
self.tool_progress_callback(name, preview, args)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
# ── Concurrent execution ─────────────────────────────────────────
|
||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag)
|
||||
results = [None] * num_tools
|
||||
|
||||
def _run_tool(index, tool_call, function_name, function_args):
|
||||
"""Worker function executed in a thread."""
|
||||
start = time.time()
|
||||
try:
|
||||
result = self._invoke_tool(function_name, function_args, effective_task_id)
|
||||
except Exception as tool_error:
|
||||
result = f"Error executing tool '{function_name}': {tool_error}"
|
||||
logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True)
|
||||
duration = time.time() - start
|
||||
is_error, _ = _detect_tool_failure(function_name, result)
|
||||
results[index] = (function_name, function_args, result, duration, is_error)
|
||||
|
||||
# Start spinner for CLI mode
|
||||
spinner = None
|
||||
if self.quiet_mode:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots')
|
||||
spinner.start()
|
||||
|
||||
try:
|
||||
max_workers = min(num_tools, _MAX_TOOL_WORKERS)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for i, (tc, name, args) in enumerate(parsed_calls):
|
||||
f = executor.submit(_run_tool, i, tc, name, args)
|
||||
futures.append(f)
|
||||
|
||||
# Wait for all to complete (exceptions are captured inside _run_tool)
|
||||
concurrent.futures.wait(futures)
|
||||
finally:
|
||||
if spinner:
|
||||
# Build a summary message for the spinner stop
|
||||
completed = sum(1 for r in results if r is not None)
|
||||
total_dur = sum(r[3] for r in results if r is not None)
|
||||
spinner.stop(f"⚡ {completed}/{num_tools} tools completed in {total_dur:.1f}s total")
|
||||
|
||||
# ── Post-execution: display per-tool results ─────────────────────
|
||||
for i, (tc, name, args) in enumerate(parsed_calls):
|
||||
r = results[i]
|
||||
if r is None:
|
||||
# Shouldn't happen, but safety fallback
|
||||
function_result = f"Error executing tool '{name}': thread did not return a result"
|
||||
tool_duration = 0.0
|
||||
else:
|
||||
function_name, function_args, function_result, tool_duration, is_error = r
|
||||
|
||||
if is_error:
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
||||
|
||||
if self.verbose_logging:
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result preview: {result_preview}...")
|
||||
|
||||
# Print cute message per tool
|
||||
if self.quiet_mode:
|
||||
cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result)
|
||||
print(f" {cute_msg}")
|
||||
elif not self.quiet_mode:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
|
||||
# Truncate oversized results
|
||||
MAX_TOOL_RESULT_CHARS = 100_000
|
||||
if len(function_result) > MAX_TOOL_RESULT_CHARS:
|
||||
original_len = len(function_result)
|
||||
function_result = (
|
||||
function_result[:MAX_TOOL_RESULT_CHARS]
|
||||
+ f"\n\n[Truncated: tool response was {original_len:,} chars, "
|
||||
f"exceeding the {MAX_TOOL_RESULT_CHARS:,} char limit]"
|
||||
)
|
||||
|
||||
# Append tool result message in order
|
||||
tool_msg = {
|
||||
"role": "tool",
|
||||
"content": function_result,
|
||||
"tool_call_id": tc.id,
|
||||
}
|
||||
messages.append(tool_msg)
|
||||
|
||||
# ── Budget pressure injection ────────────────────────────────────
|
||||
budget_warning = self._get_budget_warning(api_call_count)
|
||||
if budget_warning and messages and messages[-1].get("role") == "tool":
|
||||
last_content = messages[-1]["content"]
|
||||
try:
|
||||
parsed = json.loads(last_content)
|
||||
if isinstance(parsed, dict):
|
||||
parsed["_budget_warning"] = budget_warning
|
||||
messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False)
|
||||
else:
|
||||
messages[-1]["content"] = last_content + f"\n\n{budget_warning}"
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
messages[-1]["content"] = last_content + f"\n\n{budget_warning}"
|
||||
if not self.quiet_mode:
|
||||
remaining = self.max_iterations - api_call_count
|
||||
tier = "⚠️ WARNING" if remaining <= self.max_iterations * 0.1 else "💡 CAUTION"
|
||||
print(f"{self.log_prefix}{tier}: {remaining} iterations remaining")
|
||||
|
||||
def _execute_tool_calls_sequential(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute tool calls sequentially (original behavior). Used for single calls or interactive tools."""
|
||||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||
# SAFETY: check interrupt BEFORE starting each tool.
|
||||
# If the user sent "stop" during a previous tool's execution,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue