fix: thread safety for concurrent subagent delegation (#1672)

* fix: thread safety for concurrent subagent delegation

Four thread-safety fixes that prevent crashes and data races when
running multiple subagents concurrently via delegate_task:

1. Remove redirect_stdout/stderr from delegate_tool — mutating global
   sys.stdout races with the spinner thread when multiple children start
   concurrently, causing segfaults. Children already run with
   quiet_mode=True so the redirect was redundant.

2. Split _run_single_child into _build_child_agent (main thread) +
   _run_single_child (worker thread). AIAgent construction creates
   httpx/SSL clients which are not thread-safe to initialize
   concurrently.

3. Add threading.Lock to SessionDB — subagents share the parent's
   SessionDB and call create_session/append_message from worker threads
   with no synchronization.

4. Add _active_children_lock to AIAgent — interrupt() iterates
   _active_children while worker threads append/remove children.

5. Add _client_cache_lock to auxiliary_client — multiple subagent
   threads may resolve clients concurrently via call_llm().

Based on PR #1471 by peteromallet.

* feat: Honcho base_url override via config.yaml + quick command alias type

Two features salvaged from PR #1576:

1. Honcho base_url override: allows pointing Hermes at a remote
   self-hosted Honcho deployment via config.yaml:

     honcho:
       base_url: "http://192.168.x.x:8000"

   When set, this overrides the Honcho SDK's environment mapping
   (production/local), enabling LAN/VPN Honcho deployments without
   requiring the server to live on localhost. Uses config.yaml instead
   of env var (HONCHO_URL) per project convention.

2. Quick command alias type: adds a new 'alias' quick command type
   that rewrites to another slash command before normal dispatch:

     quick_commands:
       sc:
         type: alias
         target: /context

   Supports both CLI and gateway. Arguments are forwarded to the
   target command.

Based on PR #1576 by redhelix.

---------

Co-authored-by: peteromallet <peteromallet@users.noreply.github.com>
Co-authored-by: redhelix <redhelix@users.noreply.github.com>
This commit is contained in:
Teknium 2026-03-17 02:53:33 -07:00 committed by GitHub
parent fd61ae13e5
commit 1d5a39e002
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 397 additions and 272 deletions

View file

@ -39,6 +39,7 @@ custom OpenAI-compatible endpoint without touching the main model settings.
import json import json
import logging import logging
import os import os
import threading
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -1171,6 +1172,7 @@ def auxiliary_max_tokens_param(value: int) -> dict:
# Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model) # Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model)
_client_cache: Dict[tuple, tuple] = {} _client_cache: Dict[tuple, tuple] = {}
_client_cache_lock = threading.Lock()
def _get_cached_client( def _get_cached_client(
@ -1182,9 +1184,11 @@ def _get_cached_client(
) -> Tuple[Optional[Any], Optional[str]]: ) -> Tuple[Optional[Any], Optional[str]]:
"""Get or create a cached client for the given provider.""" """Get or create a cached client for the given provider."""
cache_key = (provider, async_mode, base_url or "", api_key or "") cache_key = (provider, async_mode, base_url or "", api_key or "")
with _client_cache_lock:
if cache_key in _client_cache: if cache_key in _client_cache:
cached_client, cached_default = _client_cache[cache_key] cached_client, cached_default = _client_cache[cache_key]
return cached_client, model or cached_default return cached_client, model or cached_default
# Build outside the lock
client, default_model = resolve_provider_client( client, default_model = resolve_provider_client(
provider, provider,
model, model,
@ -1193,7 +1197,11 @@ def _get_cached_client(
explicit_api_key=api_key, explicit_api_key=api_key,
) )
if client is not None: if client is not None:
with _client_cache_lock:
if cache_key not in _client_cache:
_client_cache[cache_key] = (client, default_model) _client_cache[cache_key] = (client, default_model)
else:
client, default_model = _client_cache[cache_key]
return client, model or default_model return client, model or default_model

11
cli.py
View file

@ -3652,8 +3652,17 @@ class HermesCLI:
self.console.print(f"[bold red]Quick command error: {e}[/]") self.console.print(f"[bold red]Quick command error: {e}[/]")
else: else:
self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]") self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]")
elif qcmd.get("type") == "alias":
target = qcmd.get("target", "").strip()
if target:
target = target if target.startswith("/") else f"/{target}"
user_args = cmd_original[len(base_cmd):].strip()
aliased_command = f"{target} {user_args}".strip()
return self.process_command(aliased_command)
else: else:
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (only 'exec' is supported)[/]") self.console.print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]")
else:
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]")
# Check for skill slash commands (/gif-search, /axolotl, etc.) # Check for skill slash commands (/gif-search, /axolotl, etc.)
elif base_cmd in _skill_commands: elif base_cmd in _skill_commands:
user_instruction = cmd_original[len(base_cmd):].strip() user_instruction = cmd_original[len(base_cmd):].strip()

View file

@ -1421,8 +1421,19 @@ class GatewayRunner:
return f"Quick command error: {e}" return f"Quick command error: {e}"
else: else:
return f"Quick command '/{command}' has no command defined." return f"Quick command '/{command}' has no command defined."
elif qcmd.get("type") == "alias":
target = qcmd.get("target", "").strip()
if target:
target = target if target.startswith("/") else f"/{target}"
target_command = target.lstrip("/")
user_args = event.get_command_args().strip()
event.text = f"{target} {user_args}".strip()
command = target_command
# Fall through to normal command dispatch below
else: else:
return f"Quick command '/{command}' has unsupported type (only 'exec' is supported)." return f"Quick command '/{command}' has no target defined."
else:
return f"Quick command '/{command}' has unsupported type (supported: 'exec', 'alias')."
# Skill slash commands: /skill-name loads the skill and sends to agent # Skill slash commands: /skill-name loads the skill and sends to agent
if command: if command:

View file

@ -18,6 +18,7 @@ import json
import os import os
import re import re
import sqlite3 import sqlite3
import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
@ -104,6 +105,7 @@ class SessionDB:
self.db_path = db_path or DEFAULT_DB_PATH self.db_path = db_path or DEFAULT_DB_PATH
self.db_path.parent.mkdir(parents=True, exist_ok=True) self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._conn = sqlite3.connect( self._conn = sqlite3.connect(
str(self.db_path), str(self.db_path),
check_same_thread=False, check_same_thread=False,
@ -173,6 +175,7 @@ class SessionDB:
def close(self): def close(self):
"""Close the database connection.""" """Close the database connection."""
with self._lock:
if self._conn: if self._conn:
self._conn.close() self._conn.close()
self._conn = None self._conn = None
@ -192,6 +195,7 @@ class SessionDB:
parent_session_id: str = None, parent_session_id: str = None,
) -> str: ) -> str:
"""Create a new session record. Returns the session_id.""" """Create a new session record. Returns the session_id."""
with self._lock:
self._conn.execute( self._conn.execute(
"""INSERT INTO sessions (id, source, user_id, model, model_config, """INSERT INTO sessions (id, source, user_id, model, model_config,
system_prompt, parent_session_id, started_at) system_prompt, parent_session_id, started_at)
@ -212,6 +216,7 @@ class SessionDB:
def end_session(self, session_id: str, end_reason: str) -> None: def end_session(self, session_id: str, end_reason: str) -> None:
"""Mark a session as ended.""" """Mark a session as ended."""
with self._lock:
self._conn.execute( self._conn.execute(
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?", "UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
(time.time(), end_reason, session_id), (time.time(), end_reason, session_id),
@ -220,6 +225,7 @@ class SessionDB:
def update_system_prompt(self, session_id: str, system_prompt: str) -> None: def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
"""Store the full assembled system prompt snapshot.""" """Store the full assembled system prompt snapshot."""
with self._lock:
self._conn.execute( self._conn.execute(
"UPDATE sessions SET system_prompt = ? WHERE id = ?", "UPDATE sessions SET system_prompt = ? WHERE id = ?",
(system_prompt, session_id), (system_prompt, session_id),
@ -231,6 +237,7 @@ class SessionDB:
model: str = None, model: str = None,
) -> None: ) -> None:
"""Increment token counters and backfill model if not already set.""" """Increment token counters and backfill model if not already set."""
with self._lock:
self._conn.execute( self._conn.execute(
"""UPDATE sessions SET """UPDATE sessions SET
input_tokens = input_tokens + ?, input_tokens = input_tokens + ?,
@ -243,6 +250,7 @@ class SessionDB:
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get a session by ID.""" """Get a session by ID."""
with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT * FROM sessions WHERE id = ?", (session_id,) "SELECT * FROM sessions WHERE id = ?", (session_id,)
) )
@ -331,6 +339,7 @@ class SessionDB:
Empty/whitespace-only strings are normalized to None (clearing the title). Empty/whitespace-only strings are normalized to None (clearing the title).
""" """
title = self.sanitize_title(title) title = self.sanitize_title(title)
with self._lock:
if title: if title:
# Check uniqueness (allow the same session to keep its own title) # Check uniqueness (allow the same session to keep its own title)
cursor = self._conn.execute( cursor = self._conn.execute(
@ -347,10 +356,12 @@ class SessionDB:
(title, session_id), (title, session_id),
) )
self._conn.commit() self._conn.commit()
return cursor.rowcount > 0 rowcount = cursor.rowcount
return rowcount > 0
def get_session_title(self, session_id: str) -> Optional[str]: def get_session_title(self, session_id: str) -> Optional[str]:
"""Get the title for a session, or None.""" """Get the title for a session, or None."""
with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT title FROM sessions WHERE id = ?", (session_id,) "SELECT title FROM sessions WHERE id = ?", (session_id,)
) )
@ -359,6 +370,7 @@ class SessionDB:
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]: def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
"""Look up a session by exact title. Returns session dict or None.""" """Look up a session by exact title. Returns session dict or None."""
with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT * FROM sessions WHERE title = ?", (title,) "SELECT * FROM sessions WHERE title = ?", (title,)
) )
@ -379,6 +391,7 @@ class SessionDB:
# Also search for numbered variants: "title #2", "title #3", etc. # Also search for numbered variants: "title #2", "title #3", etc.
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches # Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT id, title, started_at FROM sessions " "SELECT id, title, started_at FROM sessions "
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC", "WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
@ -409,6 +422,7 @@ class SessionDB:
# Find all existing numbered variants # Find all existing numbered variants
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches # Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'", "SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
(base, f"{escaped} #%"), (base, f"{escaped} #%"),
@ -461,9 +475,11 @@ class SessionDB:
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
""" """
params = (source, limit, offset) if source else (limit, offset) params = (source, limit, offset) if source else (limit, offset)
with self._lock:
cursor = self._conn.execute(query, params) cursor = self._conn.execute(query, params)
rows = cursor.fetchall()
sessions = [] sessions = []
for row in cursor.fetchall(): for row in rows:
s = dict(row) s = dict(row)
# Build the preview from the raw substring # Build the preview from the raw substring
raw = s.pop("_preview_raw", "").strip() raw = s.pop("_preview_raw", "").strip()
@ -497,6 +513,7 @@ class SessionDB:
Also increments the session's message_count (and tool_call_count Also increments the session's message_count (and tool_call_count
if role is 'tool' or tool_calls is present). if role is 'tool' or tool_calls is present).
""" """
with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"""INSERT INTO messages (session_id, role, content, tool_call_id, """INSERT INTO messages (session_id, role, content, tool_call_id,
tool_calls, tool_name, timestamp, token_count, finish_reason) tool_calls, tool_name, timestamp, token_count, finish_reason)
@ -538,6 +555,7 @@ class SessionDB:
def get_messages(self, session_id: str) -> List[Dict[str, Any]]: def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
"""Load all messages for a session, ordered by timestamp.""" """Load all messages for a session, ordered by timestamp."""
with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id", "SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
(session_id,), (session_id,),
@ -559,13 +577,15 @@ class SessionDB:
Load messages in the OpenAI conversation format (role + content dicts). Load messages in the OpenAI conversation format (role + content dicts).
Used by the gateway to restore conversation history. Used by the gateway to restore conversation history.
""" """
with self._lock:
cursor = self._conn.execute( cursor = self._conn.execute(
"SELECT role, content, tool_call_id, tool_calls, tool_name " "SELECT role, content, tool_call_id, tool_calls, tool_name "
"FROM messages WHERE session_id = ? ORDER BY timestamp, id", "FROM messages WHERE session_id = ? ORDER BY timestamp, id",
(session_id,), (session_id,),
) )
rows = cursor.fetchall()
messages = [] messages = []
for row in cursor.fetchall(): for row in rows:
msg = {"role": row["role"], "content": row["content"]} msg = {"role": row["role"], "content": row["content"]}
if row["tool_call_id"]: if row["tool_call_id"]:
msg["tool_call_id"] = row["tool_call_id"] msg["tool_call_id"] = row["tool_call_id"]
@ -675,6 +695,7 @@ class SessionDB:
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
""" """
with self._lock:
try: try:
cursor = self._conn.execute(sql, params) cursor = self._conn.execute(sql, params)
except sqlite3.OperationalError: except sqlite3.OperationalError:
@ -700,6 +721,7 @@ class SessionDB:
match["context"] = [] match["context"] = []
# Remove full content from result (snippet is enough, saves tokens) # Remove full content from result (snippet is enough, saves tokens)
for match in matches:
match.pop("content", None) match.pop("content", None)
return matches return matches

View file

@ -69,6 +69,8 @@ class HonchoClientConfig:
workspace_id: str = "hermes" workspace_id: str = "hermes"
api_key: str | None = None api_key: str | None = None
environment: str = "production" environment: str = "production"
# Optional base URL for self-hosted Honcho (overrides environment mapping)
base_url: str | None = None
# Identity # Identity
peer_name: str | None = None peer_name: str | None = None
ai_peer: str = "hermes" ai_peer: str = "hermes"
@ -361,13 +363,34 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho:
"Install it with: pip install honcho-ai" "Install it with: pip install honcho-ai"
) )
# Allow config.yaml honcho.base_url to override the SDK's environment
# mapping, enabling remote self-hosted Honcho deployments without
# requiring the server to live on localhost.
resolved_base_url = config.base_url
if not resolved_base_url:
try:
from hermes_cli.config import load_config
hermes_cfg = load_config()
honcho_cfg = hermes_cfg.get("honcho", {})
if isinstance(honcho_cfg, dict):
resolved_base_url = honcho_cfg.get("base_url", "").strip() or None
except Exception:
pass
if resolved_base_url:
logger.info("Initializing Honcho client (base_url: %s, workspace: %s)", resolved_base_url, config.workspace_id)
else:
logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id) logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id)
_honcho_client = Honcho( kwargs: dict = {
workspace_id=config.workspace_id, "workspace_id": config.workspace_id,
api_key=config.api_key, "api_key": config.api_key,
environment=config.environment, "environment": config.environment,
) }
if resolved_base_url:
kwargs["base_url"] = resolved_base_url
_honcho_client = Honcho(**kwargs)
return _honcho_client return _honcho_client

View file

@ -407,6 +407,7 @@ class AIAgent:
# Subagent delegation state # Subagent delegation state
self._delegate_depth = 0 # 0 = top-level agent, incremented for children self._delegate_depth = 0 # 0 = top-level agent, incremented for children
self._active_children = [] # Running child AIAgents (for interrupt propagation) self._active_children = [] # Running child AIAgents (for interrupt propagation)
self._active_children_lock = threading.Lock()
# Store OpenRouter provider preferences # Store OpenRouter provider preferences
self.providers_allowed = providers_allowed self.providers_allowed = providers_allowed
@ -1526,7 +1527,9 @@ class AIAgent:
# Signal all tools to abort any in-flight operations immediately # Signal all tools to abort any in-flight operations immediately
_set_interrupt(True) _set_interrupt(True)
# Propagate interrupt to any running child agents (subagent delegation) # Propagate interrupt to any running child agents (subagent delegation)
for child in self._active_children: with self._active_children_lock:
children_copy = list(self._active_children)
for child in children_copy:
try: try:
child.interrupt(message) child.interrupt(message)
except Exception as e: except Exception as e:

View file

@ -24,6 +24,7 @@ def main() -> int:
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
parent.model = "test/model" parent.model = "test/model"
parent.base_url = "http://localhost:1" parent.base_url = "http://localhost:1"

View file

@ -43,6 +43,7 @@ class TestCLISubagentInterrupt(unittest.TestCase):
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
parent.model = "test/model" parent.model = "test/model"
parent.base_url = "http://localhost:1" parent.base_url = "http://localhost:1"
@ -112,21 +113,21 @@ class TestCLISubagentInterrupt(unittest.TestCase):
mock_instance._interrupt_requested = False mock_instance._interrupt_requested = False
mock_instance._interrupt_message = None mock_instance._interrupt_message = None
mock_instance._active_children = [] mock_instance._active_children = []
mock_instance._active_children_lock = threading.Lock()
mock_instance.quiet_mode = True mock_instance.quiet_mode = True
mock_instance.run_conversation = mock_child_run_conversation mock_instance.run_conversation = mock_child_run_conversation
mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg) mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg)
mock_instance.tools = [] mock_instance.tools = []
MockAgent.return_value = mock_instance MockAgent.return_value = mock_instance
# Register child manually (normally done by _build_child_agent)
parent._active_children.append(mock_instance)
result = _run_single_child( result = _run_single_child(
task_index=0, task_index=0,
goal="Do something slow", goal="Do something slow",
context=None, child=mock_instance,
toolsets=["terminal"],
model=None,
max_iterations=50,
parent_agent=parent, parent_agent=parent,
task_count=1,
) )
delegate_result[0] = result delegate_result[0] = result
except Exception as e: except Exception as e:

View file

@ -57,6 +57,7 @@ def main() -> int:
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
parent.model = "test/model" parent.model = "test/model"
parent.base_url = "http://localhost:1" parent.base_url = "http://localhost:1"

View file

@ -30,12 +30,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
child = AIAgent.__new__(AIAgent) child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False child._interrupt_requested = False
child._interrupt_message = None child._interrupt_message = None
child._active_children = [] child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True child.quiet_mode = True
parent._active_children.append(child) parent._active_children.append(child)
@ -60,6 +62,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
child._interrupt_message = "msg" child._interrupt_message = "msg"
child.quiet_mode = True child.quiet_mode = True
child._active_children = [] child._active_children = []
child._active_children_lock = threading.Lock()
# Global is set # Global is set
set_interrupt(True) set_interrupt(True)
@ -78,6 +81,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
child._interrupt_requested = False child._interrupt_requested = False
child._interrupt_message = None child._interrupt_message = None
child._active_children = [] child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True child.quiet_mode = True
child.api_mode = "chat_completions" child.api_mode = "chat_completions"
child.log_prefix = "" child.log_prefix = ""
@ -119,12 +123,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
child = AIAgent.__new__(AIAgent) child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False child._interrupt_requested = False
child._interrupt_message = None child._interrupt_message = None
child._active_children = [] child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True child.quiet_mode = True
# Register child (simulating what _run_single_child does) # Register child (simulating what _run_single_child does)

View file

@ -47,6 +47,28 @@ class TestCLIQuickCommands:
args = cli.console.print.call_args[0][0] args = cli.console.print.call_args[0][0]
assert "no output" in args.lower() assert "no output" in args.lower()
def test_alias_command_routes_to_target(self):
"""Alias quick commands rewrite to the target command."""
cli = self._make_cli({"shortcut": {"type": "alias", "target": "/help"}})
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
cli.process_command("/shortcut")
# Should recursively call process_command with /help
spy.assert_any_call("/help")
def test_alias_command_passes_args(self):
"""Alias quick commands forward user arguments to the target."""
cli = self._make_cli({"sc": {"type": "alias", "target": "/context"}})
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
cli.process_command("/sc some args")
spy.assert_any_call("/context some args")
def test_alias_no_target_shows_error(self):
cli = self._make_cli({"broken": {"type": "alias", "target": ""}})
cli.process_command("/broken")
cli.console.print.assert_called_once()
args = cli.console.print.call_args[0][0]
assert "no target defined" in args.lower()
def test_unsupported_type_shows_error(self): def test_unsupported_type_shows_error(self):
cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}}) cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}})
cli.process_command("/bad") cli.process_command("/bad")

View file

@ -55,6 +55,7 @@ class TestRealSubagentInterrupt(unittest.TestCase):
parent._interrupt_requested = False parent._interrupt_requested = False
parent._interrupt_message = None parent._interrupt_message = None
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True parent.quiet_mode = True
parent.model = "test/model" parent.model = "test/model"
parent.base_url = "http://localhost:1" parent.base_url = "http://localhost:1"
@ -103,19 +104,28 @@ class TestRealSubagentInterrupt(unittest.TestCase):
return original_run(self_agent, *args, **kwargs) return original_run(self_agent, *args, **kwargs)
with patch.object(AIAgent, 'run_conversation', patched_run): with patch.object(AIAgent, 'run_conversation', patched_run):
# Build a real child agent (AIAgent is NOT patched here,
# only run_conversation and _build_system_prompt are)
child = AIAgent(
base_url="http://localhost:1",
api_key="test-key",
model="test/model",
provider="test",
api_mode="chat_completions",
max_iterations=5,
enabled_toolsets=["terminal"],
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
platform="cli",
)
child._delegate_depth = 1
parent._active_children.append(child)
result = _run_single_child( result = _run_single_child(
task_index=0, task_index=0,
goal="Test task", goal="Test task",
context=None, child=child,
toolsets=["terminal"],
model="test/model",
max_iterations=5,
parent_agent=parent, parent_agent=parent,
task_count=1,
override_provider="test",
override_base_url="http://localhost:1",
override_api_key="test",
override_api_mode="chat_completions",
) )
result_holder[0] = result result_holder[0] = result
except Exception as e: except Exception as e:

View file

@ -12,6 +12,7 @@ Run with: python -m pytest tests/test_delegate.py -v
import json import json
import os import os
import sys import sys
import threading
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -44,6 +45,7 @@ def _make_mock_parent(depth=0):
parent._session_db = None parent._session_db = None
parent._delegate_depth = depth parent._delegate_depth = depth
parent._active_children = [] parent._active_children = []
parent._active_children_lock = threading.Lock()
return parent return parent
@ -722,7 +724,12 @@ class TestDelegationProviderIntegration(unittest.TestCase):
} }
parent = _make_mock_parent(depth=0) parent = _make_mock_parent(depth=0)
with patch("tools.delegate_tool._run_single_child") as mock_run: # Patch _build_child_agent since credentials are now passed there
# (agents are built in the main thread before being handed to workers)
with patch("tools.delegate_tool._build_child_agent") as mock_build, \
patch("tools.delegate_tool._run_single_child") as mock_run:
mock_child = MagicMock()
mock_build.return_value = mock_child
mock_run.return_value = { mock_run.return_value = {
"task_index": 0, "status": "completed", "task_index": 0, "status": "completed",
"summary": "Done", "api_calls": 1, "duration_seconds": 1.0 "summary": "Done", "api_calls": 1, "duration_seconds": 1.0
@ -731,7 +738,8 @@ class TestDelegationProviderIntegration(unittest.TestCase):
tasks = [{"goal": "Task A"}, {"goal": "Task B"}] tasks = [{"goal": "Task A"}, {"goal": "Task B"}]
delegate_task(tasks=tasks, parent_agent=parent) delegate_task(tasks=tasks, parent_agent=parent)
for call in mock_run.call_args_list: self.assertEqual(mock_build.call_count, 2)
for call in mock_build.call_args_list:
self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout") self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout")
self.assertEqual(call.kwargs.get("override_provider"), "openrouter") self.assertEqual(call.kwargs.get("override_provider"), "openrouter")
self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1") self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1")

View file

@ -16,13 +16,10 @@ The parent's context only sees the delegation call and the summary result,
never the child's intermediate tool calls or reasoning. never the child's intermediate tool calls or reasoning.
""" """
import contextlib
import io
import json import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import os import os
import sys
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -150,7 +147,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
return _callback return _callback
def _run_single_child( def _build_child_agent(
task_index: int, task_index: int,
goal: str, goal: str,
context: Optional[str], context: Optional[str],
@ -158,16 +155,15 @@ def _run_single_child(
model: Optional[str], model: Optional[str],
max_iterations: int, max_iterations: int,
parent_agent, parent_agent,
task_count: int = 1,
# Credential overrides from delegation config (provider:model resolution) # Credential overrides from delegation config (provider:model resolution)
override_provider: Optional[str] = None, override_provider: Optional[str] = None,
override_base_url: Optional[str] = None, override_base_url: Optional[str] = None,
override_api_key: Optional[str] = None, override_api_key: Optional[str] = None,
override_api_mode: Optional[str] = None, override_api_mode: Optional[str] = None,
) -> Dict[str, Any]: ):
""" """
Spawn and run a single child agent. Called from within a thread. Build a child AIAgent on the main thread (thread-safe construction).
Returns a structured result dict. Returns the constructed child agent without running it.
When override_* params are set (from delegation config), the child uses When override_* params are set (from delegation config), the child uses
those credentials instead of inheriting from the parent. This enables those credentials instead of inheriting from the parent. This enables
@ -176,8 +172,6 @@ def _run_single_child(
""" """
from run_agent import AIAgent from run_agent import AIAgent
child_start = time.monotonic()
# When no explicit toolsets given, inherit from parent's enabled toolsets # When no explicit toolsets given, inherit from parent's enabled toolsets
# so disabled tools (e.g. web) don't leak to subagents. # so disabled tools (e.g. web) don't leak to subagents.
if toolsets: if toolsets:
@ -188,15 +182,13 @@ def _run_single_child(
child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS) child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS)
child_prompt = _build_child_system_prompt(goal, context) child_prompt = _build_child_system_prompt(goal, context)
try:
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal). # Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
parent_api_key = getattr(parent_agent, "api_key", None) parent_api_key = getattr(parent_agent, "api_key", None)
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"): if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
parent_api_key = parent_agent._client_kwargs.get("api_key") parent_api_key = parent_agent._client_kwargs.get("api_key")
# Build progress callback to relay tool calls to parent display # Build progress callback to relay tool calls to parent display
child_progress_cb = _build_child_progress_callback(task_index, parent_agent, task_count) child_progress_cb = _build_child_progress_callback(task_index, parent_agent)
# Share the parent's iteration budget so subagent tool calls # Share the parent's iteration budget so subagent tool calls
# count toward the session-wide limit. # count toward the session-wide limit.
@ -241,11 +233,32 @@ def _run_single_child(
# Register child for interrupt propagation # Register child for interrupt propagation
if hasattr(parent_agent, '_active_children'): if hasattr(parent_agent, '_active_children'):
lock = getattr(parent_agent, '_active_children_lock', None)
if lock:
with lock:
parent_agent._active_children.append(child)
else:
parent_agent._active_children.append(child) parent_agent._active_children.append(child)
# Run with stdout/stderr suppressed to prevent interleaved output return child
devnull = io.StringIO()
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull): def _run_single_child(
task_index: int,
goal: str,
child=None,
parent_agent=None,
**_kwargs,
) -> Dict[str, Any]:
"""
Run a pre-built child agent. Called from within a thread.
Returns a structured result dict.
"""
child_start = time.monotonic()
# Get the progress callback from the child agent
child_progress_cb = getattr(child, 'tool_progress_callback', None)
try:
result = child.run_conversation(user_message=goal) result = child.run_conversation(user_message=goal)
# Flush any remaining batched progress to gateway # Flush any remaining batched progress to gateway
@ -355,11 +368,15 @@ def _run_single_child(
# Unregister child from interrupt propagation # Unregister child from interrupt propagation
if hasattr(parent_agent, '_active_children'): if hasattr(parent_agent, '_active_children'):
try: try:
lock = getattr(parent_agent, '_active_children_lock', None)
if lock:
with lock:
parent_agent._active_children.remove(child)
else:
parent_agent._active_children.remove(child) parent_agent._active_children.remove(child)
except (ValueError, UnboundLocalError) as e: except (ValueError, UnboundLocalError) as e:
logger.debug("Could not remove child from active_children: %s", e) logger.debug("Could not remove child from active_children: %s", e)
def delegate_task( def delegate_task(
goal: Optional[str] = None, goal: Optional[str] = None,
context: Optional[str] = None, context: Optional[str] = None,
@ -428,51 +445,38 @@ def delegate_task(
# Track goal labels for progress display (truncated for readability) # Track goal labels for progress display (truncated for readability)
task_labels = [t["goal"][:40] for t in task_list] task_labels = [t["goal"][:40] for t in task_list]
if n_tasks == 1: # Build all child agents on the main thread (thread-safe construction)
# Single task -- run directly (no thread pool overhead) children = []
t = task_list[0] for i, t in enumerate(task_list):
result = _run_single_child( child = _build_child_agent(
task_index=0, task_index=i, goal=t["goal"], context=t.get("context"),
goal=t["goal"], toolsets=t.get("toolsets") or toolsets, model=creds["model"],
context=t.get("context"), max_iterations=effective_max_iter, parent_agent=parent_agent,
toolsets=t.get("toolsets") or toolsets, override_provider=creds["provider"], override_base_url=creds["base_url"],
model=creds["model"],
max_iterations=effective_max_iter,
parent_agent=parent_agent,
task_count=1,
override_provider=creds["provider"],
override_base_url=creds["base_url"],
override_api_key=creds["api_key"], override_api_key=creds["api_key"],
override_api_mode=creds["api_mode"], override_api_mode=creds["api_mode"],
) )
children.append((i, t, child))
if n_tasks == 1:
# Single task -- run directly (no thread pool overhead)
_i, _t, child = children[0]
result = _run_single_child(0, _t["goal"], child, parent_agent)
results.append(result) results.append(result)
else: else:
# Batch -- run in parallel with per-task progress lines # Batch -- run in parallel with per-task progress lines
completed_count = 0 completed_count = 0
spinner_ref = getattr(parent_agent, '_delegate_spinner', None) spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
# Save stdout/stderr before the executor — redirect_stdout in child
# threads races on sys.stdout and can leave it as devnull permanently.
_saved_stdout = sys.stdout
_saved_stderr = sys.stderr
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor: with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor:
futures = {} futures = {}
for i, t in enumerate(task_list): for i, t, child in children:
future = executor.submit( future = executor.submit(
_run_single_child, _run_single_child,
task_index=i, task_index=i,
goal=t["goal"], goal=t["goal"],
context=t.get("context"), child=child,
toolsets=t.get("toolsets") or toolsets,
model=creds["model"],
max_iterations=effective_max_iter,
parent_agent=parent_agent, parent_agent=parent_agent,
task_count=n_tasks,
override_provider=creds["provider"],
override_base_url=creds["base_url"],
override_api_key=creds["api_key"],
override_api_mode=creds["api_mode"],
) )
futures[future] = i futures[future] = i
@ -515,10 +519,6 @@ def delegate_task(
except Exception as e: except Exception as e:
logger.debug("Spinner update_text failed: %s", e) logger.debug("Spinner update_text failed: %s", e)
# Restore stdout/stderr in case redirect_stdout race left them as devnull
sys.stdout = _saved_stdout
sys.stderr = _saved_stderr
# Sort by task_index so results match input order # Sort by task_index so results match input order
results.sort(key=lambda r: r["task_index"]) results.sort(key=lambda r: r["task_index"])