merge: resolve conflicts with main (URL update to hermes-agent.nousresearch.com)
This commit is contained in:
commit
e976879cf2
29 changed files with 3110 additions and 115 deletions
124
tests/gateway/test_interrupt_key_match.py
Normal file
124
tests/gateway/test_interrupt_key_match.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Tests verifying interrupt key consistency between adapter and gateway.
|
||||
|
||||
Regression test for a bug where monitor_for_interrupt() in _run_agent used
|
||||
source.chat_id to query the adapter, but the adapter stores interrupts under
|
||||
the full session key (build_session_key output). This mismatch meant
|
||||
interrupts were never detected, causing subagents to ignore new messages.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
"""Minimal adapter for interrupt tests."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _source(chat_id="123456", chat_type="dm", thread_id=None):
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
class TestInterruptKeyConsistency:
|
||||
"""Ensure adapter interrupt methods are queried with session_key, not chat_id."""
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_dm(self):
|
||||
"""Session key for a DM is NOT the same as chat_id."""
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
assert session_key != source.chat_id
|
||||
assert session_key == "agent:main:telegram:dm"
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_group(self):
|
||||
"""Session key for a group chat includes prefix, unlike raw chat_id."""
|
||||
source = _source("-1001234", "group")
|
||||
session_key = build_session_key(source)
|
||||
assert session_key != source.chat_id
|
||||
assert "agent:main:" in session_key
|
||||
assert source.chat_id in session_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_pending_interrupt_requires_session_key(self):
|
||||
"""has_pending_interrupt returns True only when queried with session_key."""
|
||||
adapter = StubAdapter()
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
# Simulate adapter storing interrupt under session_key
|
||||
interrupt_event = asyncio.Event()
|
||||
adapter._active_sessions[session_key] = interrupt_event
|
||||
interrupt_event.set()
|
||||
|
||||
# Using session_key → found
|
||||
assert adapter.has_pending_interrupt(session_key) is True
|
||||
|
||||
# Using chat_id → NOT found (this was the bug)
|
||||
assert adapter.has_pending_interrupt(source.chat_id) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_message_requires_session_key(self):
|
||||
"""get_pending_message returns the event only with session_key."""
|
||||
adapter = StubAdapter()
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
event = MessageEvent(text="hello", source=source, message_id="42")
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
# Using chat_id → None (the bug)
|
||||
assert adapter.get_pending_message(source.chat_id) is None
|
||||
|
||||
# Using session_key → found
|
||||
result = adapter.get_pending_message(session_key)
|
||||
assert result is event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_stores_under_session_key(self):
|
||||
"""handle_message stores pending messages under session_key, not chat_id."""
|
||||
adapter = StubAdapter()
|
||||
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
|
||||
|
||||
source = _source("-1001234", "group")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
# Mark session as active
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Send a second message while session is active
|
||||
event = MessageEvent(text="interrupt!", source=source, message_id="2")
|
||||
await adapter.handle_message(event)
|
||||
|
||||
# Stored under session_key
|
||||
assert session_key in adapter._pending_messages
|
||||
# NOT stored under chat_id
|
||||
assert source.chat_id not in adapter._pending_messages
|
||||
|
||||
# Interrupt event was set
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
|
|
@ -530,3 +530,277 @@ class TestMessageRouting:
|
|||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestFormatMessage — Markdown → mrkdwn conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatMessage:
|
||||
"""Test markdown to Slack mrkdwn conversion."""
|
||||
|
||||
def test_bold_conversion(self, adapter):
|
||||
assert adapter.format_message("**hello**") == "*hello*"
|
||||
|
||||
def test_italic_asterisk_conversion(self, adapter):
|
||||
assert adapter.format_message("*hello*") == "_hello_"
|
||||
|
||||
def test_italic_underscore_preserved(self, adapter):
|
||||
assert adapter.format_message("_hello_") == "_hello_"
|
||||
|
||||
def test_header_to_bold(self, adapter):
|
||||
assert adapter.format_message("## Section Title") == "*Section Title*"
|
||||
|
||||
def test_header_with_bold_content(self, adapter):
|
||||
# **bold** inside a header should not double-wrap
|
||||
assert adapter.format_message("## **Title**") == "*Title*"
|
||||
|
||||
def test_link_conversion(self, adapter):
|
||||
result = adapter.format_message("[click here](https://example.com)")
|
||||
assert result == "<https://example.com|click here>"
|
||||
|
||||
def test_strikethrough(self, adapter):
|
||||
assert adapter.format_message("~~deleted~~") == "~deleted~"
|
||||
|
||||
def test_code_block_preserved(self, adapter):
|
||||
code = "```python\nx = **not bold**\n```"
|
||||
assert adapter.format_message(code) == code
|
||||
|
||||
def test_inline_code_preserved(self, adapter):
|
||||
text = "Use `**raw**` syntax"
|
||||
assert adapter.format_message(text) == "Use `**raw**` syntax"
|
||||
|
||||
def test_mixed_content(self, adapter):
|
||||
text = "**Bold** and *italic* with `code`"
|
||||
result = adapter.format_message(text)
|
||||
assert "*Bold*" in result
|
||||
assert "_italic_" in result
|
||||
assert "`code`" in result
|
||||
|
||||
def test_empty_string(self, adapter):
|
||||
assert adapter.format_message("") == ""
|
||||
|
||||
def test_none_passthrough(self, adapter):
|
||||
assert adapter.format_message(None) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestReactions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReactions:
|
||||
"""Test emoji reaction methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_reaction_calls_api(self, adapter):
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
result = await adapter._add_reaction("C123", "ts1", "eyes")
|
||||
assert result is True
|
||||
adapter._app.client.reactions_add.assert_called_once_with(
|
||||
channel="C123", timestamp="ts1", name="eyes"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_reaction_handles_error(self, adapter):
|
||||
adapter._app.client.reactions_add = AsyncMock(side_effect=Exception("already_reacted"))
|
||||
result = await adapter._add_reaction("C123", "ts1", "eyes")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_reaction_calls_api(self, adapter):
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
result = await adapter._remove_reaction("C123", "ts1", "eyes")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_in_message_flow(self, adapter):
|
||||
"""Reactions should be added on receipt and swapped on completion."""
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Should have added 👀, then removed 👀, then added ✅
|
||||
add_calls = adapter._app.client.reactions_add.call_args_list
|
||||
remove_calls = adapter._app.client.reactions_remove.call_args_list
|
||||
assert len(add_calls) == 2
|
||||
assert add_calls[0].kwargs["name"] == "eyes"
|
||||
assert add_calls[1].kwargs["name"] == "white_check_mark"
|
||||
assert len(remove_calls) == 1
|
||||
assert remove_calls[0].kwargs["name"] == "eyes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUserNameResolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUserNameResolution:
|
||||
"""Test user identity resolution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_display_name(self, adapter):
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler", "real_name": "Tyler B"}}
|
||||
})
|
||||
name = await adapter._resolve_user_name("U123")
|
||||
assert name == "Tyler"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_real_name(self, adapter):
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "", "real_name": "Tyler B"}}
|
||||
})
|
||||
name = await adapter._resolve_user_name("U123")
|
||||
assert name == "Tyler B"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caches_result(self, adapter):
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
await adapter._resolve_user_name("U123")
|
||||
await adapter._resolve_user_name("U123")
|
||||
# Only one API call despite two lookups
|
||||
assert adapter._app.client.users_info.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_api_error(self, adapter):
|
||||
adapter._app.client.users_info = AsyncMock(side_effect=Exception("rate limited"))
|
||||
name = await adapter._resolve_user_name("U123")
|
||||
assert name == "U123" # Falls back to user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_name_in_message_source(self, adapter):
|
||||
"""Message source should include resolved user name."""
|
||||
adapter._app.client.users_info = AsyncMock(return_value={
|
||||
"user": {"profile": {"display_name": "Tyler"}}
|
||||
})
|
||||
adapter._app.client.reactions_add = AsyncMock()
|
||||
adapter._app.client.reactions_remove = AsyncMock()
|
||||
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
}
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
# Check the source in the MessageEvent passed to handle_message
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.source.user_name == "Tyler"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSlashCommands — expanded command set
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlashCommands:
|
||||
"""Test slash command routing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_maps_to_compress(self, adapter):
|
||||
command = {"text": "compact", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/compress"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_command(self, adapter):
|
||||
command = {"text": "resume my session", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/resume my session"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_command(self, adapter):
|
||||
command = {"text": "background run tests", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/background run tests"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_command(self, adapter):
|
||||
command = {"text": "usage", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/usage"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_command(self, adapter):
|
||||
command = {"text": "reasoning", "user_id": "U1", "channel_id": "C1"}
|
||||
await adapter._handle_slash_command(command)
|
||||
msg = adapter.handle_message.call_args[0][0]
|
||||
assert msg.text == "/reasoning"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMessageSplitting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageSplitting:
|
||||
"""Test that long messages are split before sending."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_message_split_into_chunks(self, adapter):
|
||||
"""Messages over MAX_MESSAGE_LENGTH should be split."""
|
||||
long_text = "x" * 5000
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "ts1"}
|
||||
)
|
||||
await adapter.send("C123", long_text)
|
||||
# Should have been called multiple times
|
||||
assert adapter._app.client.chat_postMessage.call_count >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_message_single_send(self, adapter):
|
||||
"""Short messages should be sent in one call."""
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "ts1"}
|
||||
)
|
||||
await adapter.send("C123", "hello world")
|
||||
assert adapter._app.client.chat_postMessage.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestReplyBroadcast
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReplyBroadcast:
|
||||
"""Test reply_broadcast config option."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_disabled_by_default(self, adapter):
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "ts1"}
|
||||
)
|
||||
await adapter.send("C123", "hi", metadata={"thread_id": "parent_ts"})
|
||||
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert "reply_broadcast" not in kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_enabled_via_config(self, adapter):
|
||||
adapter.config.extra["reply_broadcast"] = True
|
||||
adapter._app.client.chat_postMessage = AsyncMock(
|
||||
return_value={"ts": "ts1"}
|
||||
)
|
||||
await adapter.send("C123", "hi", metadata={"thread_id": "parent_ts"})
|
||||
kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert kwargs.get("reply_broadcast") is True
|
||||
|
|
|
|||
340
tests/hermes_cli/test_claw.py
Normal file
340
tests/hermes_cli/test_claw.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
"""Tests for hermes claw commands."""
|
||||
|
||||
from argparse import Namespace
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli import claw as claw_mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_migration_script
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMigrationScript:
|
||||
"""Test script discovery in known locations."""
|
||||
|
||||
def test_finds_project_root_script(self, tmp_path):
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
with patch.object(claw_mod, "_OPENCLAW_SCRIPT", script):
|
||||
assert claw_mod._find_migration_script() == script
|
||||
|
||||
def test_finds_installed_script(self, tmp_path):
|
||||
installed = tmp_path / "installed.py"
|
||||
installed.write_text("# placeholder")
|
||||
with (
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT", tmp_path / "nonexistent.py"),
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT_INSTALLED", installed),
|
||||
):
|
||||
assert claw_mod._find_migration_script() == installed
|
||||
|
||||
def test_returns_none_when_missing(self, tmp_path):
|
||||
with (
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT", tmp_path / "a.py"),
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT_INSTALLED", tmp_path / "b.py"),
|
||||
):
|
||||
assert claw_mod._find_migration_script() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# claw_command routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClawCommand:
|
||||
"""Test the claw_command router."""
|
||||
|
||||
def test_routes_to_migrate(self):
|
||||
args = Namespace(claw_action="migrate", source=None, dry_run=True,
|
||||
preset="full", overwrite=False, migrate_secrets=False,
|
||||
workspace_target=None, skill_conflict="skip", yes=False)
|
||||
with patch.object(claw_mod, "_cmd_migrate") as mock:
|
||||
claw_mod.claw_command(args)
|
||||
mock.assert_called_once_with(args)
|
||||
|
||||
def test_shows_help_for_no_action(self, capsys):
|
||||
args = Namespace(claw_action=None)
|
||||
claw_mod.claw_command(args)
|
||||
captured = capsys.readouterr()
|
||||
assert "migrate" in captured.out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cmd_migrate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCmdMigrate:
|
||||
"""Test the migrate command handler."""
|
||||
|
||||
def test_error_when_source_missing(self, tmp_path, capsys):
|
||||
args = Namespace(
|
||||
source=str(tmp_path / "nonexistent"),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
claw_mod._cmd_migrate(args)
|
||||
captured = capsys.readouterr()
|
||||
assert "not found" in captured.out
|
||||
|
||||
def test_error_when_script_missing(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
with (
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT", tmp_path / "a.py"),
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT_INSTALLED", tmp_path / "b.py"),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration script not found" in captured.out
|
||||
|
||||
def test_dry_run_succeeds(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
script = tmp_path / "script.py"
|
||||
script.write_text("# placeholder")
|
||||
|
||||
# Build a fake migration module
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value={"soul", "memory"})
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 0, "skipped": 5, "conflict": 0, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "skipped", "reason": "Not found"},
|
||||
],
|
||||
"preset": "full",
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=script),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Dry Run Results" in captured.out
|
||||
assert "5 skipped" in captured.out
|
||||
|
||||
def test_execute_with_confirmation(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("agent:\n max_turns: 90\n")
|
||||
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value={"soul"})
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 2, "skipped": 1, "conflict": 0, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "migrated", "destination": str(tmp_path / "SOUL.md")},
|
||||
{"kind": "memory", "status": "migrated", "destination": str(tmp_path / "memories/MEMORY.md")},
|
||||
],
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=False, preset="user-data", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=True),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration Results" in captured.out
|
||||
assert "Migration complete!" in captured.out
|
||||
|
||||
def test_execute_cancelled_by_user(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("")
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=False, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=False),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration cancelled" in captured.out
|
||||
|
||||
def test_execute_with_yes_skips_confirmation(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("")
|
||||
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value=set())
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 0, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [],
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=False, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "prompt_yes_no") as mock_prompt,
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
mock_prompt.assert_not_called()
|
||||
|
||||
def test_handles_migration_error(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("")
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", side_effect=RuntimeError("boom")),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration failed" in captured.out
|
||||
|
||||
def test_full_preset_enables_secrets(self, tmp_path, capsys):
|
||||
"""The 'full' preset should set migrate_secrets=True automatically."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value=set())
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 0, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [],
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, # Not explicitly set by user
|
||||
workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
# Migrator should have been called with migrate_secrets=True
|
||||
call_kwargs = fake_mod.Migrator.call_args[1]
|
||||
assert call_kwargs["migrate_secrets"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _print_migration_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPrintMigrationReport:
|
||||
"""Test the report formatting function."""
|
||||
|
||||
def test_dry_run_report(self, capsys):
|
||||
report = {
|
||||
"summary": {"migrated": 2, "skipped": 1, "conflict": 1, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "migrated", "destination": "/home/user/.hermes/SOUL.md"},
|
||||
{"kind": "memory", "status": "migrated", "destination": "/home/user/.hermes/memories/MEMORY.md"},
|
||||
{"kind": "skills", "status": "conflict", "reason": "already exists"},
|
||||
{"kind": "tts-assets", "status": "skipped", "reason": "not found"},
|
||||
],
|
||||
"preset": "full",
|
||||
}
|
||||
claw_mod._print_migration_report(report, dry_run=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "Dry Run Results" in captured.out
|
||||
assert "Would migrate" in captured.out
|
||||
assert "2 would migrate" in captured.out
|
||||
assert "--dry-run" in captured.out
|
||||
|
||||
def test_execute_report(self, capsys):
|
||||
report = {
|
||||
"summary": {"migrated": 3, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "migrated", "destination": "/home/user/.hermes/SOUL.md"},
|
||||
],
|
||||
"output_dir": "/home/user/.hermes/migration/openclaw/20250312T120000",
|
||||
}
|
||||
claw_mod._print_migration_report(report, dry_run=False)
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration Results" in captured.out
|
||||
assert "Migrated" in captured.out
|
||||
assert "Full report saved to" in captured.out
|
||||
|
||||
def test_empty_report(self, capsys):
|
||||
report = {
|
||||
"summary": {"migrated": 0, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [],
|
||||
}
|
||||
claw_mod._print_migration_report(report, dry_run=False)
|
||||
captured = capsys.readouterr()
|
||||
assert "Nothing to migrate" in captured.out
|
||||
|
|
@ -160,7 +160,8 @@ class TestValidateFormatChecks:
|
|||
|
||||
def test_no_slash_model_rejected_if_not_in_api(self):
|
||||
result = _validate("gpt-5.4", api_models=["openai/gpt-5.4"])
|
||||
assert result["accepted"] is False
|
||||
assert result["accepted"] is True
|
||||
assert "not found" in result["message"]
|
||||
|
||||
|
||||
# -- validate — API found ----------------------------------------------------
|
||||
|
|
@ -184,37 +185,39 @@ class TestValidateApiFound:
|
|||
# -- validate — API not found ------------------------------------------------
|
||||
|
||||
class TestValidateApiNotFound:
|
||||
def test_model_not_in_api_rejected(self):
|
||||
def test_model_not_in_api_accepted_with_warning(self):
|
||||
result = _validate("anthropic/claude-nonexistent")
|
||||
assert result["accepted"] is False
|
||||
assert "not a valid model" in result["message"]
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert "not found" in result["message"]
|
||||
|
||||
def test_rejection_includes_suggestions(self):
|
||||
def test_warning_includes_suggestions(self):
|
||||
result = _validate("anthropic/claude-opus-4.5")
|
||||
assert result["accepted"] is False
|
||||
assert "Did you mean" in result["message"]
|
||||
assert result["accepted"] is True
|
||||
assert "Similar models" in result["message"]
|
||||
|
||||
|
||||
# -- validate — API unreachable (fallback) -----------------------------------
|
||||
# -- validate — API unreachable — accept and persist everything ----------------
|
||||
|
||||
class TestValidateApiFallback:
|
||||
def test_known_catalog_model_accepted_when_api_down(self):
|
||||
def test_any_model_accepted_when_api_down(self):
|
||||
result = _validate("anthropic/claude-opus-4.6", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
def test_unknown_model_session_only_when_api_down(self):
|
||||
def test_unknown_model_also_accepted_when_api_down(self):
|
||||
"""No hardcoded catalog gatekeeping — accept, persist, and warn."""
|
||||
result = _validate("anthropic/claude-next-gen", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is False
|
||||
assert "session only" in result["message"].lower()
|
||||
assert result["persist"] is True
|
||||
assert "could not reach" in result["message"].lower()
|
||||
|
||||
def test_zai_known_model_accepted_when_api_down(self):
|
||||
def test_zai_model_accepted_when_api_down(self):
|
||||
result = _validate("glm-5", provider="zai", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
def test_unknown_provider_session_only_when_api_down(self):
|
||||
def test_unknown_provider_accepted_when_api_down(self):
|
||||
result = _validate("some-model", provider="totally-unknown", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is False
|
||||
assert result["persist"] is True
|
||||
|
|
|
|||
284
tests/hermes_cli/test_setup_openclaw_migration.py
Normal file
284
tests/hermes_cli/test_setup_openclaw_migration.py
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
"""Tests for OpenClaw migration integration in the setup wizard."""
|
||||
|
||||
from argparse import Namespace
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hermes_cli import setup as setup_mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _offer_openclaw_migration — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOfferOpenclawMigration:
|
||||
"""Test the _offer_openclaw_migration helper in isolation."""
|
||||
|
||||
def test_skips_when_no_openclaw_dir(self, tmp_path):
|
||||
"""Should return False immediately when ~/.openclaw does not exist."""
|
||||
with patch("hermes_cli.setup.Path.home", return_value=tmp_path):
|
||||
assert setup_mod._offer_openclaw_migration(tmp_path / ".hermes") is False
|
||||
|
||||
def test_skips_when_migration_script_missing(self, tmp_path):
|
||||
"""Should return False when the migration script file is absent."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", tmp_path / "nonexistent.py"),
|
||||
):
|
||||
assert setup_mod._offer_openclaw_migration(tmp_path / ".hermes") is False
|
||||
|
||||
def test_skips_when_user_declines(self, tmp_path):
|
||||
"""Should return False when user declines the migration prompt."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
):
|
||||
assert setup_mod._offer_openclaw_migration(tmp_path / ".hermes") is False
|
||||
|
||||
def test_runs_migration_when_user_accepts(self, tmp_path):
|
||||
"""Should dynamically load the script and run the Migrator."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
|
||||
# Create a fake hermes home with config
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text("agent:\n max_turns: 90\n")
|
||||
|
||||
# Build a fake migration module
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value={"soul", "memory"})
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 3, "skipped": 1, "conflict": 0, "error": 0},
|
||||
"output_dir": str(hermes_home / "migration"),
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
patch.object(setup_mod, "get_config_path", return_value=config_path),
|
||||
patch("importlib.util.spec_from_file_location") as mock_spec_fn,
|
||||
):
|
||||
# Wire up the fake module loading
|
||||
mock_spec = MagicMock()
|
||||
mock_spec.loader = MagicMock()
|
||||
mock_spec_fn.return_value = mock_spec
|
||||
|
||||
def exec_module(mod):
|
||||
mod.resolve_selected_options = fake_mod.resolve_selected_options
|
||||
mod.Migrator = fake_mod.Migrator
|
||||
|
||||
mock_spec.loader.exec_module = exec_module
|
||||
|
||||
result = setup_mod._offer_openclaw_migration(hermes_home)
|
||||
|
||||
assert result is True
|
||||
fake_mod.resolve_selected_options.assert_called_once_with(
|
||||
None, None, preset="full"
|
||||
)
|
||||
fake_mod.Migrator.assert_called_once()
|
||||
call_kwargs = fake_mod.Migrator.call_args[1]
|
||||
assert call_kwargs["execute"] is True
|
||||
assert call_kwargs["overwrite"] is False
|
||||
assert call_kwargs["migrate_secrets"] is True
|
||||
assert call_kwargs["preset_name"] == "full"
|
||||
fake_migrator.migrate.assert_called_once()
|
||||
|
||||
def test_handles_migration_error_gracefully(self, tmp_path):
|
||||
"""Should catch exceptions and return False."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text("")
|
||||
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
patch.object(setup_mod, "get_config_path", return_value=config_path),
|
||||
patch(
|
||||
"importlib.util.spec_from_file_location",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
):
|
||||
result = setup_mod._offer_openclaw_migration(hermes_home)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_creates_config_if_missing(self, tmp_path):
|
||||
"""Should bootstrap config.yaml before running migration."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
# config does NOT exist yet
|
||||
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
patch.object(setup_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(setup_mod, "load_config", return_value={"agent": {}}),
|
||||
patch.object(setup_mod, "save_config") as mock_save,
|
||||
patch(
|
||||
"importlib.util.spec_from_file_location",
|
||||
side_effect=RuntimeError("stop early"),
|
||||
),
|
||||
):
|
||||
setup_mod._offer_openclaw_migration(hermes_home)
|
||||
|
||||
# save_config should have been called to bootstrap the file
|
||||
mock_save.assert_called_once_with({"agent": {}})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration with run_setup_wizard — first-time flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _first_time_args() -> Namespace:
|
||||
return Namespace(
|
||||
section=None,
|
||||
non_interactive=False,
|
||||
reset=False,
|
||||
)
|
||||
|
||||
|
||||
class TestSetupWizardOpenclawIntegration:
|
||||
"""Verify _offer_openclaw_migration is called during first-time setup."""
|
||||
|
||||
def test_migration_offered_during_first_time_setup(self, tmp_path):
|
||||
"""On first-time setup, _offer_openclaw_migration should be called."""
|
||||
args = _first_time_args()
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(setup_mod, "load_config", return_value={}),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", return_value=""),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
# User presses Enter to start
|
||||
patch("builtins.input", return_value=""),
|
||||
# Mock the migration offer
|
||||
patch.object(
|
||||
setup_mod, "_offer_openclaw_migration", return_value=False
|
||||
) as mock_migration,
|
||||
# Mock the actual setup sections so they don't run
|
||||
patch.object(setup_mod, "setup_model_provider"),
|
||||
patch.object(setup_mod, "setup_terminal_backend"),
|
||||
patch.object(setup_mod, "setup_agent_settings"),
|
||||
patch.object(setup_mod, "setup_gateway"),
|
||||
patch.object(setup_mod, "setup_tools"),
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
mock_migration.assert_called_once_with(tmp_path)
|
||||
|
||||
def test_migration_reloads_config_on_success(self, tmp_path):
|
||||
"""When migration returns True, config should be reloaded."""
|
||||
args = _first_time_args()
|
||||
call_order = []
|
||||
|
||||
def tracking_load_config():
|
||||
call_order.append("load_config")
|
||||
return {}
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(setup_mod, "load_config", side_effect=tracking_load_config),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", return_value=""),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
patch("builtins.input", return_value=""),
|
||||
patch.object(setup_mod, "_offer_openclaw_migration", return_value=True),
|
||||
patch.object(setup_mod, "setup_model_provider"),
|
||||
patch.object(setup_mod, "setup_terminal_backend"),
|
||||
patch.object(setup_mod, "setup_agent_settings"),
|
||||
patch.object(setup_mod, "setup_gateway"),
|
||||
patch.object(setup_mod, "setup_tools"),
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
# load_config called twice: once at start, once after migration
|
||||
assert call_order.count("load_config") == 2
|
||||
|
||||
def test_reloaded_config_flows_into_remaining_setup_sections(self, tmp_path):
|
||||
args = _first_time_args()
|
||||
initial_config = {}
|
||||
reloaded_config = {"model": {"provider": "openrouter"}}
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(
|
||||
setup_mod,
|
||||
"load_config",
|
||||
side_effect=[initial_config, reloaded_config],
|
||||
),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", return_value=""),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
patch("builtins.input", return_value=""),
|
||||
patch.object(setup_mod, "_offer_openclaw_migration", return_value=True),
|
||||
patch.object(setup_mod, "setup_model_provider") as setup_model_provider,
|
||||
patch.object(setup_mod, "setup_terminal_backend"),
|
||||
patch.object(setup_mod, "setup_agent_settings"),
|
||||
patch.object(setup_mod, "setup_gateway"),
|
||||
patch.object(setup_mod, "setup_tools"),
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
setup_model_provider.assert_called_once_with(reloaded_config)
|
||||
|
||||
def test_migration_not_offered_for_existing_install(self, tmp_path):
|
||||
"""Returning users should not see the migration prompt."""
|
||||
args = _first_time_args()
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(setup_mod, "load_config", return_value={}),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(
|
||||
setup_mod,
|
||||
"get_env_value",
|
||||
side_effect=lambda k: "sk-xxx" if k == "OPENROUTER_API_KEY" else "",
|
||||
),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
# Returning user picks "Exit"
|
||||
patch.object(setup_mod, "prompt_choice", return_value=9),
|
||||
patch.object(
|
||||
setup_mod, "_offer_openclaw_migration", return_value=False
|
||||
) as mock_migration,
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
mock_migration.assert_not_called()
|
||||
141
tests/run_interrupt_test.py
Normal file
141
tests/run_interrupt_test.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Run a real interrupt test with actual AIAgent + delegate child.
|
||||
|
||||
Not a pytest test — runs directly as a script for live testing.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from run_agent import AIAgent, IterationBudget
|
||||
from tools.delegate_tool import _run_single_child
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
|
||||
set_interrupt(False)
|
||||
|
||||
# Create parent agent (minimal)
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
|
||||
child_started = threading.Event()
|
||||
result_holder = [None]
|
||||
|
||||
|
||||
def run_delegate():
|
||||
with patch("run_agent.OpenAI") as MockOpenAI:
|
||||
mock_client = MagicMock()
|
||||
|
||||
def slow_create(**kwargs):
|
||||
time.sleep(3)
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "Done"
|
||||
resp.choices[0].message.tool_calls = None
|
||||
resp.choices[0].message.refusal = None
|
||||
resp.choices[0].finish_reason = "stop"
|
||||
resp.usage.prompt_tokens = 100
|
||||
resp.usage.completion_tokens = 10
|
||||
resp.usage.total_tokens = 110
|
||||
resp.usage.prompt_tokens_details = None
|
||||
return resp
|
||||
|
||||
mock_client.chat.completions.create = slow_create
|
||||
mock_client.close = MagicMock()
|
||||
MockOpenAI.return_value = mock_client
|
||||
|
||||
original_init = AIAgent.__init__
|
||||
|
||||
def patched_init(self_agent, *a, **kw):
|
||||
original_init(self_agent, *a, **kw)
|
||||
child_started.set()
|
||||
|
||||
with patch.object(AIAgent, "__init__", patched_init):
|
||||
try:
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Test slow task",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model="test/model",
|
||||
max_iterations=5,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
override_provider="test",
|
||||
override_base_url="http://localhost:1",
|
||||
override_api_key="test",
|
||||
override_api_mode="chat_completions",
|
||||
)
|
||||
result_holder[0] = result
|
||||
except Exception as e:
|
||||
print(f"ERROR in delegate: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
print("Starting agent thread...")
|
||||
agent_thread = threading.Thread(target=run_delegate, daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
started = child_started.wait(timeout=10)
|
||||
if not started:
|
||||
print("ERROR: Child never started")
|
||||
sys.exit(1)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
print(f"Active children: {len(parent._active_children)}")
|
||||
for i, c in enumerate(parent._active_children):
|
||||
print(f" Child {i}: _interrupt_requested={c._interrupt_requested}")
|
||||
|
||||
t0 = time.monotonic()
|
||||
parent.interrupt("User typed a new message")
|
||||
print(f"Called parent.interrupt()")
|
||||
|
||||
for i, c in enumerate(parent._active_children):
|
||||
print(f" Child {i} after interrupt: _interrupt_requested={c._interrupt_requested}")
|
||||
print(f"Global is_interrupted: {is_interrupted()}")
|
||||
|
||||
agent_thread.join(timeout=10)
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"Agent thread finished in {elapsed:.2f}s")
|
||||
|
||||
result = result_holder[0]
|
||||
if result:
|
||||
print(f"Status: {result['status']}")
|
||||
print(f"Duration: {result['duration_seconds']}s")
|
||||
if elapsed < 2.0:
|
||||
print("✅ PASS: Interrupt detected quickly!")
|
||||
else:
|
||||
print(f"❌ FAIL: Took {elapsed:.2f}s — interrupt was too slow or not detected")
|
||||
else:
|
||||
print("❌ FAIL: No result!")
|
||||
|
||||
set_interrupt(False)
|
||||
171
tests/test_cli_interrupt_subagent.py
Normal file
171
tests/test_cli_interrupt_subagent.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
"""End-to-end test simulating CLI interrupt during subagent execution.
|
||||
|
||||
Reproduces the exact scenario:
|
||||
1. Parent agent calls delegate_task
|
||||
2. Child agent is running (simulated with a slow tool)
|
||||
3. User "types a message" (simulated by calling parent.interrupt from another thread)
|
||||
4. Child should detect the interrupt and stop
|
||||
|
||||
This tests the COMPLETE path including _run_single_child, _active_children
|
||||
registration, interrupt propagation, and child detection.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
|
||||
|
||||
class TestCLISubagentInterrupt(unittest.TestCase):
|
||||
"""Simulate exact CLI scenario."""
|
||||
|
||||
def setUp(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def tearDown(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def test_full_delegate_interrupt_flow(self):
|
||||
"""Full integration: parent runs delegate_task, main thread interrupts."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
interrupt_detected = threading.Event()
|
||||
child_started = threading.Event()
|
||||
child_api_call_count = 0
|
||||
|
||||
# Create a real-enough parent agent
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
# We'll track what happens with _active_children
|
||||
original_children = parent._active_children
|
||||
|
||||
# Mock the child's run_conversation to simulate a slow operation
|
||||
# that checks _interrupt_requested like the real one does
|
||||
def mock_child_run_conversation(user_message, **kwargs):
|
||||
child_started.set()
|
||||
# Find the child in parent._active_children
|
||||
child = parent._active_children[-1] if parent._active_children else None
|
||||
|
||||
# Simulate the agent loop: poll _interrupt_requested like run_conversation does
|
||||
for i in range(100): # Up to 10 seconds (100 * 0.1s)
|
||||
if child and child._interrupt_requested:
|
||||
interrupt_detected.set()
|
||||
return {
|
||||
"final_response": "Interrupted!",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
"interrupt_message": child._interrupt_message,
|
||||
}
|
||||
time.sleep(0.1)
|
||||
|
||||
return {
|
||||
"final_response": "Finished without interrupt",
|
||||
"messages": [],
|
||||
"api_calls": 5,
|
||||
"completed": True,
|
||||
"interrupted": False,
|
||||
}
|
||||
|
||||
# Patch AIAgent to use our mock
|
||||
from tools.delegate_tool import _run_single_child
|
||||
from run_agent import IterationBudget
|
||||
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
|
||||
# Run delegate in a thread (simulates agent_thread)
|
||||
delegate_result = [None]
|
||||
delegate_error = [None]
|
||||
|
||||
def run_delegate():
|
||||
try:
|
||||
with patch('run_agent.AIAgent') as MockAgent:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance._interrupt_requested = False
|
||||
mock_instance._interrupt_message = None
|
||||
mock_instance._active_children = []
|
||||
mock_instance.quiet_mode = True
|
||||
mock_instance.run_conversation = mock_child_run_conversation
|
||||
mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg)
|
||||
mock_instance.tools = []
|
||||
MockAgent.return_value = mock_instance
|
||||
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Do something slow",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model=None,
|
||||
max_iterations=50,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
delegate_result[0] = result
|
||||
except Exception as e:
|
||||
delegate_error[0] = e
|
||||
|
||||
agent_thread = threading.Thread(target=run_delegate, daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
# Wait for child to start
|
||||
assert child_started.wait(timeout=5), "Child never started!"
|
||||
|
||||
# Now simulate user interrupt (from main/process thread)
|
||||
time.sleep(0.2) # Give child a moment to be in its loop
|
||||
|
||||
print(f"Parent has {len(parent._active_children)} active children")
|
||||
assert len(parent._active_children) >= 1, f"Expected child in _active_children, got {len(parent._active_children)}"
|
||||
|
||||
# This is what the CLI does:
|
||||
parent.interrupt("Hey stop that")
|
||||
|
||||
print(f"Parent._interrupt_requested: {parent._interrupt_requested}")
|
||||
for i, child in enumerate(parent._active_children):
|
||||
print(f"Child {i}._interrupt_requested: {child._interrupt_requested}")
|
||||
|
||||
# Wait for child to detect interrupt
|
||||
detected = interrupt_detected.wait(timeout=3.0)
|
||||
|
||||
# Wait for delegate to finish
|
||||
agent_thread.join(timeout=5)
|
||||
|
||||
if delegate_error[0]:
|
||||
raise delegate_error[0]
|
||||
|
||||
assert detected, "Child never detected the interrupt!"
|
||||
result = delegate_result[0]
|
||||
assert result is not None, "Delegate returned no result"
|
||||
assert result["status"] == "interrupted", f"Expected 'interrupted', got '{result['status']}'"
|
||||
print(f"✓ Interrupt detected! Result: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -31,7 +31,7 @@ class TestModelCommand:
|
|||
assert cli_obj.model == "anthropic/claude-sonnet-4.5"
|
||||
save_mock.assert_called_once_with("model.default", "anthropic/claude-sonnet-4.5")
|
||||
|
||||
def test_invalid_model_from_api_is_rejected(self, capsys):
|
||||
def test_unlisted_model_accepted_with_warning(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models",
|
||||
|
|
@ -40,12 +40,10 @@ class TestModelCommand:
|
|||
cli_obj.process_command("/model anthropic/fake-model")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "not a valid model" in output
|
||||
assert "Model unchanged" in output
|
||||
assert cli_obj.model == "anthropic/claude-opus-4.6"
|
||||
save_mock.assert_not_called()
|
||||
assert "not found" in output or "Model changed" in output
|
||||
assert cli_obj.model == "anthropic/fake-model" # accepted
|
||||
|
||||
def test_api_unreachable_falls_back_session_only(self, capsys):
|
||||
def test_api_unreachable_accepts_and_persists(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=None), \
|
||||
|
|
@ -53,12 +51,11 @@ class TestModelCommand:
|
|||
cli_obj.process_command("/model anthropic/claude-sonnet-next")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "session only" in output
|
||||
assert "will revert on restart" in output
|
||||
assert "saved to config" in output
|
||||
assert cli_obj.model == "anthropic/claude-sonnet-next"
|
||||
save_mock.assert_not_called()
|
||||
save_mock.assert_called_once()
|
||||
|
||||
def test_no_slash_model_probes_api_and_rejects(self, capsys):
|
||||
def test_no_slash_model_accepted_with_warning(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models",
|
||||
|
|
@ -67,11 +64,8 @@ class TestModelCommand:
|
|||
cli_obj.process_command("/model gpt-5.4")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "not a valid model" in output
|
||||
assert "Model unchanged" in output
|
||||
assert cli_obj.model == "anthropic/claude-opus-4.6" # unchanged
|
||||
assert cli_obj.agent is not None # not reset
|
||||
save_mock.assert_not_called()
|
||||
# Model is accepted (with warning) even if not in API listing
|
||||
assert cli_obj.model == "gpt-5.4"
|
||||
|
||||
def test_validation_crash_falls_back_to_save(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
|
|
|||
189
tests/test_interactive_interrupt.py
Normal file
189
tests/test_interactive_interrupt.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Interactive interrupt test that mimics the exact CLI flow.
|
||||
|
||||
Starts an agent in a thread with a mock delegate_task that takes a while,
|
||||
then simulates the user typing a message via _interrupt_queue.
|
||||
|
||||
Logs every step to stderr (which isn't affected by redirect_stdout)
|
||||
so we can see exactly where the interrupt gets lost.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
|
||||
# Force stderr logging so redirect_stdout doesn't swallow it
|
||||
logging.basicConfig(level=logging.DEBUG, stream=sys.stderr,
|
||||
format="%(asctime)s [%(threadName)s] %(message)s")
|
||||
log = logging.getLogger("interrupt_test")
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from run_agent import AIAgent, IterationBudget
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
|
||||
set_interrupt(False)
|
||||
|
||||
# ─── Create parent agent ───
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
|
||||
# Monkey-patch parent.interrupt to log
|
||||
_original_interrupt = AIAgent.interrupt
|
||||
def logged_interrupt(self, message=None):
|
||||
log.info(f"🔴 parent.interrupt() called with: {message!r}")
|
||||
log.info(f" _active_children count: {len(self._active_children)}")
|
||||
_original_interrupt(self, message)
|
||||
log.info(f" After interrupt: _interrupt_requested={self._interrupt_requested}")
|
||||
for i, c in enumerate(self._active_children):
|
||||
log.info(f" Child {i}._interrupt_requested={c._interrupt_requested}")
|
||||
parent.interrupt = lambda msg=None: logged_interrupt(parent, msg)
|
||||
|
||||
# ─── Simulate the exact CLI flow ───
|
||||
interrupt_queue = queue.Queue()
|
||||
child_running = threading.Event()
|
||||
agent_result = [None]
|
||||
|
||||
def make_slow_response(delay=2.0):
|
||||
"""API response that takes a while."""
|
||||
def create(**kwargs):
|
||||
log.info(f" 🌐 Mock API call starting (will take {delay}s)...")
|
||||
time.sleep(delay)
|
||||
log.info(f" 🌐 Mock API call completed")
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "Done with the task"
|
||||
resp.choices[0].message.tool_calls = None
|
||||
resp.choices[0].message.refusal = None
|
||||
resp.choices[0].finish_reason = "stop"
|
||||
resp.usage.prompt_tokens = 100
|
||||
resp.usage.completion_tokens = 10
|
||||
resp.usage.total_tokens = 110
|
||||
resp.usage.prompt_tokens_details = None
|
||||
return resp
|
||||
return create
|
||||
|
||||
|
||||
def agent_thread_func():
|
||||
"""Simulates the agent_thread in cli.py's chat() method."""
|
||||
log.info("🟢 agent_thread starting")
|
||||
|
||||
with patch("run_agent.OpenAI") as MockOpenAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = make_slow_response(delay=3.0)
|
||||
mock_client.close = MagicMock()
|
||||
MockOpenAI.return_value = mock_client
|
||||
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
# Signal that child is about to start
|
||||
original_init = AIAgent.__init__
|
||||
def patched_init(self_agent, *a, **kw):
|
||||
log.info("🟡 Child AIAgent.__init__ called")
|
||||
original_init(self_agent, *a, **kw)
|
||||
child_running.set()
|
||||
log.info(f"🟡 Child started, parent._active_children = {len(parent._active_children)}")
|
||||
|
||||
with patch.object(AIAgent, "__init__", patched_init):
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Do a slow thing",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model="test/model",
|
||||
max_iterations=3,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
override_provider="test",
|
||||
override_base_url="http://localhost:1",
|
||||
override_api_key="test",
|
||||
override_api_mode="chat_completions",
|
||||
)
|
||||
agent_result[0] = result
|
||||
log.info(f"🟢 agent_thread finished. Result status: {result.get('status')}")
|
||||
|
||||
|
||||
# ─── Start agent thread (like chat() does) ───
|
||||
agent_thread = threading.Thread(target=agent_thread_func, name="agent_thread", daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
# ─── Wait for child to start ───
|
||||
if not child_running.wait(timeout=10):
|
||||
print("FAIL: Child never started", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Give child time to enter its main loop and start API call
|
||||
time.sleep(1.0)
|
||||
|
||||
# ─── Simulate user typing a message (like handle_enter does) ───
|
||||
log.info("📝 Simulating user typing 'Hey stop that'")
|
||||
interrupt_queue.put("Hey stop that")
|
||||
|
||||
# ─── Simulate chat() polling loop (like the real chat() method) ───
|
||||
log.info("📡 Starting interrupt queue polling (like chat())")
|
||||
interrupt_msg = None
|
||||
poll_count = 0
|
||||
while agent_thread.is_alive():
|
||||
try:
|
||||
interrupt_msg = interrupt_queue.get(timeout=0.1)
|
||||
if interrupt_msg:
|
||||
log.info(f"📨 Got interrupt message from queue: {interrupt_msg!r}")
|
||||
log.info(f" Calling parent.interrupt()...")
|
||||
parent.interrupt(interrupt_msg)
|
||||
log.info(f" parent.interrupt() returned. Breaking poll loop.")
|
||||
break
|
||||
except queue.Empty:
|
||||
poll_count += 1
|
||||
if poll_count % 20 == 0: # Log every 2s
|
||||
log.info(f" Still polling ({poll_count} iterations)...")
|
||||
|
||||
# ─── Wait for agent to finish ───
|
||||
log.info("⏳ Waiting for agent_thread to join...")
|
||||
t0 = time.monotonic()
|
||||
agent_thread.join(timeout=10)
|
||||
elapsed = time.monotonic() - t0
|
||||
log.info(f"✅ agent_thread joined after {elapsed:.2f}s")
|
||||
|
||||
# ─── Check results ───
|
||||
result = agent_result[0]
|
||||
if result:
|
||||
log.info(f"Result status: {result['status']}")
|
||||
log.info(f"Result duration: {result['duration_seconds']}s")
|
||||
if result["status"] == "interrupted" and elapsed < 2.0:
|
||||
print("✅ PASS: Interrupt worked correctly!", file=sys.stderr)
|
||||
else:
|
||||
print(f"❌ FAIL: status={result['status']}, elapsed={elapsed:.2f}s", file=sys.stderr)
|
||||
else:
|
||||
print("❌ FAIL: No result returned", file=sys.stderr)
|
||||
|
||||
set_interrupt(False)
|
||||
155
tests/test_interrupt_propagation.py
Normal file
155
tests/test_interrupt_propagation.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""Test interrupt propagation from parent to child agents.
|
||||
|
||||
Reproduces the CLI scenario: user sends a message while delegate_task is
|
||||
running, main thread calls parent.interrupt(), child should stop.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
from tools.interrupt import set_interrupt, is_interrupted, _interrupt_event
|
||||
|
||||
|
||||
class TestInterruptPropagationToChild(unittest.TestCase):
|
||||
"""Verify interrupt propagates from parent to child agent."""
|
||||
|
||||
def setUp(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def tearDown(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def test_parent_interrupt_sets_child_flag(self):
|
||||
"""When parent.interrupt() is called, child._interrupt_requested should be set."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child.quiet_mode = True
|
||||
|
||||
parent._active_children.append(child)
|
||||
|
||||
parent.interrupt("new user message")
|
||||
|
||||
assert parent._interrupt_requested is True
|
||||
assert child._interrupt_requested is True
|
||||
assert child._interrupt_message == "new user message"
|
||||
assert is_interrupted() is True
|
||||
|
||||
def test_child_clear_interrupt_at_start_clears_global(self):
|
||||
"""child.clear_interrupt() at start of run_conversation clears the GLOBAL event.
|
||||
|
||||
This is the intended behavior at startup, but verify it doesn't
|
||||
accidentally clear an interrupt intended for a running child.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = True
|
||||
child._interrupt_message = "msg"
|
||||
child.quiet_mode = True
|
||||
child._active_children = []
|
||||
|
||||
# Global is set
|
||||
set_interrupt(True)
|
||||
assert is_interrupted() is True
|
||||
|
||||
# child.clear_interrupt() clears both
|
||||
child.clear_interrupt()
|
||||
assert child._interrupt_requested is False
|
||||
assert is_interrupted() is False
|
||||
|
||||
def test_interrupt_during_child_api_call_detected(self):
|
||||
"""Interrupt set during _interruptible_api_call is detected within 0.5s."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child.quiet_mode = True
|
||||
child.api_mode = "chat_completions"
|
||||
child.log_prefix = ""
|
||||
child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"}
|
||||
|
||||
# Mock a slow API call
|
||||
mock_client = MagicMock()
|
||||
def slow_api_call(**kwargs):
|
||||
time.sleep(5) # Would take 5s normally
|
||||
return MagicMock()
|
||||
mock_client.chat.completions.create = slow_api_call
|
||||
mock_client.close = MagicMock()
|
||||
child.client = mock_client
|
||||
|
||||
# Set interrupt after 0.2s from another thread
|
||||
def set_interrupt_later():
|
||||
time.sleep(0.2)
|
||||
child.interrupt("stop!")
|
||||
t = threading.Thread(target=set_interrupt_later, daemon=True)
|
||||
t.start()
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
child._interruptible_api_call({"model": "test", "messages": []})
|
||||
self.fail("Should have raised InterruptedError")
|
||||
except InterruptedError:
|
||||
elapsed = time.monotonic() - start
|
||||
# Should detect within ~0.5s (0.2s delay + 0.3s poll interval)
|
||||
assert elapsed < 1.0, f"Took {elapsed:.2f}s to detect interrupt (expected < 1.0s)"
|
||||
finally:
|
||||
t.join(timeout=2)
|
||||
set_interrupt(False)
|
||||
|
||||
def test_concurrent_interrupt_propagation(self):
|
||||
"""Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child.quiet_mode = True
|
||||
|
||||
# Register child (simulating what _run_single_child does)
|
||||
parent._active_children.append(child)
|
||||
|
||||
# Simulate child running (checking flag in a loop)
|
||||
child_detected = threading.Event()
|
||||
def simulate_child_loop():
|
||||
while not child._interrupt_requested:
|
||||
time.sleep(0.05)
|
||||
child_detected.set()
|
||||
|
||||
child_thread = threading.Thread(target=simulate_child_loop, daemon=True)
|
||||
child_thread.start()
|
||||
|
||||
# Small delay, then interrupt from "main thread"
|
||||
time.sleep(0.1)
|
||||
parent.interrupt("user typed something new")
|
||||
|
||||
# Child should detect within 200ms
|
||||
detected = child_detected.wait(timeout=1.0)
|
||||
assert detected, "Child never detected the interrupt!"
|
||||
child_thread.join(timeout=1)
|
||||
set_interrupt(False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
176
tests/test_real_interrupt_subagent.py
Normal file
176
tests/test_real_interrupt_subagent.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
"""Test real interrupt propagation through delegate_task with actual AIAgent.
|
||||
|
||||
This uses a real AIAgent with mocked HTTP responses to test the complete
|
||||
interrupt flow through _run_single_child → child.run_conversation().
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
|
||||
|
||||
def _make_slow_api_response(delay=5.0):
|
||||
"""Create a mock that simulates a slow API response (like a real LLM call)."""
|
||||
def slow_create(**kwargs):
|
||||
# Simulate a slow API call
|
||||
time.sleep(delay)
|
||||
# Return a simple text response (no tool calls)
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message = MagicMock()
|
||||
resp.choices[0].message.content = "Done"
|
||||
resp.choices[0].message.tool_calls = None
|
||||
resp.choices[0].message.refusal = None
|
||||
resp.choices[0].finish_reason = "stop"
|
||||
resp.usage = MagicMock()
|
||||
resp.usage.prompt_tokens = 100
|
||||
resp.usage.completion_tokens = 10
|
||||
resp.usage.total_tokens = 110
|
||||
resp.usage.prompt_tokens_details = None
|
||||
return resp
|
||||
return slow_create
|
||||
|
||||
|
||||
class TestRealSubagentInterrupt(unittest.TestCase):
|
||||
"""Test interrupt with real AIAgent child through delegate_tool."""
|
||||
|
||||
def setUp(self):
|
||||
set_interrupt(False)
|
||||
os.environ.setdefault("OPENAI_API_KEY", "test-key")
|
||||
|
||||
def tearDown(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def test_interrupt_child_during_api_call(self):
|
||||
"""Real AIAgent child interrupted while making API call."""
|
||||
from run_agent import AIAgent, IterationBudget
|
||||
|
||||
# Create a real parent agent (just enough to be a parent)
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
child_started = threading.Event()
|
||||
result_holder = [None]
|
||||
error_holder = [None]
|
||||
|
||||
def run_delegate():
|
||||
try:
|
||||
# Patch the OpenAI client creation inside AIAgent.__init__
|
||||
with patch('run_agent.OpenAI') as MockOpenAI:
|
||||
mock_client = MagicMock()
|
||||
# API call takes 5 seconds — should be interrupted before that
|
||||
mock_client.chat.completions.create = _make_slow_api_response(delay=5.0)
|
||||
mock_client.close = MagicMock()
|
||||
MockOpenAI.return_value = mock_client
|
||||
|
||||
# Also need to patch the system prompt builder
|
||||
with patch('run_agent.build_system_prompt', return_value="You are a test agent"):
|
||||
# Signal when child starts
|
||||
original_run = AIAgent.run_conversation
|
||||
|
||||
def patched_run(self_agent, *args, **kwargs):
|
||||
child_started.set()
|
||||
return original_run(self_agent, *args, **kwargs)
|
||||
|
||||
with patch.object(AIAgent, 'run_conversation', patched_run):
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Test task",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model="test/model",
|
||||
max_iterations=5,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
override_provider="test",
|
||||
override_base_url="http://localhost:1",
|
||||
override_api_key="test",
|
||||
override_api_mode="chat_completions",
|
||||
)
|
||||
result_holder[0] = result
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
error_holder[0] = e
|
||||
|
||||
agent_thread = threading.Thread(target=run_delegate, daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
# Wait for child to start run_conversation
|
||||
started = child_started.wait(timeout=10)
|
||||
if not started:
|
||||
agent_thread.join(timeout=1)
|
||||
if error_holder[0]:
|
||||
raise error_holder[0]
|
||||
self.fail("Child never started run_conversation")
|
||||
|
||||
# Give child time to enter main loop and start API call
|
||||
time.sleep(0.5)
|
||||
|
||||
# Verify child is registered
|
||||
print(f"Active children: {len(parent._active_children)}")
|
||||
self.assertGreaterEqual(len(parent._active_children), 1,
|
||||
"Child not registered in _active_children")
|
||||
|
||||
# Interrupt! (simulating what CLI does)
|
||||
start = time.monotonic()
|
||||
parent.interrupt("User typed a new message")
|
||||
|
||||
# Check propagation
|
||||
child = parent._active_children[0] if parent._active_children else None
|
||||
if child:
|
||||
print(f"Child._interrupt_requested after parent.interrupt(): {child._interrupt_requested}")
|
||||
self.assertTrue(child._interrupt_requested,
|
||||
"Interrupt did not propagate to child!")
|
||||
|
||||
# Wait for delegate to finish (should be fast since interrupted)
|
||||
agent_thread.join(timeout=5)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
if error_holder[0]:
|
||||
raise error_holder[0]
|
||||
|
||||
result = result_holder[0]
|
||||
self.assertIsNotNone(result, "Delegate returned no result")
|
||||
print(f"Result status: {result['status']}, elapsed: {elapsed:.2f}s")
|
||||
print(f"Full result: {result}")
|
||||
|
||||
# The child should have been interrupted, not completed the full 5s API call
|
||||
self.assertLess(elapsed, 3.0,
|
||||
f"Took {elapsed:.2f}s — interrupt was not detected quickly enough")
|
||||
self.assertEqual(result["status"], "interrupted",
|
||||
f"Expected 'interrupted', got '{result['status']}'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
54
tests/test_redirect_stdout_issue.py
Normal file
54
tests/test_redirect_stdout_issue.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""Verify that redirect_stdout in _run_single_child is process-wide.
|
||||
|
||||
This demonstrates that contextlib.redirect_stdout changes sys.stdout
|
||||
for ALL threads, not just the current one. This means during subagent
|
||||
execution, all output from other threads (including the CLI's process_thread)
|
||||
is swallowed.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
|
||||
class TestRedirectStdoutIsProcessWide(unittest.TestCase):
|
||||
|
||||
def test_redirect_stdout_affects_other_threads(self):
|
||||
"""contextlib.redirect_stdout changes sys.stdout for ALL threads."""
|
||||
captured_from_other_thread = []
|
||||
real_stdout = sys.stdout
|
||||
other_thread_saw_devnull = threading.Event()
|
||||
|
||||
def other_thread_work():
|
||||
"""Runs in a different thread, tries to use sys.stdout."""
|
||||
time.sleep(0.2) # Let redirect_stdout take effect
|
||||
# Check what sys.stdout is
|
||||
if sys.stdout is not real_stdout:
|
||||
other_thread_saw_devnull.set()
|
||||
# Try to print — this should go to devnull
|
||||
captured_from_other_thread.append(sys.stdout)
|
||||
|
||||
t = threading.Thread(target=other_thread_work, daemon=True)
|
||||
t.start()
|
||||
|
||||
# redirect_stdout in main thread
|
||||
devnull = io.StringIO()
|
||||
with contextlib.redirect_stdout(devnull):
|
||||
time.sleep(0.5) # Let the other thread check during redirect
|
||||
|
||||
t.join(timeout=2)
|
||||
|
||||
# The other thread should have seen devnull, NOT the real stdout
|
||||
self.assertTrue(
|
||||
other_thread_saw_devnull.is_set(),
|
||||
"redirect_stdout was NOT process-wide — other thread still saw real stdout. "
|
||||
"This test's premise is wrong."
|
||||
)
|
||||
print("Confirmed: redirect_stdout IS process-wide — affects all threads")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue