merge: resolve conflict with main in subagent interrupt test
This commit is contained in:
commit
fefc709b2c
75 changed files with 8124 additions and 1376 deletions
|
|
@ -9,8 +9,7 @@ from agent.context_compressor import ContextCompressor
|
|||
@pytest.fixture()
|
||||
def compressor():
|
||||
"""Create a ContextCompressor with mocked dependencies."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(None, None)):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(
|
||||
model="test/model",
|
||||
threshold_percent=0.85,
|
||||
|
|
@ -119,14 +118,11 @@ class TestGenerateSummaryNoneContent:
|
|||
"""Regression: content=None (from tool-call-only assistant messages) must not crash."""
|
||||
|
||||
def test_none_content_does_not_crash(self):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: tool calls happened"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
|
||||
messages = [
|
||||
|
|
@ -139,14 +135,14 @@ class TestGenerateSummaryNoneContent:
|
|||
{"role": "user", "content": "thanks"},
|
||||
]
|
||||
|
||||
summary = c._generate_summary(messages)
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
summary = c._generate_summary(messages)
|
||||
assert isinstance(summary, str)
|
||||
assert "CONTEXT SUMMARY" in summary
|
||||
|
||||
def test_none_content_in_system_message_compress(self):
|
||||
"""System message with content=None should not crash during compress."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(None, None)):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||
|
||||
msgs = [{"role": "system", "content": None}] + [
|
||||
|
|
@ -165,12 +161,12 @@ class TestCompressWithClient:
|
|||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: stuff happened"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
|
||||
msgs = [{"role": "user" if i % 2 == 0 else "assistant", "content": f"msg {i}"} for i in range(10)]
|
||||
result = c.compress(msgs)
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
|
||||
# Should have summary message in the middle
|
||||
contents = [m.get("content", "") for m in result]
|
||||
|
|
@ -184,8 +180,7 @@ class TestCompressWithClient:
|
|||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: compressed middle"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(
|
||||
model="test",
|
||||
quiet_mode=True,
|
||||
|
|
@ -212,7 +207,8 @@ class TestCompressWithClient:
|
|||
{"role": "user", "content": "later 4"},
|
||||
]
|
||||
|
||||
result = c.compress(msgs)
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
|
||||
answered_ids = {
|
||||
msg.get("tool_call_id")
|
||||
|
|
@ -232,8 +228,7 @@ class TestCompressWithClient:
|
|||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: stuff happened"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=2, protect_last_n=2)
|
||||
|
||||
# Last head message (index 1) is "assistant" → summary should be "user"
|
||||
|
|
@ -245,7 +240,8 @@ class TestCompressWithClient:
|
|||
{"role": "user", "content": "msg 4"},
|
||||
{"role": "assistant", "content": "msg 5"},
|
||||
]
|
||||
result = c.compress(msgs)
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
summary_msg = [m for m in result if "CONTEXT SUMMARY" in (m.get("content") or "")]
|
||||
assert len(summary_msg) == 1
|
||||
assert summary_msg[0]["role"] == "user"
|
||||
|
|
@ -258,8 +254,7 @@ class TestCompressWithClient:
|
|||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: stuff happened"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True, protect_first_n=3, protect_last_n=2)
|
||||
|
||||
# Last head message (index 2) is "user" → summary should be "assistant"
|
||||
|
|
@ -273,20 +268,18 @@ class TestCompressWithClient:
|
|||
{"role": "user", "content": "msg 6"},
|
||||
{"role": "assistant", "content": "msg 7"},
|
||||
]
|
||||
result = c.compress(msgs)
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
summary_msg = [m for m in result if "CONTEXT SUMMARY" in (m.get("content") or "")]
|
||||
assert len(summary_msg) == 1
|
||||
assert summary_msg[0]["role"] == "assistant"
|
||||
|
||||
def test_summarization_does_not_start_tail_with_tool_outputs(self):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: compressed middle"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
c = ContextCompressor(
|
||||
model="test",
|
||||
quiet_mode=True,
|
||||
|
|
@ -309,7 +302,8 @@ class TestCompressWithClient:
|
|||
{"role": "user", "content": "latest user"},
|
||||
]
|
||||
|
||||
result = c.compress(msgs)
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response):
|
||||
result = c.compress(msgs)
|
||||
|
||||
called_ids = {
|
||||
tc["id"]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Shared fixtures for the hermes-agent test suite."""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
|
@ -48,3 +49,21 @@ def mock_config():
|
|||
"memory": {"memory_enabled": False, "user_profile_enabled": False},
|
||||
"command_allowlist": [],
|
||||
}
|
||||
|
||||
|
||||
# ── Global test timeout ─────────────────────────────────────────────────────
|
||||
# Kill any individual test that takes longer than 30 seconds.
|
||||
# Prevents hanging tests (subprocess spawns, blocking I/O) from stalling the
|
||||
# entire test suite.
|
||||
|
||||
def _timeout_handler(signum, frame):
|
||||
raise TimeoutError("Test exceeded 30 second timeout")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enforce_test_timeout():
|
||||
"""Kill any individual test that takes longer than 30 seconds."""
|
||||
old = signal.signal(signal.SIGALRM, _timeout_handler)
|
||||
signal.alarm(30)
|
||||
yield
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old)
|
||||
|
|
|
|||
249
tests/gateway/test_discord_free_response.py
Normal file
249
tests/gateway/test_discord_free_response.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
"""Tests for Discord free-response defaults and mention gating."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class FakeDMChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "dm"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
|
||||
|
||||
class FakeTextChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "general", guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.guild = SimpleNamespace(name=guild_name)
|
||||
self.topic = None
|
||||
|
||||
|
||||
class FakeForumChannel:
|
||||
def __init__(self, channel_id: int = 1, name: str = "support-forum", guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.guild = SimpleNamespace(name=guild_name)
|
||||
self.type = 15
|
||||
self.topic = None
|
||||
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self, channel_id: int = 1, name: str = "thread", parent=None, guild_name: str = "Hermes Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
self.parent_id = getattr(parent, "id", None)
|
||||
self.guild = getattr(parent, "guild", None) or SimpleNamespace(name=guild_name)
|
||||
self.topic = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
||||
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
|
||||
monkeypatch.setattr(discord_platform.discord, "ForumChannel", FakeForumChannel, raising=False)
|
||||
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter = DiscordAdapter(config)
|
||||
adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
|
||||
adapter.handle_message = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def make_message(*, channel, content: str, mentions=None):
|
||||
author = SimpleNamespace(id=42, display_name="Jezza", name="Jezza")
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
content=content,
|
||||
mentions=list(mentions or []),
|
||||
attachments=[],
|
||||
reference=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=channel,
|
||||
author=author,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_defaults_to_require_mention(adapter, monkeypatch):
|
||||
"""Default behavior: require @mention in server channels."""
|
||||
monkeypatch.delenv("DISCORD_REQUIRE_MENTION", raising=False)
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello from channel")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
# Should be ignored — no mention, require_mention defaults to true
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_response_in_server_channels(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123), content="hello from channel")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "hello from channel"
|
||||
assert event.source.chat_id == "123"
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_response_in_threads(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
thread = FakeThread(channel_id=456, name="Ghost reader skill")
|
||||
message = make_message(channel=thread, content="hello from thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "hello from thread"
|
||||
assert event.source.chat_id == "456"
|
||||
assert event.source.thread_id == "456"
|
||||
assert event.source.chat_type == "thread"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_forum_threads_are_handled_as_threads(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
forum = FakeForumChannel(channel_id=222, name="support-forum")
|
||||
thread = FakeThread(channel_id=456, name="Can Hermes reply here?", parent=forum)
|
||||
message = make_message(channel=thread, content="hello from forum post")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "hello from forum post"
|
||||
assert event.source.chat_id == "456"
|
||||
assert event.source.thread_id == "456"
|
||||
assert event.source.chat_type == "thread"
|
||||
assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=789), content="ignored without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_response_channel_overrides_mention_requirement(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "789,999")
|
||||
|
||||
message = make_message(channel=FakeTextChannel(channel_id=789), content="allowed without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "allowed without mention"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_forum_parent_in_free_response_list_allows_forum_thread(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "222")
|
||||
|
||||
forum = FakeForumChannel(channel_id=222, name="support-forum")
|
||||
thread = FakeThread(channel_id=333, name="Forum topic", parent=forum)
|
||||
message = make_message(channel=thread, content="allowed from forum thread")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "allowed from forum thread"
|
||||
assert event.source.chat_id == "333"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_accepts_and_strips_bot_mentions_when_required(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
bot_user = adapter._client.user
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=321),
|
||||
content=f"<@{bot_user.id}> hello with mention",
|
||||
mentions=[bot_user],
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "hello with mention"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_dms_ignore_mention_requirement(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
message = make_message(channel=FakeDMChannel(channel_id=654), content="dm without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "dm without mention"
|
||||
assert event.source.chat_type == "dm"
|
||||
124
tests/gateway/test_interrupt_key_match.py
Normal file
124
tests/gateway/test_interrupt_key_match.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Tests verifying interrupt key consistency between adapter and gateway.
|
||||
|
||||
Regression test for a bug where monitor_for_interrupt() in _run_agent used
|
||||
source.chat_id to query the adapter, but the adapter stores interrupts under
|
||||
the full session key (build_session_key output). This mismatch meant
|
||||
interrupts were never detected, causing subagents to ignore new messages.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
"""Minimal adapter for interrupt tests."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self):
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _source(chat_id="123456", chat_type="dm", thread_id=None):
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type=chat_type,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
|
||||
class TestInterruptKeyConsistency:
|
||||
"""Ensure adapter interrupt methods are queried with session_key, not chat_id."""
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_dm(self):
|
||||
"""Session key for a DM is NOT the same as chat_id."""
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
assert session_key != source.chat_id
|
||||
assert session_key == "agent:main:telegram:dm"
|
||||
|
||||
def test_session_key_differs_from_chat_id_for_group(self):
|
||||
"""Session key for a group chat includes prefix, unlike raw chat_id."""
|
||||
source = _source("-1001234", "group")
|
||||
session_key = build_session_key(source)
|
||||
assert session_key != source.chat_id
|
||||
assert "agent:main:" in session_key
|
||||
assert source.chat_id in session_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_pending_interrupt_requires_session_key(self):
|
||||
"""has_pending_interrupt returns True only when queried with session_key."""
|
||||
adapter = StubAdapter()
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
# Simulate adapter storing interrupt under session_key
|
||||
interrupt_event = asyncio.Event()
|
||||
adapter._active_sessions[session_key] = interrupt_event
|
||||
interrupt_event.set()
|
||||
|
||||
# Using session_key → found
|
||||
assert adapter.has_pending_interrupt(session_key) is True
|
||||
|
||||
# Using chat_id → NOT found (this was the bug)
|
||||
assert adapter.has_pending_interrupt(source.chat_id) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_message_requires_session_key(self):
|
||||
"""get_pending_message returns the event only with session_key."""
|
||||
adapter = StubAdapter()
|
||||
source = _source("123456", "dm")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
event = MessageEvent(text="hello", source=source, message_id="42")
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
# Using chat_id → None (the bug)
|
||||
assert adapter.get_pending_message(source.chat_id) is None
|
||||
|
||||
# Using session_key → found
|
||||
result = adapter.get_pending_message(session_key)
|
||||
assert result is event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_stores_under_session_key(self):
|
||||
"""handle_message stores pending messages under session_key, not chat_id."""
|
||||
adapter = StubAdapter()
|
||||
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
|
||||
|
||||
source = _source("-1001234", "group")
|
||||
session_key = build_session_key(source)
|
||||
|
||||
# Mark session as active
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Send a second message while session is active
|
||||
event = MessageEvent(text="interrupt!", source=source, message_id="2")
|
||||
await adapter.handle_message(event)
|
||||
|
||||
# Stored under session_key
|
||||
assert session_key in adapter._pending_messages
|
||||
# NOT stored under chat_id
|
||||
assert source.chat_id not in adapter._pending_messages
|
||||
|
||||
# Interrupt event was set
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
340
tests/hermes_cli/test_claw.py
Normal file
340
tests/hermes_cli/test_claw.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
"""Tests for hermes claw commands."""
|
||||
|
||||
from argparse import Namespace
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli import claw as claw_mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_migration_script
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindMigrationScript:
|
||||
"""Test script discovery in known locations."""
|
||||
|
||||
def test_finds_project_root_script(self, tmp_path):
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
with patch.object(claw_mod, "_OPENCLAW_SCRIPT", script):
|
||||
assert claw_mod._find_migration_script() == script
|
||||
|
||||
def test_finds_installed_script(self, tmp_path):
|
||||
installed = tmp_path / "installed.py"
|
||||
installed.write_text("# placeholder")
|
||||
with (
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT", tmp_path / "nonexistent.py"),
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT_INSTALLED", installed),
|
||||
):
|
||||
assert claw_mod._find_migration_script() == installed
|
||||
|
||||
def test_returns_none_when_missing(self, tmp_path):
|
||||
with (
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT", tmp_path / "a.py"),
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT_INSTALLED", tmp_path / "b.py"),
|
||||
):
|
||||
assert claw_mod._find_migration_script() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# claw_command routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClawCommand:
|
||||
"""Test the claw_command router."""
|
||||
|
||||
def test_routes_to_migrate(self):
|
||||
args = Namespace(claw_action="migrate", source=None, dry_run=True,
|
||||
preset="full", overwrite=False, migrate_secrets=False,
|
||||
workspace_target=None, skill_conflict="skip", yes=False)
|
||||
with patch.object(claw_mod, "_cmd_migrate") as mock:
|
||||
claw_mod.claw_command(args)
|
||||
mock.assert_called_once_with(args)
|
||||
|
||||
def test_shows_help_for_no_action(self, capsys):
|
||||
args = Namespace(claw_action=None)
|
||||
claw_mod.claw_command(args)
|
||||
captured = capsys.readouterr()
|
||||
assert "migrate" in captured.out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cmd_migrate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCmdMigrate:
|
||||
"""Test the migrate command handler."""
|
||||
|
||||
def test_error_when_source_missing(self, tmp_path, capsys):
|
||||
args = Namespace(
|
||||
source=str(tmp_path / "nonexistent"),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
claw_mod._cmd_migrate(args)
|
||||
captured = capsys.readouterr()
|
||||
assert "not found" in captured.out
|
||||
|
||||
def test_error_when_script_missing(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
with (
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT", tmp_path / "a.py"),
|
||||
patch.object(claw_mod, "_OPENCLAW_SCRIPT_INSTALLED", tmp_path / "b.py"),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration script not found" in captured.out
|
||||
|
||||
def test_dry_run_succeeds(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
script = tmp_path / "script.py"
|
||||
script.write_text("# placeholder")
|
||||
|
||||
# Build a fake migration module
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value={"soul", "memory"})
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 0, "skipped": 5, "conflict": 0, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "skipped", "reason": "Not found"},
|
||||
],
|
||||
"preset": "full",
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=script),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Dry Run Results" in captured.out
|
||||
assert "5 skipped" in captured.out
|
||||
|
||||
def test_execute_with_confirmation(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("agent:\n max_turns: 90\n")
|
||||
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value={"soul"})
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 2, "skipped": 1, "conflict": 0, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "migrated", "destination": str(tmp_path / "SOUL.md")},
|
||||
{"kind": "memory", "status": "migrated", "destination": str(tmp_path / "memories/MEMORY.md")},
|
||||
],
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=False, preset="user-data", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=True),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration Results" in captured.out
|
||||
assert "Migration complete!" in captured.out
|
||||
|
||||
def test_execute_cancelled_by_user(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("")
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=False, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=False),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration cancelled" in captured.out
|
||||
|
||||
def test_execute_with_yes_skips_confirmation(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("")
|
||||
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value=set())
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 0, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [],
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=False, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "prompt_yes_no") as mock_prompt,
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
mock_prompt.assert_not_called()
|
||||
|
||||
def test_handles_migration_error(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("")
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", side_effect=RuntimeError("boom")),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration failed" in captured.out
|
||||
|
||||
def test_full_preset_enables_secrets(self, tmp_path, capsys):
|
||||
"""The 'full' preset should set migrate_secrets=True automatically."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value=set())
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 0, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [],
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, # Not explicitly set by user
|
||||
workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
# Migrator should have been called with migrate_secrets=True
|
||||
call_kwargs = fake_mod.Migrator.call_args[1]
|
||||
assert call_kwargs["migrate_secrets"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _print_migration_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPrintMigrationReport:
|
||||
"""Test the report formatting function."""
|
||||
|
||||
def test_dry_run_report(self, capsys):
|
||||
report = {
|
||||
"summary": {"migrated": 2, "skipped": 1, "conflict": 1, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "migrated", "destination": "/home/user/.hermes/SOUL.md"},
|
||||
{"kind": "memory", "status": "migrated", "destination": "/home/user/.hermes/memories/MEMORY.md"},
|
||||
{"kind": "skills", "status": "conflict", "reason": "already exists"},
|
||||
{"kind": "tts-assets", "status": "skipped", "reason": "not found"},
|
||||
],
|
||||
"preset": "full",
|
||||
}
|
||||
claw_mod._print_migration_report(report, dry_run=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "Dry Run Results" in captured.out
|
||||
assert "Would migrate" in captured.out
|
||||
assert "2 would migrate" in captured.out
|
||||
assert "--dry-run" in captured.out
|
||||
|
||||
def test_execute_report(self, capsys):
|
||||
report = {
|
||||
"summary": {"migrated": 3, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "migrated", "destination": "/home/user/.hermes/SOUL.md"},
|
||||
],
|
||||
"output_dir": "/home/user/.hermes/migration/openclaw/20250312T120000",
|
||||
}
|
||||
claw_mod._print_migration_report(report, dry_run=False)
|
||||
captured = capsys.readouterr()
|
||||
assert "Migration Results" in captured.out
|
||||
assert "Migrated" in captured.out
|
||||
assert "Full report saved to" in captured.out
|
||||
|
||||
def test_empty_report(self, capsys):
|
||||
report = {
|
||||
"summary": {"migrated": 0, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [],
|
||||
}
|
||||
claw_mod._print_migration_report(report, dry_run=False)
|
||||
captured = capsys.readouterr()
|
||||
assert "Nothing to migrate" in captured.out
|
||||
97
tests/hermes_cli/test_setup.py
Normal file
97
tests/hermes_cli/test_setup.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
import json
|
||||
|
||||
from hermes_cli.auth import _update_config_for_provider, get_active_provider
|
||||
from hermes_cli.config import load_config, save_config
|
||||
from hermes_cli.setup import setup_model_provider
|
||||
|
||||
|
||||
def _clear_provider_env(monkeypatch):
|
||||
for key in (
|
||||
"NOUS_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_BASE_URL",
|
||||
"OPENAI_API_KEY",
|
||||
"LLM_MODEL",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
|
||||
def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
prompt_choices = iter([0, 2])
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt_choice",
|
||||
lambda *args, **kwargs: next(prompt_choices),
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
|
||||
def _fake_login_nous(*args, **kwargs):
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path.write_text(json.dumps({"active_provider": "nous", "providers": {}}))
|
||||
_update_config_for_provider("nous", "https://inference.example.com/v1")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth._login_nous", _fake_login_nous)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_nous_runtime_credentials",
|
||||
lambda *args, **kwargs: {
|
||||
"base_url": "https://inference.example.com/v1",
|
||||
"api_key": "nous-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.fetch_nous_models",
|
||||
lambda *args, **kwargs: ["gemini-3-flash"],
|
||||
)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "nous"
|
||||
assert reloaded["model"]["base_url"] == "https://inference.example.com/v1"
|
||||
assert reloaded["model"]["default"] == "anthropic/claude-opus-4.6"
|
||||
|
||||
|
||||
def test_custom_setup_clears_active_oauth_provider(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path.write_text(json.dumps({"active_provider": "nous", "providers": {}}))
|
||||
|
||||
config = load_config()
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: 3)
|
||||
|
||||
prompt_values = iter(
|
||||
[
|
||||
"https://custom.example/v1",
|
||||
"custom-api-key",
|
||||
"custom/model",
|
||||
"",
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt",
|
||||
lambda *args, **kwargs: next(prompt_values),
|
||||
)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
|
||||
assert get_active_provider() is None
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "custom"
|
||||
assert reloaded["model"]["base_url"] == "https://custom.example/v1"
|
||||
assert reloaded["model"]["default"] == "custom/model"
|
||||
284
tests/hermes_cli/test_setup_openclaw_migration.py
Normal file
284
tests/hermes_cli/test_setup_openclaw_migration.py
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
"""Tests for OpenClaw migration integration in the setup wizard."""
|
||||
|
||||
from argparse import Namespace
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from hermes_cli import setup as setup_mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _offer_openclaw_migration — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOfferOpenclawMigration:
|
||||
"""Test the _offer_openclaw_migration helper in isolation."""
|
||||
|
||||
def test_skips_when_no_openclaw_dir(self, tmp_path):
|
||||
"""Should return False immediately when ~/.openclaw does not exist."""
|
||||
with patch("hermes_cli.setup.Path.home", return_value=tmp_path):
|
||||
assert setup_mod._offer_openclaw_migration(tmp_path / ".hermes") is False
|
||||
|
||||
def test_skips_when_migration_script_missing(self, tmp_path):
|
||||
"""Should return False when the migration script file is absent."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", tmp_path / "nonexistent.py"),
|
||||
):
|
||||
assert setup_mod._offer_openclaw_migration(tmp_path / ".hermes") is False
|
||||
|
||||
def test_skips_when_user_declines(self, tmp_path):
|
||||
"""Should return False when user declines the migration prompt."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
):
|
||||
assert setup_mod._offer_openclaw_migration(tmp_path / ".hermes") is False
|
||||
|
||||
def test_runs_migration_when_user_accepts(self, tmp_path):
|
||||
"""Should dynamically load the script and run the Migrator."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
|
||||
# Create a fake hermes home with config
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text("agent:\n max_turns: 90\n")
|
||||
|
||||
# Build a fake migration module
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value={"soul", "memory"})
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 3, "skipped": 1, "conflict": 0, "error": 0},
|
||||
"output_dir": str(hermes_home / "migration"),
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
patch.object(setup_mod, "get_config_path", return_value=config_path),
|
||||
patch("importlib.util.spec_from_file_location") as mock_spec_fn,
|
||||
):
|
||||
# Wire up the fake module loading
|
||||
mock_spec = MagicMock()
|
||||
mock_spec.loader = MagicMock()
|
||||
mock_spec_fn.return_value = mock_spec
|
||||
|
||||
def exec_module(mod):
|
||||
mod.resolve_selected_options = fake_mod.resolve_selected_options
|
||||
mod.Migrator = fake_mod.Migrator
|
||||
|
||||
mock_spec.loader.exec_module = exec_module
|
||||
|
||||
result = setup_mod._offer_openclaw_migration(hermes_home)
|
||||
|
||||
assert result is True
|
||||
fake_mod.resolve_selected_options.assert_called_once_with(
|
||||
None, None, preset="full"
|
||||
)
|
||||
fake_mod.Migrator.assert_called_once()
|
||||
call_kwargs = fake_mod.Migrator.call_args[1]
|
||||
assert call_kwargs["execute"] is True
|
||||
assert call_kwargs["overwrite"] is False
|
||||
assert call_kwargs["migrate_secrets"] is True
|
||||
assert call_kwargs["preset_name"] == "full"
|
||||
fake_migrator.migrate.assert_called_once()
|
||||
|
||||
def test_handles_migration_error_gracefully(self, tmp_path):
|
||||
"""Should catch exceptions and return False."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text("")
|
||||
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
patch.object(setup_mod, "get_config_path", return_value=config_path),
|
||||
patch(
|
||||
"importlib.util.spec_from_file_location",
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
):
|
||||
result = setup_mod._offer_openclaw_migration(hermes_home)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_creates_config_if_missing(self, tmp_path):
|
||||
"""Should bootstrap config.yaml before running migration."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
# config does NOT exist yet
|
||||
|
||||
script = tmp_path / "openclaw_to_hermes.py"
|
||||
script.write_text("# placeholder")
|
||||
|
||||
with (
|
||||
patch("hermes_cli.setup.Path.home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "_OPENCLAW_SCRIPT", script),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
patch.object(setup_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(setup_mod, "load_config", return_value={"agent": {}}),
|
||||
patch.object(setup_mod, "save_config") as mock_save,
|
||||
patch(
|
||||
"importlib.util.spec_from_file_location",
|
||||
side_effect=RuntimeError("stop early"),
|
||||
),
|
||||
):
|
||||
setup_mod._offer_openclaw_migration(hermes_home)
|
||||
|
||||
# save_config should have been called to bootstrap the file
|
||||
mock_save.assert_called_once_with({"agent": {}})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration with run_setup_wizard — first-time flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _first_time_args() -> Namespace:
|
||||
return Namespace(
|
||||
section=None,
|
||||
non_interactive=False,
|
||||
reset=False,
|
||||
)
|
||||
|
||||
|
||||
class TestSetupWizardOpenclawIntegration:
|
||||
"""Verify _offer_openclaw_migration is called during first-time setup."""
|
||||
|
||||
def test_migration_offered_during_first_time_setup(self, tmp_path):
|
||||
"""On first-time setup, _offer_openclaw_migration should be called."""
|
||||
args = _first_time_args()
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(setup_mod, "load_config", return_value={}),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", return_value=""),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
# User presses Enter to start
|
||||
patch("builtins.input", return_value=""),
|
||||
# Mock the migration offer
|
||||
patch.object(
|
||||
setup_mod, "_offer_openclaw_migration", return_value=False
|
||||
) as mock_migration,
|
||||
# Mock the actual setup sections so they don't run
|
||||
patch.object(setup_mod, "setup_model_provider"),
|
||||
patch.object(setup_mod, "setup_terminal_backend"),
|
||||
patch.object(setup_mod, "setup_agent_settings"),
|
||||
patch.object(setup_mod, "setup_gateway"),
|
||||
patch.object(setup_mod, "setup_tools"),
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
mock_migration.assert_called_once_with(tmp_path)
|
||||
|
||||
def test_migration_reloads_config_on_success(self, tmp_path):
|
||||
"""When migration returns True, config should be reloaded."""
|
||||
args = _first_time_args()
|
||||
call_order = []
|
||||
|
||||
def tracking_load_config():
|
||||
call_order.append("load_config")
|
||||
return {}
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(setup_mod, "load_config", side_effect=tracking_load_config),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", return_value=""),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
patch("builtins.input", return_value=""),
|
||||
patch.object(setup_mod, "_offer_openclaw_migration", return_value=True),
|
||||
patch.object(setup_mod, "setup_model_provider"),
|
||||
patch.object(setup_mod, "setup_terminal_backend"),
|
||||
patch.object(setup_mod, "setup_agent_settings"),
|
||||
patch.object(setup_mod, "setup_gateway"),
|
||||
patch.object(setup_mod, "setup_tools"),
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
# load_config called twice: once at start, once after migration
|
||||
assert call_order.count("load_config") == 2
|
||||
|
||||
def test_reloaded_config_flows_into_remaining_setup_sections(self, tmp_path):
|
||||
args = _first_time_args()
|
||||
initial_config = {}
|
||||
reloaded_config = {"model": {"provider": "openrouter"}}
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(
|
||||
setup_mod,
|
||||
"load_config",
|
||||
side_effect=[initial_config, reloaded_config],
|
||||
),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", return_value=""),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
patch("builtins.input", return_value=""),
|
||||
patch.object(setup_mod, "_offer_openclaw_migration", return_value=True),
|
||||
patch.object(setup_mod, "setup_model_provider") as setup_model_provider,
|
||||
patch.object(setup_mod, "setup_terminal_backend"),
|
||||
patch.object(setup_mod, "setup_agent_settings"),
|
||||
patch.object(setup_mod, "setup_gateway"),
|
||||
patch.object(setup_mod, "setup_tools"),
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
setup_model_provider.assert_called_once_with(reloaded_config)
|
||||
|
||||
def test_migration_not_offered_for_existing_install(self, tmp_path):
|
||||
"""Returning users should not see the migration prompt."""
|
||||
args = _first_time_args()
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(setup_mod, "load_config", return_value={}),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(
|
||||
setup_mod,
|
||||
"get_env_value",
|
||||
side_effect=lambda k: "sk-xxx" if k == "OPENROUTER_API_KEY" else "",
|
||||
),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
# Returning user picks "Exit"
|
||||
patch.object(setup_mod, "prompt_choice", return_value=9),
|
||||
patch.object(
|
||||
setup_mod, "_offer_openclaw_migration", return_value=False
|
||||
) as mock_migration,
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
mock_migration.assert_not_called()
|
||||
|
|
@ -1,13 +1,23 @@
|
|||
from io import StringIO
|
||||
|
||||
import pytest
|
||||
from rich.console import Console
|
||||
|
||||
from hermes_cli.skills_hub import do_list
|
||||
|
||||
|
||||
def test_do_list_initializes_hub_dir(monkeypatch, tmp_path):
|
||||
class _DummyLockFile:
|
||||
def __init__(self, installed):
|
||||
self._installed = installed
|
||||
|
||||
def list_installed(self):
|
||||
return self._installed
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def hub_env(monkeypatch, tmp_path):
|
||||
"""Set up isolated hub directory paths and return (monkeypatch, tmp_path)."""
|
||||
import tools.skills_hub as hub
|
||||
import tools.skills_tool as skills_tool
|
||||
|
||||
hub_dir = tmp_path / "skills" / ".hub"
|
||||
monkeypatch.setattr(hub, "SKILLS_DIR", tmp_path / "skills")
|
||||
|
|
@ -17,15 +27,98 @@ def test_do_list_initializes_hub_dir(monkeypatch, tmp_path):
|
|||
monkeypatch.setattr(hub, "AUDIT_LOG", hub_dir / "audit.log")
|
||||
monkeypatch.setattr(hub, "TAPS_FILE", hub_dir / "taps.json")
|
||||
monkeypatch.setattr(hub, "INDEX_CACHE_DIR", hub_dir / "index-cache")
|
||||
|
||||
return hub_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures for common skill setups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HUB_ENTRY = {"name": "hub-skill", "source": "github", "trust_level": "community"}
|
||||
|
||||
_ALL_THREE_SKILLS = [
|
||||
{"name": "hub-skill", "category": "x", "description": "hub"},
|
||||
{"name": "builtin-skill", "category": "x", "description": "builtin"},
|
||||
{"name": "local-skill", "category": "x", "description": "local"},
|
||||
]
|
||||
|
||||
_BUILTIN_MANIFEST = {"builtin-skill": "abc123"}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def three_source_env(monkeypatch, hub_env):
|
||||
"""Populate hub/builtin/local skills for source-classification tests."""
|
||||
import tools.skills_hub as hub
|
||||
import tools.skills_sync as skills_sync
|
||||
import tools.skills_tool as skills_tool
|
||||
|
||||
monkeypatch.setattr(hub, "HubLockFile", lambda: _DummyLockFile([_HUB_ENTRY]))
|
||||
monkeypatch.setattr(skills_tool, "_find_all_skills", lambda: list(_ALL_THREE_SKILLS))
|
||||
monkeypatch.setattr(skills_sync, "_read_manifest", lambda: dict(_BUILTIN_MANIFEST))
|
||||
|
||||
return hub_env
|
||||
|
||||
|
||||
def _capture(source_filter: str = "all") -> str:
|
||||
"""Run do_list into a string buffer and return the output."""
|
||||
sink = StringIO()
|
||||
console = Console(file=sink, force_terminal=False, color_system=None)
|
||||
do_list(source_filter=source_filter, console=console)
|
||||
return sink.getvalue()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_do_list_initializes_hub_dir(monkeypatch, hub_env):
|
||||
import tools.skills_sync as skills_sync
|
||||
import tools.skills_tool as skills_tool
|
||||
|
||||
monkeypatch.setattr(skills_tool, "_find_all_skills", lambda: [])
|
||||
monkeypatch.setattr(skills_sync, "_read_manifest", lambda: {})
|
||||
|
||||
console = Console(file=StringIO(), force_terminal=False, color_system=None)
|
||||
|
||||
hub_dir = hub_env
|
||||
assert not hub_dir.exists()
|
||||
|
||||
do_list(console=console)
|
||||
_capture()
|
||||
|
||||
assert hub_dir.exists()
|
||||
assert (hub_dir / "lock.json").exists()
|
||||
assert (hub_dir / "quarantine").is_dir()
|
||||
assert (hub_dir / "index-cache").is_dir()
|
||||
|
||||
|
||||
def test_do_list_distinguishes_hub_builtin_and_local(three_source_env):
|
||||
output = _capture()
|
||||
|
||||
assert "hub-skill" in output
|
||||
assert "builtin-skill" in output
|
||||
assert "local-skill" in output
|
||||
assert "1 hub-installed, 1 builtin, 1 local" in output
|
||||
|
||||
|
||||
def test_do_list_filter_local(three_source_env):
|
||||
output = _capture(source_filter="local")
|
||||
|
||||
assert "local-skill" in output
|
||||
assert "builtin-skill" not in output
|
||||
assert "hub-skill" not in output
|
||||
|
||||
|
||||
def test_do_list_filter_hub(three_source_env):
|
||||
output = _capture(source_filter="hub")
|
||||
|
||||
assert "hub-skill" in output
|
||||
assert "builtin-skill" not in output
|
||||
assert "local-skill" not in output
|
||||
|
||||
|
||||
def test_do_list_filter_builtin(three_source_env):
|
||||
output = _capture(source_filter="builtin")
|
||||
|
||||
assert "builtin-skill" in output
|
||||
assert "hub-skill" not in output
|
||||
assert "local-skill" not in output
|
||||
|
|
|
|||
|
|
@ -579,7 +579,7 @@ class WebToolsTester:
|
|||
"results": self.test_results,
|
||||
"environment": {
|
||||
"firecrawl_api_key": check_firecrawl_api_key(),
|
||||
"nous_api_key": check_auxiliary_model(),
|
||||
"auxiliary_model": check_auxiliary_model(),
|
||||
"debug_mode": get_debug_session_info()["enabled"]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
141
tests/run_interrupt_test.py
Normal file
141
tests/run_interrupt_test.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Run a real interrupt test with actual AIAgent + delegate child.
|
||||
|
||||
Not a pytest test — runs directly as a script for live testing.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from run_agent import AIAgent, IterationBudget
|
||||
from tools.delegate_tool import _run_single_child
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
|
||||
set_interrupt(False)
|
||||
|
||||
# Create parent agent (minimal)
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
|
||||
child_started = threading.Event()
|
||||
result_holder = [None]
|
||||
|
||||
|
||||
def run_delegate():
|
||||
with patch("run_agent.OpenAI") as MockOpenAI:
|
||||
mock_client = MagicMock()
|
||||
|
||||
def slow_create(**kwargs):
|
||||
time.sleep(3)
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "Done"
|
||||
resp.choices[0].message.tool_calls = None
|
||||
resp.choices[0].message.refusal = None
|
||||
resp.choices[0].finish_reason = "stop"
|
||||
resp.usage.prompt_tokens = 100
|
||||
resp.usage.completion_tokens = 10
|
||||
resp.usage.total_tokens = 110
|
||||
resp.usage.prompt_tokens_details = None
|
||||
return resp
|
||||
|
||||
mock_client.chat.completions.create = slow_create
|
||||
mock_client.close = MagicMock()
|
||||
MockOpenAI.return_value = mock_client
|
||||
|
||||
original_init = AIAgent.__init__
|
||||
|
||||
def patched_init(self_agent, *a, **kw):
|
||||
original_init(self_agent, *a, **kw)
|
||||
child_started.set()
|
||||
|
||||
with patch.object(AIAgent, "__init__", patched_init):
|
||||
try:
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Test slow task",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model="test/model",
|
||||
max_iterations=5,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
override_provider="test",
|
||||
override_base_url="http://localhost:1",
|
||||
override_api_key="test",
|
||||
override_api_mode="chat_completions",
|
||||
)
|
||||
result_holder[0] = result
|
||||
except Exception as e:
|
||||
print(f"ERROR in delegate: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
print("Starting agent thread...")
|
||||
agent_thread = threading.Thread(target=run_delegate, daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
started = child_started.wait(timeout=10)
|
||||
if not started:
|
||||
print("ERROR: Child never started")
|
||||
sys.exit(1)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
print(f"Active children: {len(parent._active_children)}")
|
||||
for i, c in enumerate(parent._active_children):
|
||||
print(f" Child {i}: _interrupt_requested={c._interrupt_requested}")
|
||||
|
||||
t0 = time.monotonic()
|
||||
parent.interrupt("User typed a new message")
|
||||
print(f"Called parent.interrupt()")
|
||||
|
||||
for i, c in enumerate(parent._active_children):
|
||||
print(f" Child {i} after interrupt: _interrupt_requested={c._interrupt_requested}")
|
||||
print(f"Global is_interrupted: {is_interrupted()}")
|
||||
|
||||
agent_thread.join(timeout=10)
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"Agent thread finished in {elapsed:.2f}s")
|
||||
|
||||
result = result_holder[0]
|
||||
if result:
|
||||
print(f"Status: {result['status']}")
|
||||
print(f"Duration: {result['duration_seconds']}s")
|
||||
if elapsed < 2.0:
|
||||
print("✅ PASS: Interrupt detected quickly!")
|
||||
else:
|
||||
print(f"❌ FAIL: Took {elapsed:.2f}s — interrupt was too slow or not detected")
|
||||
else:
|
||||
print("❌ FAIL: No result!")
|
||||
|
||||
set_interrupt(False)
|
||||
|
|
@ -6,6 +6,11 @@ Verifies that:
|
|||
- Preflight compression proactively compresses oversized sessions before API calls
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.skip(reason="Hangs in non-interactive environments")
|
||||
|
||||
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.skip(reason="Live API integration test — hangs in batch runs")
|
||||
|
||||
# Ensure repo root is importable
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
|
|
|
|||
|
|
@ -229,13 +229,14 @@ class TestVisionModelOverride:
|
|||
|
||||
def test_default_model_when_no_override(self, monkeypatch):
|
||||
monkeypatch.delenv("AUXILIARY_VISION_MODEL", raising=False)
|
||||
from tools.vision_tools import _handle_vision_analyze, DEFAULT_VISION_MODEL
|
||||
from tools.vision_tools import _handle_vision_analyze
|
||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=MagicMock) as mock_tool:
|
||||
mock_tool.return_value = '{"success": true}'
|
||||
_handle_vision_analyze({"image_url": "http://test.jpg", "question": "test"})
|
||||
call_args = mock_tool.call_args
|
||||
expected = DEFAULT_VISION_MODEL or "google/gemini-3-flash-preview"
|
||||
assert call_args[0][2] == expected
|
||||
# With no AUXILIARY_VISION_MODEL env var, model should be None
|
||||
# (the centralized call_llm router picks the provider default)
|
||||
assert call_args[0][2] is None
|
||||
|
||||
|
||||
# ── DEFAULT_CONFIG shape tests ───────────────────────────────────────────────
|
||||
|
|
|
|||
171
tests/test_cli_interrupt_subagent.py
Normal file
171
tests/test_cli_interrupt_subagent.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
"""End-to-end test simulating CLI interrupt during subagent execution.
|
||||
|
||||
Reproduces the exact scenario:
|
||||
1. Parent agent calls delegate_task
|
||||
2. Child agent is running (simulated with a slow tool)
|
||||
3. User "types a message" (simulated by calling parent.interrupt from another thread)
|
||||
4. Child should detect the interrupt and stop
|
||||
|
||||
This tests the COMPLETE path including _run_single_child, _active_children
|
||||
registration, interrupt propagation, and child detection.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
|
||||
|
||||
class TestCLISubagentInterrupt(unittest.TestCase):
|
||||
"""Simulate exact CLI scenario."""
|
||||
|
||||
def setUp(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def tearDown(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def test_full_delegate_interrupt_flow(self):
|
||||
"""Full integration: parent runs delegate_task, main thread interrupts."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
interrupt_detected = threading.Event()
|
||||
child_started = threading.Event()
|
||||
child_api_call_count = 0
|
||||
|
||||
# Create a real-enough parent agent
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
# We'll track what happens with _active_children
|
||||
original_children = parent._active_children
|
||||
|
||||
# Mock the child's run_conversation to simulate a slow operation
|
||||
# that checks _interrupt_requested like the real one does
|
||||
def mock_child_run_conversation(user_message, **kwargs):
|
||||
child_started.set()
|
||||
# Find the child in parent._active_children
|
||||
child = parent._active_children[-1] if parent._active_children else None
|
||||
|
||||
# Simulate the agent loop: poll _interrupt_requested like run_conversation does
|
||||
for i in range(100): # Up to 10 seconds (100 * 0.1s)
|
||||
if child and child._interrupt_requested:
|
||||
interrupt_detected.set()
|
||||
return {
|
||||
"final_response": "Interrupted!",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
"interrupt_message": child._interrupt_message,
|
||||
}
|
||||
time.sleep(0.1)
|
||||
|
||||
return {
|
||||
"final_response": "Finished without interrupt",
|
||||
"messages": [],
|
||||
"api_calls": 5,
|
||||
"completed": True,
|
||||
"interrupted": False,
|
||||
}
|
||||
|
||||
# Patch AIAgent to use our mock
|
||||
from tools.delegate_tool import _run_single_child
|
||||
from run_agent import IterationBudget
|
||||
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
|
||||
# Run delegate in a thread (simulates agent_thread)
|
||||
delegate_result = [None]
|
||||
delegate_error = [None]
|
||||
|
||||
def run_delegate():
|
||||
try:
|
||||
with patch('run_agent.AIAgent') as MockAgent:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance._interrupt_requested = False
|
||||
mock_instance._interrupt_message = None
|
||||
mock_instance._active_children = []
|
||||
mock_instance.quiet_mode = True
|
||||
mock_instance.run_conversation = mock_child_run_conversation
|
||||
mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg)
|
||||
mock_instance.tools = []
|
||||
MockAgent.return_value = mock_instance
|
||||
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Do something slow",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model=None,
|
||||
max_iterations=50,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
)
|
||||
delegate_result[0] = result
|
||||
except Exception as e:
|
||||
delegate_error[0] = e
|
||||
|
||||
agent_thread = threading.Thread(target=run_delegate, daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
# Wait for child to start
|
||||
assert child_started.wait(timeout=5), "Child never started!"
|
||||
|
||||
# Now simulate user interrupt (from main/process thread)
|
||||
time.sleep(0.2) # Give child a moment to be in its loop
|
||||
|
||||
print(f"Parent has {len(parent._active_children)} active children")
|
||||
assert len(parent._active_children) >= 1, f"Expected child in _active_children, got {len(parent._active_children)}"
|
||||
|
||||
# This is what the CLI does:
|
||||
parent.interrupt("Hey stop that")
|
||||
|
||||
print(f"Parent._interrupt_requested: {parent._interrupt_requested}")
|
||||
for i, child in enumerate(parent._active_children):
|
||||
print(f"Child {i}._interrupt_requested: {child._interrupt_requested}")
|
||||
|
||||
# Wait for child to detect interrupt
|
||||
detected = interrupt_detected.wait(timeout=3.0)
|
||||
|
||||
# Wait for delegate to finish
|
||||
agent_thread.join(timeout=5)
|
||||
|
||||
if delegate_error[0]:
|
||||
raise delegate_error[0]
|
||||
|
||||
assert detected, "Child never detected the interrupt!"
|
||||
result = delegate_result[0]
|
||||
assert result is not None, "Delegate returned no result"
|
||||
assert result["status"] == "interrupted", f"Expected 'interrupted', got '{result['status']}'"
|
||||
print(f"✓ Interrupt detected! Result: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -93,8 +93,8 @@ class TestModelCommand:
|
|||
output = capsys.readouterr().out
|
||||
assert "anthropic/claude-opus-4.6" in output
|
||||
assert "OpenRouter" in output
|
||||
assert "Available models" in output
|
||||
assert "provider:model-name" in output
|
||||
assert "Authenticated providers" in output or "Switch model" in output
|
||||
assert "provider" in output and "model" in output
|
||||
|
||||
# -- provider switching tests -------------------------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -197,21 +197,28 @@ def test_codex_provider_replaces_incompatible_default_model(monkeypatch):
|
|||
assert shell.model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
def test_codex_provider_trusts_explicit_envvar_model(monkeypatch):
|
||||
"""When the user explicitly sets LLM_MODEL, we trust their choice and
|
||||
let the API be the judge — even if it's a non-OpenAI model. Only
|
||||
provider prefixes are stripped; the bare model passes through."""
|
||||
def test_codex_provider_uses_config_model(monkeypatch):
|
||||
"""Model comes from config.yaml, not LLM_MODEL env var.
|
||||
Config.yaml is the single source of truth to avoid multi-agent conflicts."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.setenv("LLM_MODEL", "claude-opus-4-6")
|
||||
# LLM_MODEL env var should be IGNORED (even if set)
|
||||
monkeypatch.setenv("LLM_MODEL", "should-be-ignored")
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
# Set model via config
|
||||
monkeypatch.setitem(cli.CLI_CONFIG, "model", {
|
||||
"default": "gpt-5.2-codex",
|
||||
"provider": "openai-codex",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
})
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"api_key": "fake-codex-token",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
|
|
@ -220,11 +227,12 @@ def test_codex_provider_trusts_explicit_envvar_model(monkeypatch):
|
|||
|
||||
shell = cli.HermesCLI(compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is False
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.provider == "openai-codex"
|
||||
# User explicitly chose this model — it passes through untouched
|
||||
assert shell.model == "claude-opus-4-6"
|
||||
# Model from config (may be normalized by codex provider logic)
|
||||
assert "codex" in shell.model.lower()
|
||||
# LLM_MODEL env var is NOT used
|
||||
assert shell.model != "should-be-ignored"
|
||||
|
||||
|
||||
def test_codex_provider_preserves_explicit_codex_model(monkeypatch):
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def _make_agent(fallback_model=None):
|
|||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key-primary",
|
||||
api_key="test-key",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
|
|
@ -45,6 +45,14 @@ def _make_agent(fallback_model=None):
|
|||
return agent
|
||||
|
||||
|
||||
def _mock_resolve(base_url="https://openrouter.ai/api/v1", api_key="test-key"):
|
||||
"""Helper to create a mock client for resolve_provider_client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.api_key = api_key
|
||||
mock_client.base_url = base_url
|
||||
return mock_client
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _try_activate_fallback()
|
||||
# =============================================================================
|
||||
|
|
@ -71,9 +79,13 @@ class TestTryActivateFallback:
|
|||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"OPENROUTER_API_KEY": "sk-or-fallback-key"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-or-fallback-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "anthropic/claude-sonnet-4"),
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
|
|
@ -81,36 +93,37 @@ class TestTryActivateFallback:
|
|||
assert agent.model == "anthropic/claude-sonnet-4"
|
||||
assert agent.provider == "openrouter"
|
||||
assert agent.api_mode == "chat_completions"
|
||||
mock_openai.assert_called_once()
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "sk-or-fallback-key"
|
||||
assert "openrouter" in call_kwargs["base_url"].lower()
|
||||
# OpenRouter should get attribution headers
|
||||
assert "default_headers" in call_kwargs
|
||||
assert agent.client is mock_client
|
||||
|
||||
def test_activates_zai_fallback(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "zai", "model": "glm-5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"ZAI_API_KEY": "sk-zai-key"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-zai-key",
|
||||
base_url="https://open.z.ai/api/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "glm-5"),
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
assert agent.model == "glm-5"
|
||||
assert agent.provider == "zai"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "sk-zai-key"
|
||||
assert "z.ai" in call_kwargs["base_url"].lower()
|
||||
assert agent.client is mock_client
|
||||
|
||||
def test_activates_kimi_fallback(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "kimi-coding", "model": "kimi-k2.5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"KIMI_API_KEY": "sk-kimi-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-kimi-key",
|
||||
base_url="https://api.moonshot.ai/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "kimi-k2.5"),
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "kimi-k2.5"
|
||||
|
|
@ -120,23 +133,30 @@ class TestTryActivateFallback:
|
|||
agent = _make_agent(
|
||||
fallback_model={"provider": "minimax", "model": "MiniMax-M2.5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"MINIMAX_API_KEY": "sk-mm-key"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-mm-key",
|
||||
base_url="https://api.minimax.io/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "MiniMax-M2.5"),
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
assert agent.model == "MiniMax-M2.5"
|
||||
assert agent.provider == "minimax"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert "minimax.io" in call_kwargs["base_url"]
|
||||
assert agent.client is mock_client
|
||||
|
||||
def test_only_fires_once(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"OPENROUTER_API_KEY": "sk-or-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "anthropic/claude-sonnet-4"),
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
# Second attempt should return False
|
||||
|
|
@ -147,9 +167,10 @@ class TestTryActivateFallback:
|
|||
agent = _make_agent(
|
||||
fallback_model={"provider": "minimax", "model": "MiniMax-M2.5"},
|
||||
)
|
||||
# Ensure MINIMAX_API_KEY is not in the environment
|
||||
env = {k: v for k, v in os.environ.items() if k != "MINIMAX_API_KEY"}
|
||||
with patch.dict("os.environ", env, clear=True):
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(None, None),
|
||||
):
|
||||
assert agent._try_activate_fallback() is False
|
||||
assert agent._fallback_activated is False
|
||||
|
||||
|
|
@ -163,22 +184,29 @@ class TestTryActivateFallback:
|
|||
"api_key_env": "MY_CUSTOM_KEY",
|
||||
},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"MY_CUSTOM_KEY": "custom-secret"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
mock_client = _mock_resolve(
|
||||
api_key="custom-secret",
|
||||
base_url="http://localhost:8080/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "my-model"),
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["base_url"] == "http://localhost:8080/v1"
|
||||
assert call_kwargs["api_key"] == "custom-secret"
|
||||
assert agent.client is mock_client
|
||||
assert agent.model == "my-model"
|
||||
|
||||
def test_prompt_caching_enabled_for_claude_on_openrouter(self):
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "anthropic/claude-sonnet-4"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"OPENROUTER_API_KEY": "sk-or-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "anthropic/claude-sonnet-4"),
|
||||
):
|
||||
agent._try_activate_fallback()
|
||||
assert agent._use_prompt_caching is True
|
||||
|
|
@ -187,9 +215,13 @@ class TestTryActivateFallback:
|
|||
agent = _make_agent(
|
||||
fallback_model={"provider": "openrouter", "model": "google/gemini-2.5-flash"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"OPENROUTER_API_KEY": "sk-or-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "google/gemini-2.5-flash"),
|
||||
):
|
||||
agent._try_activate_fallback()
|
||||
assert agent._use_prompt_caching is False
|
||||
|
|
@ -198,9 +230,13 @@ class TestTryActivateFallback:
|
|||
agent = _make_agent(
|
||||
fallback_model={"provider": "zai", "model": "glm-5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"ZAI_API_KEY": "sk-zai-key"}),
|
||||
patch("run_agent.OpenAI"),
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-zai-key",
|
||||
base_url="https://open.z.ai/api/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "glm-5"),
|
||||
):
|
||||
agent._try_activate_fallback()
|
||||
assert agent._use_prompt_caching is False
|
||||
|
|
@ -210,35 +246,36 @@ class TestTryActivateFallback:
|
|||
agent = _make_agent(
|
||||
fallback_model={"provider": "zai", "model": "glm-5"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {"Z_AI_API_KEY": "sk-alt-key"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
mock_client = _mock_resolve(
|
||||
api_key="sk-alt-key",
|
||||
base_url="https://open.z.ai/api/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "glm-5"),
|
||||
):
|
||||
assert agent._try_activate_fallback() is True
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "sk-alt-key"
|
||||
assert agent.client is mock_client
|
||||
|
||||
def test_activates_codex_fallback(self):
|
||||
"""OpenAI Codex fallback should use OAuth credentials and codex_responses mode."""
|
||||
agent = _make_agent(
|
||||
fallback_model={"provider": "openai-codex", "model": "gpt-5.3-codex"},
|
||||
)
|
||||
mock_creds = {
|
||||
"api_key": "codex-oauth-token",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
}
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_codex_runtime_credentials", return_value=mock_creds),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
mock_client = _mock_resolve(
|
||||
api_key="codex-oauth-token",
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "gpt-5.3-codex"),
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
assert agent.model == "gpt-5.3-codex"
|
||||
assert agent.provider == "openai-codex"
|
||||
assert agent.api_mode == "codex_responses"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "codex-oauth-token"
|
||||
assert "chatgpt.com" in call_kwargs["base_url"]
|
||||
assert agent.client is mock_client
|
||||
|
||||
def test_codex_fallback_fails_gracefully_without_credentials(self):
|
||||
"""Codex fallback should return False if no OAuth credentials available."""
|
||||
|
|
@ -246,8 +283,8 @@ class TestTryActivateFallback:
|
|||
fallback_model={"provider": "openai-codex", "model": "gpt-5.3-codex"},
|
||||
)
|
||||
with patch(
|
||||
"hermes_cli.auth.resolve_codex_runtime_credentials",
|
||||
side_effect=Exception("No Codex credentials"),
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(None, None),
|
||||
):
|
||||
assert agent._try_activate_fallback() is False
|
||||
assert agent._fallback_activated is False
|
||||
|
|
@ -257,22 +294,20 @@ class TestTryActivateFallback:
|
|||
agent = _make_agent(
|
||||
fallback_model={"provider": "nous", "model": "nous-hermes-3"},
|
||||
)
|
||||
mock_creds = {
|
||||
"api_key": "nous-agent-key-abc",
|
||||
"base_url": "https://inference-api.nousresearch.com/v1",
|
||||
}
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_nous_runtime_credentials", return_value=mock_creds),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
mock_client = _mock_resolve(
|
||||
api_key="nous-agent-key-abc",
|
||||
base_url="https://inference-api.nousresearch.com/v1",
|
||||
)
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "nous-hermes-3"),
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
assert agent.model == "nous-hermes-3"
|
||||
assert agent.provider == "nous"
|
||||
assert agent.api_mode == "chat_completions"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "nous-agent-key-abc"
|
||||
assert "nousresearch.com" in call_kwargs["base_url"]
|
||||
assert agent.client is mock_client
|
||||
|
||||
def test_nous_fallback_fails_gracefully_without_login(self):
|
||||
"""Nous fallback should return False if not logged in."""
|
||||
|
|
@ -280,8 +315,8 @@ class TestTryActivateFallback:
|
|||
fallback_model={"provider": "nous", "model": "nous-hermes-3"},
|
||||
)
|
||||
with patch(
|
||||
"hermes_cli.auth.resolve_nous_runtime_credentials",
|
||||
side_effect=Exception("Not logged in to Nous Portal"),
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(None, None),
|
||||
):
|
||||
assert agent._try_activate_fallback() is False
|
||||
assert agent._fallback_activated is False
|
||||
|
|
@ -315,7 +350,7 @@ class TestFallbackInit:
|
|||
# =============================================================================
|
||||
|
||||
class TestProviderCredentials:
|
||||
"""Verify that each supported provider resolves its API key correctly."""
|
||||
"""Verify that each supported provider resolves via the centralized router."""
|
||||
|
||||
@pytest.mark.parametrize("provider,env_var,base_url_fragment", [
|
||||
("openrouter", "OPENROUTER_API_KEY", "openrouter"),
|
||||
|
|
@ -328,12 +363,15 @@ class TestProviderCredentials:
|
|||
agent = _make_agent(
|
||||
fallback_model={"provider": provider, "model": "test-model"},
|
||||
)
|
||||
with (
|
||||
patch.dict("os.environ", {env_var: "test-key-123"}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
mock_client = MagicMock()
|
||||
mock_client.api_key = "test-api-key"
|
||||
mock_client.base_url = f"https://{base_url_fragment}/v1"
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "test-model"),
|
||||
):
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True, f"Failed to activate fallback for {provider}"
|
||||
call_kwargs = mock_openai.call_args[1]
|
||||
assert call_kwargs["api_key"] == "test-key-123"
|
||||
assert base_url_fragment in call_kwargs["base_url"].lower()
|
||||
assert agent.client is mock_client
|
||||
assert agent.model == "test-model"
|
||||
assert agent.provider == provider
|
||||
|
|
|
|||
|
|
@ -98,10 +98,9 @@ class TestFlushMemoriesUsesAuxiliaryClient:
|
|||
def test_flush_uses_auxiliary_when_available(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, api_mode="codex_responses", provider="openai-codex")
|
||||
|
||||
mock_aux_client = MagicMock()
|
||||
mock_aux_client.chat.completions.create.return_value = _chat_response_with_memory_call()
|
||||
mock_response = _chat_response_with_memory_call()
|
||||
|
||||
with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(mock_aux_client, "gpt-4o-mini")):
|
||||
with patch("agent.auxiliary_client.call_llm", return_value=mock_response) as mock_call:
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
|
|
@ -110,9 +109,9 @@ class TestFlushMemoriesUsesAuxiliaryClient:
|
|||
with patch("tools.memory_tool.memory_tool", return_value="Saved.") as mock_memory:
|
||||
agent.flush_memories(messages)
|
||||
|
||||
mock_aux_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_aux_client.chat.completions.create.call_args
|
||||
assert call_kwargs.kwargs.get("model") == "gpt-4o-mini" or call_kwargs[1].get("model") == "gpt-4o-mini"
|
||||
mock_call.assert_called_once()
|
||||
call_kwargs = mock_call.call_args
|
||||
assert call_kwargs.kwargs.get("task") == "flush_memories"
|
||||
|
||||
def test_flush_uses_main_client_when_no_auxiliary(self, monkeypatch):
|
||||
"""Non-Codex mode with no auxiliary falls back to self.client."""
|
||||
|
|
@ -120,7 +119,7 @@ class TestFlushMemoriesUsesAuxiliaryClient:
|
|||
agent.client = MagicMock()
|
||||
agent.client.chat.completions.create.return_value = _chat_response_with_memory_call()
|
||||
|
||||
with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(None, None)):
|
||||
with patch("agent.auxiliary_client.call_llm", side_effect=RuntimeError("no provider")):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
|
|
@ -135,10 +134,9 @@ class TestFlushMemoriesUsesAuxiliaryClient:
|
|||
"""Verify that memory tool calls from the flush response actually get executed."""
|
||||
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
|
||||
|
||||
mock_aux_client = MagicMock()
|
||||
mock_aux_client.chat.completions.create.return_value = _chat_response_with_memory_call()
|
||||
mock_response = _chat_response_with_memory_call()
|
||||
|
||||
with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(mock_aux_client, "gpt-4o-mini")):
|
||||
with patch("agent.auxiliary_client.call_llm", return_value=mock_response):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
|
|
@ -157,10 +155,9 @@ class TestFlushMemoriesUsesAuxiliaryClient:
|
|||
"""After flush, the flush prompt and any response should be removed from messages."""
|
||||
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
|
||||
|
||||
mock_aux_client = MagicMock()
|
||||
mock_aux_client.chat.completions.create.return_value = _chat_response_with_memory_call()
|
||||
mock_response = _chat_response_with_memory_call()
|
||||
|
||||
with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(mock_aux_client, "gpt-4o-mini")):
|
||||
with patch("agent.auxiliary_client.call_llm", return_value=mock_response):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
|
|
@ -202,7 +199,7 @@ class TestFlushMemoriesCodexFallback:
|
|||
model="gpt-5-codex",
|
||||
)
|
||||
|
||||
with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(None, None)), \
|
||||
with patch("agent.auxiliary_client.call_llm", side_effect=RuntimeError("no provider")), \
|
||||
patch.object(agent, "_run_codex_stream", return_value=codex_response) as mock_stream, \
|
||||
patch.object(agent, "_build_api_kwargs") as mock_build, \
|
||||
patch("tools.memory_tool.memory_tool", return_value="Saved.") as mock_memory:
|
||||
|
|
|
|||
189
tests/test_interactive_interrupt.py
Normal file
189
tests/test_interactive_interrupt.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Interactive interrupt test that mimics the exact CLI flow.
|
||||
|
||||
Starts an agent in a thread with a mock delegate_task that takes a while,
|
||||
then simulates the user typing a message via _interrupt_queue.
|
||||
|
||||
Logs every step to stderr (which isn't affected by redirect_stdout)
|
||||
so we can see exactly where the interrupt gets lost.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
|
||||
# Force stderr logging so redirect_stdout doesn't swallow it
|
||||
logging.basicConfig(level=logging.DEBUG, stream=sys.stderr,
|
||||
format="%(asctime)s [%(threadName)s] %(message)s")
|
||||
log = logging.getLogger("interrupt_test")
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from run_agent import AIAgent, IterationBudget
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
|
||||
set_interrupt(False)
|
||||
|
||||
# ─── Create parent agent ───
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
|
||||
# Monkey-patch parent.interrupt to log
|
||||
_original_interrupt = AIAgent.interrupt
|
||||
def logged_interrupt(self, message=None):
|
||||
log.info(f"🔴 parent.interrupt() called with: {message!r}")
|
||||
log.info(f" _active_children count: {len(self._active_children)}")
|
||||
_original_interrupt(self, message)
|
||||
log.info(f" After interrupt: _interrupt_requested={self._interrupt_requested}")
|
||||
for i, c in enumerate(self._active_children):
|
||||
log.info(f" Child {i}._interrupt_requested={c._interrupt_requested}")
|
||||
parent.interrupt = lambda msg=None: logged_interrupt(parent, msg)
|
||||
|
||||
# ─── Simulate the exact CLI flow ───
|
||||
interrupt_queue = queue.Queue()
|
||||
child_running = threading.Event()
|
||||
agent_result = [None]
|
||||
|
||||
def make_slow_response(delay=2.0):
|
||||
"""API response that takes a while."""
|
||||
def create(**kwargs):
|
||||
log.info(f" 🌐 Mock API call starting (will take {delay}s)...")
|
||||
time.sleep(delay)
|
||||
log.info(f" 🌐 Mock API call completed")
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "Done with the task"
|
||||
resp.choices[0].message.tool_calls = None
|
||||
resp.choices[0].message.refusal = None
|
||||
resp.choices[0].finish_reason = "stop"
|
||||
resp.usage.prompt_tokens = 100
|
||||
resp.usage.completion_tokens = 10
|
||||
resp.usage.total_tokens = 110
|
||||
resp.usage.prompt_tokens_details = None
|
||||
return resp
|
||||
return create
|
||||
|
||||
|
||||
def agent_thread_func():
|
||||
"""Simulates the agent_thread in cli.py's chat() method."""
|
||||
log.info("🟢 agent_thread starting")
|
||||
|
||||
with patch("run_agent.OpenAI") as MockOpenAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = make_slow_response(delay=3.0)
|
||||
mock_client.close = MagicMock()
|
||||
MockOpenAI.return_value = mock_client
|
||||
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
# Signal that child is about to start
|
||||
original_init = AIAgent.__init__
|
||||
def patched_init(self_agent, *a, **kw):
|
||||
log.info("🟡 Child AIAgent.__init__ called")
|
||||
original_init(self_agent, *a, **kw)
|
||||
child_running.set()
|
||||
log.info(f"🟡 Child started, parent._active_children = {len(parent._active_children)}")
|
||||
|
||||
with patch.object(AIAgent, "__init__", patched_init):
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Do a slow thing",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model="test/model",
|
||||
max_iterations=3,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
override_provider="test",
|
||||
override_base_url="http://localhost:1",
|
||||
override_api_key="test",
|
||||
override_api_mode="chat_completions",
|
||||
)
|
||||
agent_result[0] = result
|
||||
log.info(f"🟢 agent_thread finished. Result status: {result.get('status')}")
|
||||
|
||||
|
||||
# ─── Start agent thread (like chat() does) ───
|
||||
agent_thread = threading.Thread(target=agent_thread_func, name="agent_thread", daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
# ─── Wait for child to start ───
|
||||
if not child_running.wait(timeout=10):
|
||||
print("FAIL: Child never started", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Give child time to enter its main loop and start API call
|
||||
time.sleep(1.0)
|
||||
|
||||
# ─── Simulate user typing a message (like handle_enter does) ───
|
||||
log.info("📝 Simulating user typing 'Hey stop that'")
|
||||
interrupt_queue.put("Hey stop that")
|
||||
|
||||
# ─── Simulate chat() polling loop (like the real chat() method) ───
|
||||
log.info("📡 Starting interrupt queue polling (like chat())")
|
||||
interrupt_msg = None
|
||||
poll_count = 0
|
||||
while agent_thread.is_alive():
|
||||
try:
|
||||
interrupt_msg = interrupt_queue.get(timeout=0.1)
|
||||
if interrupt_msg:
|
||||
log.info(f"📨 Got interrupt message from queue: {interrupt_msg!r}")
|
||||
log.info(f" Calling parent.interrupt()...")
|
||||
parent.interrupt(interrupt_msg)
|
||||
log.info(f" parent.interrupt() returned. Breaking poll loop.")
|
||||
break
|
||||
except queue.Empty:
|
||||
poll_count += 1
|
||||
if poll_count % 20 == 0: # Log every 2s
|
||||
log.info(f" Still polling ({poll_count} iterations)...")
|
||||
|
||||
# ─── Wait for agent to finish ───
|
||||
log.info("⏳ Waiting for agent_thread to join...")
|
||||
t0 = time.monotonic()
|
||||
agent_thread.join(timeout=10)
|
||||
elapsed = time.monotonic() - t0
|
||||
log.info(f"✅ agent_thread joined after {elapsed:.2f}s")
|
||||
|
||||
# ─── Check results ───
|
||||
result = agent_result[0]
|
||||
if result:
|
||||
log.info(f"Result status: {result['status']}")
|
||||
log.info(f"Result duration: {result['duration_seconds']}s")
|
||||
if result["status"] == "interrupted" and elapsed < 2.0:
|
||||
print("✅ PASS: Interrupt worked correctly!", file=sys.stderr)
|
||||
else:
|
||||
print(f"❌ FAIL: status={result['status']}, elapsed={elapsed:.2f}s", file=sys.stderr)
|
||||
else:
|
||||
print("❌ FAIL: No result returned", file=sys.stderr)
|
||||
|
||||
set_interrupt(False)
|
||||
155
tests/test_interrupt_propagation.py
Normal file
155
tests/test_interrupt_propagation.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""Test interrupt propagation from parent to child agents.
|
||||
|
||||
Reproduces the CLI scenario: user sends a message while delegate_task is
|
||||
running, main thread calls parent.interrupt(), child should stop.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
from tools.interrupt import set_interrupt, is_interrupted, _interrupt_event
|
||||
|
||||
|
||||
class TestInterruptPropagationToChild(unittest.TestCase):
|
||||
"""Verify interrupt propagates from parent to child agent."""
|
||||
|
||||
def setUp(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def tearDown(self):
|
||||
set_interrupt(False)
|
||||
|
||||
def test_parent_interrupt_sets_child_flag(self):
|
||||
"""When parent.interrupt() is called, child._interrupt_requested should be set."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child.quiet_mode = True
|
||||
|
||||
parent._active_children.append(child)
|
||||
|
||||
parent.interrupt("new user message")
|
||||
|
||||
assert parent._interrupt_requested is True
|
||||
assert child._interrupt_requested is True
|
||||
assert child._interrupt_message == "new user message"
|
||||
assert is_interrupted() is True
|
||||
|
||||
def test_child_clear_interrupt_at_start_clears_global(self):
|
||||
"""child.clear_interrupt() at start of run_conversation clears the GLOBAL event.
|
||||
|
||||
This is the intended behavior at startup, but verify it doesn't
|
||||
accidentally clear an interrupt intended for a running child.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = True
|
||||
child._interrupt_message = "msg"
|
||||
child.quiet_mode = True
|
||||
child._active_children = []
|
||||
|
||||
# Global is set
|
||||
set_interrupt(True)
|
||||
assert is_interrupted() is True
|
||||
|
||||
# child.clear_interrupt() clears both
|
||||
child.clear_interrupt()
|
||||
assert child._interrupt_requested is False
|
||||
assert is_interrupted() is False
|
||||
|
||||
def test_interrupt_during_child_api_call_detected(self):
|
||||
"""Interrupt set during _interruptible_api_call is detected within 0.5s."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child.quiet_mode = True
|
||||
child.api_mode = "chat_completions"
|
||||
child.log_prefix = ""
|
||||
child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"}
|
||||
|
||||
# Mock a slow API call
|
||||
mock_client = MagicMock()
|
||||
def slow_api_call(**kwargs):
|
||||
time.sleep(5) # Would take 5s normally
|
||||
return MagicMock()
|
||||
mock_client.chat.completions.create = slow_api_call
|
||||
mock_client.close = MagicMock()
|
||||
child.client = mock_client
|
||||
|
||||
# Set interrupt after 0.2s from another thread
|
||||
def set_interrupt_later():
|
||||
time.sleep(0.2)
|
||||
child.interrupt("stop!")
|
||||
t = threading.Thread(target=set_interrupt_later, daemon=True)
|
||||
t.start()
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
child._interruptible_api_call({"model": "test", "messages": []})
|
||||
self.fail("Should have raised InterruptedError")
|
||||
except InterruptedError:
|
||||
elapsed = time.monotonic() - start
|
||||
# Should detect within ~0.5s (0.2s delay + 0.3s poll interval)
|
||||
assert elapsed < 1.0, f"Took {elapsed:.2f}s to detect interrupt (expected < 1.0s)"
|
||||
finally:
|
||||
t.join(timeout=2)
|
||||
set_interrupt(False)
|
||||
|
||||
def test_concurrent_interrupt_propagation(self):
|
||||
"""Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
|
||||
child = AIAgent.__new__(AIAgent)
|
||||
child._interrupt_requested = False
|
||||
child._interrupt_message = None
|
||||
child._active_children = []
|
||||
child.quiet_mode = True
|
||||
|
||||
# Register child (simulating what _run_single_child does)
|
||||
parent._active_children.append(child)
|
||||
|
||||
# Simulate child running (checking flag in a loop)
|
||||
child_detected = threading.Event()
|
||||
def simulate_child_loop():
|
||||
while not child._interrupt_requested:
|
||||
time.sleep(0.05)
|
||||
child_detected.set()
|
||||
|
||||
child_thread = threading.Thread(target=simulate_child_loop, daemon=True)
|
||||
child_thread.start()
|
||||
|
||||
# Small delay, then interrupt from "main thread"
|
||||
time.sleep(0.1)
|
||||
parent.interrupt("user typed something new")
|
||||
|
||||
# Child should detect within 200ms
|
||||
detected = child_detected.wait(timeout=1.0)
|
||||
assert detected, "Child never detected the interrupt!"
|
||||
child_thread.join(timeout=1)
|
||||
set_interrupt(False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -342,6 +342,90 @@ class TestExtractReasoningFormats(unittest.TestCase):
|
|||
self.assertIsNone(result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inline <think> block extraction fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestInlineThinkBlockExtraction(unittest.TestCase):
|
||||
"""Test _build_assistant_message extracts inline <think> blocks as reasoning
|
||||
when no structured API-level reasoning fields are present."""
|
||||
|
||||
def _build_msg(self, content, reasoning=None, reasoning_content=None, reasoning_details=None, tool_calls=None):
|
||||
"""Create a mock API response message."""
|
||||
msg = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||
if reasoning is not None:
|
||||
msg.reasoning = reasoning
|
||||
if reasoning_content is not None:
|
||||
msg.reasoning_content = reasoning_content
|
||||
if reasoning_details is not None:
|
||||
msg.reasoning_details = reasoning_details
|
||||
return msg
|
||||
|
||||
def _make_agent(self):
|
||||
"""Create a minimal agent with _build_assistant_message."""
|
||||
from run_agent import AIAgent
|
||||
agent = MagicMock(spec=AIAgent)
|
||||
agent._build_assistant_message = AIAgent._build_assistant_message.__get__(agent)
|
||||
agent._extract_reasoning = AIAgent._extract_reasoning.__get__(agent)
|
||||
agent.verbose_logging = False
|
||||
agent.reasoning_callback = None
|
||||
return agent
|
||||
|
||||
def test_single_think_block_extracted(self):
|
||||
agent = self._make_agent()
|
||||
api_msg = self._build_msg("<think>Let me calculate 2+2=4.</think>The answer is 4.")
|
||||
result = agent._build_assistant_message(api_msg, "stop")
|
||||
self.assertEqual(result["reasoning"], "Let me calculate 2+2=4.")
|
||||
|
||||
def test_multiple_think_blocks_extracted(self):
|
||||
agent = self._make_agent()
|
||||
api_msg = self._build_msg("<think>First thought.</think>Some text<think>Second thought.</think>More text")
|
||||
result = agent._build_assistant_message(api_msg, "stop")
|
||||
self.assertIn("First thought.", result["reasoning"])
|
||||
self.assertIn("Second thought.", result["reasoning"])
|
||||
|
||||
def test_no_think_blocks_no_reasoning(self):
|
||||
agent = self._make_agent()
|
||||
api_msg = self._build_msg("Just a plain response.")
|
||||
result = agent._build_assistant_message(api_msg, "stop")
|
||||
# No structured reasoning AND no inline think blocks → None
|
||||
self.assertIsNone(result["reasoning"])
|
||||
|
||||
def test_structured_reasoning_takes_priority(self):
|
||||
"""When structured API reasoning exists, inline think blocks should NOT override."""
|
||||
agent = self._make_agent()
|
||||
api_msg = self._build_msg(
|
||||
"<think>Inline thought.</think>Response text.",
|
||||
reasoning="Structured reasoning from API.",
|
||||
)
|
||||
result = agent._build_assistant_message(api_msg, "stop")
|
||||
self.assertEqual(result["reasoning"], "Structured reasoning from API.")
|
||||
|
||||
def test_empty_think_block_ignored(self):
|
||||
agent = self._make_agent()
|
||||
api_msg = self._build_msg("<think></think>Hello!")
|
||||
result = agent._build_assistant_message(api_msg, "stop")
|
||||
# Empty think block should not produce reasoning
|
||||
self.assertIsNone(result["reasoning"])
|
||||
|
||||
def test_multiline_think_block(self):
|
||||
agent = self._make_agent()
|
||||
api_msg = self._build_msg("<think>\nStep 1: Analyze.\nStep 2: Solve.\n</think>Done.")
|
||||
result = agent._build_assistant_message(api_msg, "stop")
|
||||
self.assertIn("Step 1: Analyze.", result["reasoning"])
|
||||
self.assertIn("Step 2: Solve.", result["reasoning"])
|
||||
|
||||
def test_callback_fires_for_inline_think(self):
|
||||
"""Reasoning callback should fire when reasoning is extracted from inline think blocks."""
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
api_msg = self._build_msg("<think>Deep analysis here.</think>Answer.")
|
||||
agent._build_assistant_message(api_msg, "stop")
|
||||
self.assertEqual(len(captured), 1)
|
||||
self.assertIn("Deep analysis", captured[0])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
54
tests/test_redirect_stdout_issue.py
Normal file
54
tests/test_redirect_stdout_issue.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""Verify that redirect_stdout in _run_single_child is process-wide.
|
||||
|
||||
This demonstrates that contextlib.redirect_stdout changes sys.stdout
|
||||
for ALL threads, not just the current one. This means during subagent
|
||||
execution, all output from other threads (including the CLI's process_thread)
|
||||
is swallowed.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
|
||||
class TestRedirectStdoutIsProcessWide(unittest.TestCase):
|
||||
|
||||
def test_redirect_stdout_affects_other_threads(self):
|
||||
"""contextlib.redirect_stdout changes sys.stdout for ALL threads."""
|
||||
captured_from_other_thread = []
|
||||
real_stdout = sys.stdout
|
||||
other_thread_saw_devnull = threading.Event()
|
||||
|
||||
def other_thread_work():
|
||||
"""Runs in a different thread, tries to use sys.stdout."""
|
||||
time.sleep(0.2) # Let redirect_stdout take effect
|
||||
# Check what sys.stdout is
|
||||
if sys.stdout is not real_stdout:
|
||||
other_thread_saw_devnull.set()
|
||||
# Try to print — this should go to devnull
|
||||
captured_from_other_thread.append(sys.stdout)
|
||||
|
||||
t = threading.Thread(target=other_thread_work, daemon=True)
|
||||
t.start()
|
||||
|
||||
# redirect_stdout in main thread
|
||||
devnull = io.StringIO()
|
||||
with contextlib.redirect_stdout(devnull):
|
||||
time.sleep(0.5) # Let the other thread check during redirect
|
||||
|
||||
t.join(timeout=2)
|
||||
|
||||
# The other thread should have seen devnull, NOT the real stdout
|
||||
self.assertTrue(
|
||||
other_thread_saw_devnull.is_set(),
|
||||
"redirect_stdout was NOT process-wide — other thread still saw real stdout. "
|
||||
"This test's premise is wrong."
|
||||
)
|
||||
print("Confirmed: redirect_stdout IS process-wide — affects all threads")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -960,7 +960,7 @@ class TestFlushSentinelNotLeaked:
|
|||
agent.client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Bypass auxiliary client so flush uses agent.client directly
|
||||
with patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(None, None)):
|
||||
with patch("agent.auxiliary_client.call_llm", side_effect=RuntimeError("no provider")):
|
||||
agent.flush_memories(messages, min_turns=0)
|
||||
|
||||
# Check what was actually sent to the API
|
||||
|
|
@ -1466,3 +1466,83 @@ class TestBudgetPressure:
|
|||
messages[-1]["content"] = last_content + f"\n\n{warning}"
|
||||
assert "plain text result" in messages[-1]["content"]
|
||||
assert "BUDGET WARNING" in messages[-1]["content"]
|
||||
|
||||
|
||||
class TestSafeWriter:
|
||||
"""Verify _SafeWriter guards stdout against OSError (broken pipes)."""
|
||||
|
||||
def test_write_delegates_normally(self):
|
||||
"""When stdout is healthy, _SafeWriter is transparent."""
|
||||
from run_agent import _SafeWriter
|
||||
from io import StringIO
|
||||
inner = StringIO()
|
||||
writer = _SafeWriter(inner)
|
||||
writer.write("hello")
|
||||
assert inner.getvalue() == "hello"
|
||||
|
||||
def test_write_catches_oserror(self):
|
||||
"""OSError on write is silently caught, returns len(data)."""
|
||||
from run_agent import _SafeWriter
|
||||
from unittest.mock import MagicMock
|
||||
inner = MagicMock()
|
||||
inner.write.side_effect = OSError(5, "Input/output error")
|
||||
writer = _SafeWriter(inner)
|
||||
result = writer.write("hello")
|
||||
assert result == 5 # len("hello")
|
||||
|
||||
def test_flush_catches_oserror(self):
|
||||
"""OSError on flush is silently caught."""
|
||||
from run_agent import _SafeWriter
|
||||
from unittest.mock import MagicMock
|
||||
inner = MagicMock()
|
||||
inner.flush.side_effect = OSError(5, "Input/output error")
|
||||
writer = _SafeWriter(inner)
|
||||
writer.flush() # should not raise
|
||||
|
||||
def test_print_survives_broken_stdout(self, monkeypatch):
|
||||
"""print() through _SafeWriter doesn't crash on broken pipe."""
|
||||
import sys
|
||||
from run_agent import _SafeWriter
|
||||
from unittest.mock import MagicMock
|
||||
broken = MagicMock()
|
||||
broken.write.side_effect = OSError(5, "Input/output error")
|
||||
original = sys.stdout
|
||||
sys.stdout = _SafeWriter(broken)
|
||||
try:
|
||||
print("this should not crash") # would raise without _SafeWriter
|
||||
finally:
|
||||
sys.stdout = original
|
||||
|
||||
def test_installed_in_run_conversation(self, agent):
|
||||
"""run_conversation installs _SafeWriter on sys.stdout."""
|
||||
import sys
|
||||
from run_agent import _SafeWriter
|
||||
resp = _mock_response(content="Done", finish_reason="stop")
|
||||
agent.client.chat.completions.create.return_value = resp
|
||||
original = sys.stdout
|
||||
try:
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
agent.run_conversation("test")
|
||||
assert isinstance(sys.stdout, _SafeWriter)
|
||||
finally:
|
||||
sys.stdout = original
|
||||
|
||||
def test_double_wrap_prevented(self):
|
||||
"""Wrapping an already-wrapped stream doesn't add layers."""
|
||||
import sys
|
||||
from run_agent import _SafeWriter
|
||||
from io import StringIO
|
||||
inner = StringIO()
|
||||
wrapped = _SafeWriter(inner)
|
||||
# isinstance check should prevent double-wrapping
|
||||
assert isinstance(wrapped, _SafeWriter)
|
||||
# The guard in run_conversation checks isinstance before wrapping
|
||||
if not isinstance(wrapped, _SafeWriter):
|
||||
wrapped = _SafeWriter(wrapped)
|
||||
# Still just one layer
|
||||
wrapped.write("test")
|
||||
assert inner.getvalue() == "test"
|
||||
|
|
|
|||
|
|
@ -158,29 +158,6 @@ def test_custom_endpoint_auto_provider_prefers_openai_key(monkeypatch):
|
|||
assert resolved["api_key"] == "sk-vllm-key"
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_nous_api(monkeypatch):
|
||||
"""Nous Portal API key provider resolves via the api_key path."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "nous-api")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"resolve_api_key_provider_credentials",
|
||||
lambda pid: {
|
||||
"provider": "nous-api",
|
||||
"api_key": "nous-test-key",
|
||||
"base_url": "https://inference-api.nousresearch.com/v1",
|
||||
"source": "NOUS_API_KEY",
|
||||
},
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="nous-api")
|
||||
|
||||
assert resolved["provider"] == "nous-api"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://inference-api.nousresearch.com/v1"
|
||||
assert resolved["api_key"] == "nous-test-key"
|
||||
assert resolved["requested_provider"] == "nous-api"
|
||||
|
||||
|
||||
def test_explicit_openrouter_skips_openai_base_url(monkeypatch):
|
||||
"""When the user explicitly requests openrouter, OPENAI_BASE_URL
|
||||
(which may point to a custom endpoint) must not override the
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""Tests for the dangerous command approval module."""
|
||||
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
from tools.approval import (
|
||||
approve_session,
|
||||
clear_session,
|
||||
|
|
@ -7,6 +9,7 @@ from tools.approval import (
|
|||
has_pending,
|
||||
is_approved,
|
||||
pop_pending,
|
||||
prompt_dangerous_approval,
|
||||
submit_pending,
|
||||
)
|
||||
|
||||
|
|
@ -338,3 +341,63 @@ class TestFindExecFullPathRm:
|
|||
assert dangerous is False
|
||||
assert key is None
|
||||
|
||||
|
||||
class TestViewFullCommand:
|
||||
"""Tests for the 'view full command' option in prompt_dangerous_approval."""
|
||||
|
||||
def test_view_then_once_fallback(self):
|
||||
"""Pressing 'v' shows the full command, then 'o' approves once."""
|
||||
long_cmd = "rm -rf " + "a" * 200
|
||||
inputs = iter(["v", "o"])
|
||||
with mock_patch("builtins.input", side_effect=inputs):
|
||||
result = prompt_dangerous_approval(long_cmd, "recursive delete")
|
||||
assert result == "once"
|
||||
|
||||
def test_view_then_deny_fallback(self):
|
||||
"""Pressing 'v' shows the full command, then 'd' denies."""
|
||||
long_cmd = "rm -rf " + "b" * 200
|
||||
inputs = iter(["v", "d"])
|
||||
with mock_patch("builtins.input", side_effect=inputs):
|
||||
result = prompt_dangerous_approval(long_cmd, "recursive delete")
|
||||
assert result == "deny"
|
||||
|
||||
def test_view_then_session_fallback(self):
|
||||
"""Pressing 'v' shows the full command, then 's' approves for session."""
|
||||
long_cmd = "rm -rf " + "c" * 200
|
||||
inputs = iter(["v", "s"])
|
||||
with mock_patch("builtins.input", side_effect=inputs):
|
||||
result = prompt_dangerous_approval(long_cmd, "recursive delete")
|
||||
assert result == "session"
|
||||
|
||||
def test_view_then_always_fallback(self):
|
||||
"""Pressing 'v' shows the full command, then 'a' approves always."""
|
||||
long_cmd = "rm -rf " + "d" * 200
|
||||
inputs = iter(["v", "a"])
|
||||
with mock_patch("builtins.input", side_effect=inputs):
|
||||
result = prompt_dangerous_approval(long_cmd, "recursive delete")
|
||||
assert result == "always"
|
||||
|
||||
def test_view_not_shown_for_short_command(self):
|
||||
"""Short commands don't offer the view option; 'v' falls through to deny."""
|
||||
short_cmd = "rm -rf /tmp"
|
||||
with mock_patch("builtins.input", return_value="v"):
|
||||
result = prompt_dangerous_approval(short_cmd, "recursive delete")
|
||||
# 'v' is not a valid choice for short commands, should deny
|
||||
assert result == "deny"
|
||||
|
||||
def test_once_without_view(self):
|
||||
"""Directly pressing 'o' without viewing still works."""
|
||||
long_cmd = "rm -rf " + "e" * 200
|
||||
with mock_patch("builtins.input", return_value="o"):
|
||||
result = prompt_dangerous_approval(long_cmd, "recursive delete")
|
||||
assert result == "once"
|
||||
|
||||
def test_view_ignored_after_already_shown(self):
|
||||
"""After viewing once, 'v' on a now-untruncated display falls through to deny."""
|
||||
long_cmd = "rm -rf " + "f" * 200
|
||||
inputs = iter(["v", "v"]) # second 'v' should not match since is_truncated is False
|
||||
with mock_patch("builtins.input", side_effect=inputs):
|
||||
result = prompt_dangerous_approval(long_cmd, "recursive delete")
|
||||
# After first 'v', is_truncated becomes False, so second 'v' -> deny
|
||||
assert result == "deny"
|
||||
|
||||
|
|
|
|||
|
|
@ -137,8 +137,7 @@ class TestBrowserVisionAnnotate:
|
|||
|
||||
with (
|
||||
patch("tools.browser_tool._run_browser_command") as mock_cmd,
|
||||
patch("tools.browser_tool._aux_vision_client") as mock_client,
|
||||
patch("tools.browser_tool._DEFAULT_VISION_MODEL", "test-model"),
|
||||
patch("tools.browser_tool.call_llm") as mock_call_llm,
|
||||
patch("tools.browser_tool._get_vision_model", return_value="test-model"),
|
||||
):
|
||||
mock_cmd.return_value = {"success": True, "data": {}}
|
||||
|
|
@ -159,8 +158,7 @@ class TestBrowserVisionAnnotate:
|
|||
|
||||
with (
|
||||
patch("tools.browser_tool._run_browser_command") as mock_cmd,
|
||||
patch("tools.browser_tool._aux_vision_client") as mock_client,
|
||||
patch("tools.browser_tool._DEFAULT_VISION_MODEL", "test-model"),
|
||||
patch("tools.browser_tool.call_llm") as mock_call_llm,
|
||||
patch("tools.browser_tool._get_vision_model", return_value="test-model"),
|
||||
):
|
||||
mock_cmd.return_value = {"success": True, "data": {}}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
|
||||
Tests for the code execution sandbox (programmatic tool calling).
|
||||
|
||||
These tests monkeypatch handle_function_call so they don't require API keys
|
||||
|
|
@ -11,6 +12,10 @@ Run with: python -m pytest tests/test_code_execution.py -v
|
|||
or: python tests/test_code_execution.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.skip(reason="Hangs in non-interactive environments")
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -8,6 +8,11 @@ Every test with output validates against a known-good value AND
|
|||
asserts zero contamination from shell noise via _assert_clean().
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.skip(reason="Hangs in non-interactive environments")
|
||||
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1828,8 +1828,8 @@ class TestSamplingCallbackText:
|
|||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
|
@ -1847,13 +1847,13 @@ class TestSamplingCallbackText:
|
|||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
) as mock_call:
|
||||
params = _make_sampling_params(system_prompt="Be helpful")
|
||||
asyncio.run(self.handler(None, params))
|
||||
|
||||
call_args = fake_client.chat.completions.create.call_args
|
||||
call_args = mock_call.call_args
|
||||
messages = call_args.kwargs["messages"]
|
||||
assert messages[0] == {"role": "system", "content": "Be helpful"}
|
||||
|
||||
|
|
@ -1865,8 +1865,8 @@ class TestSamplingCallbackText:
|
|||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
|
@ -1889,8 +1889,8 @@ class TestSamplingCallbackToolUse:
|
|||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
|
@ -1916,8 +1916,8 @@ class TestSamplingCallbackToolUse:
|
|||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(self.handler(None, _make_sampling_params()))
|
||||
|
||||
|
|
@ -1939,8 +1939,8 @@ class TestToolLoopGovernance:
|
|||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
# Round 1, 2: allowed
|
||||
|
|
@ -1956,24 +1956,26 @@ class TestToolLoopGovernance:
|
|||
def test_text_response_resets_counter(self):
|
||||
"""A text response resets the tool loop counter."""
|
||||
handler = SamplingHandler("tl2", {"max_tool_rounds": 1})
|
||||
fake_client = MagicMock()
|
||||
|
||||
# Use a list to hold the current response, so the side_effect can
|
||||
# pick up changes between calls.
|
||||
responses = [_make_llm_tool_response()]
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
side_effect=lambda **kw: responses[0],
|
||||
):
|
||||
# Tool response (round 1 of 1 allowed)
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
r1 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r1, CreateMessageResultWithTools)
|
||||
|
||||
# Text response resets counter
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
responses[0] = _make_llm_response()
|
||||
r2 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r2, CreateMessageResult)
|
||||
|
||||
# Tool response again (should succeed since counter was reset)
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
responses[0] = _make_llm_tool_response()
|
||||
r3 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r3, CreateMessageResultWithTools)
|
||||
|
||||
|
|
@ -1984,8 +1986,8 @@ class TestToolLoopGovernance:
|
|||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
|
|
@ -2003,8 +2005,8 @@ class TestSamplingErrors:
|
|||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
# First call succeeds
|
||||
r1 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
|
@ -2017,20 +2019,16 @@ class TestSamplingErrors:
|
|||
|
||||
def test_timeout_error(self):
|
||||
handler = SamplingHandler("to", {"timeout": 0.05})
|
||||
fake_client = MagicMock()
|
||||
|
||||
def slow_call(**kwargs):
|
||||
import threading
|
||||
# Use an event to ensure the thread truly blocks long enough
|
||||
evt = threading.Event()
|
||||
evt.wait(5) # blocks for up to 5 seconds (cancelled by timeout)
|
||||
return _make_llm_response()
|
||||
|
||||
fake_client.chat.completions.create.side_effect = slow_call
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
side_effect=slow_call,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
|
|
@ -2041,12 +2039,11 @@ class TestSamplingErrors:
|
|||
handler = SamplingHandler("np", {})
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(None, None),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
side_effect=RuntimeError("No LLM provider configured"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "No LLM provider" in result.message
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_empty_choices_returns_error(self):
|
||||
|
|
@ -2060,8 +2057,8 @@ class TestSamplingErrors:
|
|||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
|
|
@ -2080,8 +2077,8 @@ class TestSamplingErrors:
|
|||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
|
|
@ -2099,8 +2096,8 @@ class TestSamplingErrors:
|
|||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
|
|
@ -2120,19 +2117,19 @@ class TestModelWhitelist:
|
|||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "test-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
|
||||
def test_disallowed_model_rejected(self):
|
||||
handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"]})
|
||||
handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"], "model": "test-model"})
|
||||
fake_client = MagicMock()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "gpt-3.5-turbo"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
|
|
@ -2145,8 +2142,8 @@ class TestModelWhitelist:
|
|||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "any-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
|
|
@ -2166,8 +2163,8 @@ class TestMalformedToolCallArgs:
|
|||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
|
|
@ -2194,8 +2191,8 @@ class TestMalformedToolCallArgs:
|
|||
fake_client.chat.completions.create.return_value = response
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
|
|
@ -2214,8 +2211,8 @@ class TestMetricsTracking:
|
|||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
|
|
@ -2229,8 +2226,8 @@ class TestMetricsTracking:
|
|||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
return_value=fake_client.chat.completions.create.return_value,
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
|
|
@ -2241,8 +2238,8 @@ class TestMetricsTracking:
|
|||
handler = SamplingHandler("met3", {})
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(None, None),
|
||||
"agent.auxiliary_client.call_llm",
|
||||
side_effect=RuntimeError("No LLM provider configured"),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
|
|
|
|||
|
|
@ -189,16 +189,14 @@ class TestSessionSearch:
|
|||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
|
||||
# Mock the summarizer to return a simple summary
|
||||
import tools.session_search_tool as sst
|
||||
original_client = sst._async_aux_client
|
||||
sst._async_aux_client = None # Disable summarizer → returns None
|
||||
|
||||
result = json.loads(session_search(
|
||||
query="test", db=mock_db, current_session_id=current_sid,
|
||||
))
|
||||
|
||||
sst._async_aux_client = original_client
|
||||
# Mock async_call_llm to raise RuntimeError → summarizer returns None
|
||||
from unittest.mock import AsyncMock, patch as _patch
|
||||
with _patch("tools.session_search_tool.async_call_llm",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("no provider")):
|
||||
result = json.loads(session_search(
|
||||
query="test", db=mock_db, current_session_id=current_sid,
|
||||
))
|
||||
|
||||
assert result["success"] is True
|
||||
# Current session should be skipped, only other_sid should appear
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ class TestHandleVisionAnalyze:
|
|||
assert model == "custom/model-v1"
|
||||
|
||||
def test_falls_back_to_default_model(self):
|
||||
"""Without AUXILIARY_VISION_MODEL, should use DEFAULT_VISION_MODEL or fallback."""
|
||||
"""Without AUXILIARY_VISION_MODEL, model should be None (let call_llm resolve default)."""
|
||||
with (
|
||||
patch(
|
||||
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||
|
|
@ -218,9 +218,9 @@ class TestHandleVisionAnalyze:
|
|||
coro.close()
|
||||
call_args = mock_tool.call_args
|
||||
model = call_args[0][2]
|
||||
# Should be DEFAULT_VISION_MODEL or the hardcoded fallback
|
||||
assert model is not None
|
||||
assert len(model) > 0
|
||||
# With no AUXILIARY_VISION_MODEL set, model should be None
|
||||
# (the centralized call_llm router picks the default)
|
||||
assert model is None
|
||||
|
||||
def test_empty_args_graceful(self):
|
||||
"""Missing keys should default to empty strings, not raise."""
|
||||
|
|
@ -277,8 +277,6 @@ class TestErrorLoggingExcInfo:
|
|||
new_callable=AsyncMock,
|
||||
side_effect=Exception("download boom"),
|
||||
),
|
||||
patch("tools.vision_tools._aux_async_client", MagicMock()),
|
||||
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"),
|
||||
caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
|
||||
):
|
||||
result = await vision_analyze_tool(
|
||||
|
|
@ -311,25 +309,16 @@ class TestErrorLoggingExcInfo:
|
|||
"tools.vision_tools._image_to_base64_data_url",
|
||||
return_value="data:image/jpeg;base64,abc",
|
||||
),
|
||||
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None),
|
||||
patch(
|
||||
"agent.auxiliary_client.auxiliary_max_tokens_param",
|
||||
return_value={"max_tokens": 2000},
|
||||
),
|
||||
caplog.at_level(logging.WARNING, logger="tools.vision_tools"),
|
||||
):
|
||||
# Mock the vision client
|
||||
mock_client = AsyncMock()
|
||||
# Mock the async_call_llm function to return a mock response
|
||||
mock_response = MagicMock()
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "A test image description"
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch module-level _aux_async_client so the tool doesn't bail early
|
||||
with (
|
||||
patch("tools.vision_tools._aux_async_client", mock_client),
|
||||
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"),
|
||||
patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock, return_value=mock_response),
|
||||
):
|
||||
# Make unlink fail to trigger cleanup warning
|
||||
original_unlink = Path.unlink
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue