The architecture has been updated
This commit is contained in:
parent
805f7a017e
commit
a01257ead9
1119 changed files with 226 additions and 352 deletions
0
hermes_code/tests/acp/__init__.py
Normal file
0
hermes_code/tests/acp/__init__.py
Normal file
56
hermes_code/tests/acp/test_auth.py
Normal file
56
hermes_code/tests/acp/test_auth.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Tests for acp_adapter.auth — provider detection."""
|
||||
|
||||
from acp_adapter.auth import has_provider, detect_provider
|
||||
|
||||
|
||||
class TestHasProvider:
|
||||
def test_has_provider_with_resolved_runtime(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "openrouter", "api_key": "sk-or-test"},
|
||||
)
|
||||
assert has_provider() is True
|
||||
|
||||
def test_has_no_provider_when_runtime_has_no_key(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "openrouter", "api_key": ""},
|
||||
)
|
||||
assert has_provider() is False
|
||||
|
||||
def test_has_no_provider_when_runtime_resolution_fails(self, monkeypatch):
|
||||
def _boom():
|
||||
raise RuntimeError("no provider")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _boom)
|
||||
assert has_provider() is False
|
||||
|
||||
|
||||
class TestDetectProvider:
|
||||
def test_detect_openrouter(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "openrouter", "api_key": "sk-or-test"},
|
||||
)
|
||||
assert detect_provider() == "openrouter"
|
||||
|
||||
def test_detect_anthropic(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "anthropic", "api_key": "sk-ant-test"},
|
||||
)
|
||||
assert detect_provider() == "anthropic"
|
||||
|
||||
def test_detect_none_when_no_key(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "kimi-coding", "api_key": ""},
|
||||
)
|
||||
assert detect_provider() is None
|
||||
|
||||
def test_detect_none_on_resolution_error(self, monkeypatch):
|
||||
def _boom():
|
||||
raise RuntimeError("broken")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _boom)
|
||||
assert detect_provider() is None
|
||||
239
hermes_code/tests/acp/test_events.py
Normal file
239
hermes_code/tests/acp/test_events.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Tests for acp_adapter.events — callback factories for ACP notifications."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Future
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.schema import ToolCallStart, ToolCallProgress, AgentThoughtChunk, AgentMessageChunk
|
||||
|
||||
from acp_adapter.events import (
|
||||
make_message_cb,
|
||||
make_step_cb,
|
||||
make_thinking_cb,
|
||||
make_tool_progress_cb,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_conn():
|
||||
"""Mock ACP Client connection."""
|
||||
conn = MagicMock(spec=acp.Client)
|
||||
conn.session_update = AsyncMock()
|
||||
return conn
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def event_loop_fixture():
|
||||
"""Create a real event loop for testing threadsafe coroutine submission."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool progress callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolProgressCallback:
|
||||
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
|
||||
"""Tool progress should emit a ToolCallStart update."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
# Run callback in the event loop context
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("terminal", "$ ls -la", {"command": "ls -la"})
|
||||
|
||||
# Should have tracked the tool call ID
|
||||
assert "terminal" in tool_call_ids
|
||||
|
||||
# Should have called run_coroutine_threadsafe
|
||||
mock_rcts.assert_called_once()
|
||||
coro = mock_rcts.call_args[0][0]
|
||||
# The coroutine should be conn.session_update
|
||||
assert mock_conn.session_update.called or coro is not None
|
||||
|
||||
def test_handles_string_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is a JSON string, it should be parsed."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("read_file", "Reading /etc/hosts", '{"path": "/etc/hosts"}')
|
||||
|
||||
assert "read_file" in tool_call_ids
|
||||
|
||||
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is not a dict, it should be wrapped."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("terminal", "$ echo hi", None)
|
||||
|
||||
assert "terminal" in tool_call_ids
|
||||
|
||||
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
|
||||
"""Multiple same-name tool calls should be tracked independently in order."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
progress_cb("terminal", "$ ls", {"command": "ls"})
|
||||
progress_cb("terminal", "$ pwd", {"command": "pwd"})
|
||||
assert len(tool_call_ids["terminal"]) == 2
|
||||
|
||||
step_cb(1, [{"name": "terminal", "result": "ok-1"}])
|
||||
assert len(tool_call_ids["terminal"]) == 1
|
||||
|
||||
step_cb(2, [{"name": "terminal", "result": "ok-2"}])
|
||||
assert "terminal" not in tool_call_ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thinking callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestThinkingCallback:
|
||||
def test_emits_thought_chunk(self, mock_conn, event_loop_fixture):
|
||||
"""Thinking callback should emit AgentThoughtChunk."""
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_thinking_cb(mock_conn, "session-1", loop)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("Analyzing the code...")
|
||||
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
def test_ignores_empty_text(self, mock_conn, event_loop_fixture):
|
||||
"""Empty text should not emit any update."""
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_thinking_cb(mock_conn, "session-1", loop)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb("")
|
||||
|
||||
mock_rcts.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStepCallback:
|
||||
def test_completes_tracked_tool_calls(self, mock_conn, event_loop_fixture):
|
||||
"""Step callback should mark tracked tools as completed."""
|
||||
tool_call_ids = {"terminal": "tc-abc123"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "terminal", "result": "success"}])
|
||||
|
||||
# Tool should have been removed from tracking
|
||||
assert "terminal" not in tool_call_ids
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
def test_ignores_untracked_tools(self, mock_conn, event_loop_fixture):
|
||||
"""Tools not in tool_call_ids should be silently ignored."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb(1, [{"name": "unknown_tool", "result": "ok"}])
|
||||
|
||||
mock_rcts.assert_not_called()
|
||||
|
||||
def test_handles_string_tool_info(self, mock_conn, event_loop_fixture):
|
||||
"""Tool info as a string (just the name) should work."""
|
||||
tool_call_ids = {"read_file": "tc-def456"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(2, ["read_file"])
|
||||
|
||||
assert "read_file" not in tool_call_ids
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageCallback:
|
||||
def test_emits_agent_message_chunk(self, mock_conn, event_loop_fixture):
|
||||
"""Message callback should emit AgentMessageChunk."""
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_message_cb(mock_conn, "session-1", loop)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("Here is your answer.")
|
||||
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
def test_ignores_empty_message(self, mock_conn, event_loop_fixture):
|
||||
"""Empty text should not emit any update."""
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_message_cb(mock_conn, "session-1", loop)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb("")
|
||||
|
||||
mock_rcts.assert_not_called()
|
||||
75
hermes_code/tests/acp/test_permissions.py
Normal file
75
hermes_code/tests/acp/test_permissions.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Tests for acp_adapter.permissions — ACP approval bridging."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Future
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from acp.schema import (
|
||||
AllowedOutcome,
|
||||
DeniedOutcome,
|
||||
RequestPermissionResponse,
|
||||
)
|
||||
from acp_adapter.permissions import make_approval_callback
|
||||
|
||||
|
||||
def _make_response(outcome):
|
||||
"""Helper to build a RequestPermissionResponse with the given outcome."""
|
||||
return RequestPermissionResponse(outcome=outcome)
|
||||
|
||||
|
||||
def _setup_callback(outcome, timeout=60.0):
|
||||
"""
|
||||
Create a callback wired to a mock request_permission coroutine
|
||||
that resolves to the given outcome.
|
||||
|
||||
Returns:
|
||||
(callback, mock_request_permission_fn)
|
||||
"""
|
||||
loop = MagicMock(spec=asyncio.AbstractEventLoop)
|
||||
mock_rp = MagicMock(name="request_permission")
|
||||
|
||||
response = _make_response(outcome)
|
||||
|
||||
# Patch asyncio.run_coroutine_threadsafe so it returns a future
|
||||
# that immediately yields the response.
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = response
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", return_value=future):
|
||||
cb = make_approval_callback(mock_rp, loop, session_id="s1", timeout=timeout)
|
||||
result = cb("rm -rf /", "dangerous command")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TestApprovalMapping:
|
||||
def test_approval_allow_once_maps_correctly(self):
|
||||
outcome = AllowedOutcome(option_id="allow_once", outcome="selected")
|
||||
result = _setup_callback(outcome)
|
||||
assert result == "once"
|
||||
|
||||
def test_approval_allow_always_maps_correctly(self):
|
||||
outcome = AllowedOutcome(option_id="allow_always", outcome="selected")
|
||||
result = _setup_callback(outcome)
|
||||
assert result == "always"
|
||||
|
||||
def test_approval_deny_maps_correctly(self):
|
||||
outcome = DeniedOutcome(outcome="cancelled")
|
||||
result = _setup_callback(outcome)
|
||||
assert result == "deny"
|
||||
|
||||
def test_approval_timeout_returns_deny(self):
|
||||
"""When the future times out, the callback should return 'deny'."""
|
||||
loop = MagicMock(spec=asyncio.AbstractEventLoop)
|
||||
mock_rp = MagicMock(name="request_permission")
|
||||
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.side_effect = TimeoutError("timed out")
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", return_value=future):
|
||||
cb = make_approval_callback(mock_rp, loop, session_id="s1", timeout=0.01)
|
||||
result = cb("rm -rf /", "dangerous")
|
||||
|
||||
assert result == "deny"
|
||||
436
hermes_code/tests/acp/test_server.py
Normal file
436
hermes_code/tests/acp/test_server.py
Normal file
|
|
@ -0,0 +1,436 @@
|
|||
"""Tests for acp_adapter.server — HermesACPAgent ACP server."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.schema import (
|
||||
AgentCapabilities,
|
||||
AuthenticateResponse,
|
||||
Implementation,
|
||||
InitializeResponse,
|
||||
ListSessionsResponse,
|
||||
LoadSessionResponse,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
Usage,
|
||||
)
|
||||
from acp_adapter.server import HermesACPAgent, HERMES_VERSION
|
||||
from acp_adapter.session import SessionManager
|
||||
from hermes_state import SessionDB
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_manager():
|
||||
"""SessionManager with a mock agent factory."""
|
||||
return SessionManager(agent_factory=lambda: MagicMock(name="MockAIAgent"))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent(mock_manager):
|
||||
"""HermesACPAgent backed by a mock session manager."""
|
||||
return HermesACPAgent(session_manager=mock_manager)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# initialize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInitialize:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_returns_correct_protocol_version(self, agent):
|
||||
resp = await agent.initialize(protocol_version=1)
|
||||
assert isinstance(resp, InitializeResponse)
|
||||
assert resp.protocol_version == acp.PROTOCOL_VERSION
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_returns_agent_info(self, agent):
|
||||
resp = await agent.initialize(protocol_version=1)
|
||||
assert resp.agent_info is not None
|
||||
assert isinstance(resp.agent_info, Implementation)
|
||||
assert resp.agent_info.name == "hermes-agent"
|
||||
assert resp.agent_info.version == HERMES_VERSION
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_returns_capabilities(self, agent):
|
||||
resp = await agent.initialize(protocol_version=1)
|
||||
caps = resp.agent_capabilities
|
||||
assert isinstance(caps, AgentCapabilities)
|
||||
assert caps.session_capabilities is not None
|
||||
assert caps.session_capabilities.fork is not None
|
||||
assert caps.session_capabilities.list is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# authenticate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthenticate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_with_provider_configured(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"acp_adapter.server.has_provider",
|
||||
lambda: True,
|
||||
)
|
||||
resp = await agent.authenticate(method_id="openrouter")
|
||||
assert isinstance(resp, AuthenticateResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_without_provider(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"acp_adapter.server.has_provider",
|
||||
lambda: False,
|
||||
)
|
||||
resp = await agent.authenticate(method_id="openrouter")
|
||||
assert resp is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# new_session / cancel / load / resume
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionOps:
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_creates_session(self, agent):
|
||||
resp = await agent.new_session(cwd="/home/user/project")
|
||||
assert isinstance(resp, NewSessionResponse)
|
||||
assert resp.session_id
|
||||
# Session should be retrievable from the manager
|
||||
state = agent.session_manager.get_session(resp.session_id)
|
||||
assert state is not None
|
||||
assert state.cwd == "/home/user/project"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_sets_event(self, agent):
|
||||
resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(resp.session_id)
|
||||
assert not state.cancel_event.is_set()
|
||||
await agent.cancel(session_id=resp.session_id)
|
||||
assert state.cancel_event.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_session_is_noop(self, agent):
|
||||
# Should not raise
|
||||
await agent.cancel(session_id="does-not-exist")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_returns_response(self, agent):
|
||||
resp = await agent.new_session(cwd="/tmp")
|
||||
load_resp = await agent.load_session(cwd="/tmp", session_id=resp.session_id)
|
||||
assert isinstance(load_resp, LoadSessionResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_not_found_returns_none(self, agent):
|
||||
resp = await agent.load_session(cwd="/tmp", session_id="bogus")
|
||||
assert resp is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_returns_response(self, agent):
|
||||
resp = await agent.new_session(cwd="/tmp")
|
||||
resume_resp = await agent.resume_session(cwd="/tmp", session_id=resp.session_id)
|
||||
assert isinstance(resume_resp, ResumeSessionResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_creates_new_if_missing(self, agent):
|
||||
resume_resp = await agent.resume_session(cwd="/tmp", session_id="nonexistent")
|
||||
assert isinstance(resume_resp, ResumeSessionResponse)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list / fork
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListAndFork:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions(self, agent):
|
||||
await agent.new_session(cwd="/a")
|
||||
await agent.new_session(cwd="/b")
|
||||
resp = await agent.list_sessions()
|
||||
assert isinstance(resp, ListSessionsResponse)
|
||||
assert len(resp.sessions) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fork_session(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/original")
|
||||
fork_resp = await agent.fork_session(cwd="/forked", session_id=new_resp.session_id)
|
||||
assert fork_resp.session_id
|
||||
assert fork_resp.session_id != new_resp.session_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPrompt:
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_returns_refusal_for_unknown_session(self, agent):
|
||||
prompt = [TextContentBlock(type="text", text="hello")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id="nonexistent")
|
||||
assert isinstance(resp, PromptResponse)
|
||||
assert resp.stop_reason == "refusal"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_returns_end_turn_for_empty_message(self, agent):
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
prompt = [TextContentBlock(type="text", text=" ")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
assert resp.stop_reason == "end_turn"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_runs_agent(self, agent):
|
||||
"""The prompt method should call run_conversation on the agent."""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
# Mock the agent's run_conversation
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "Hello! How can I help?",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "Hello! How can I help?"},
|
||||
],
|
||||
})
|
||||
|
||||
# Set up a mock connection
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="hello")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, PromptResponse)
|
||||
assert resp.stop_reason == "end_turn"
|
||||
state.agent.run_conversation.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_updates_history(self, agent):
|
||||
"""After a prompt, session history should be updated."""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
expected_history = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hey"},
|
||||
]
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "hey",
|
||||
"messages": expected_history,
|
||||
})
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="hi")]
|
||||
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert state.history == expected_history
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_sends_final_message_update(self, agent):
|
||||
"""The final response should be sent as an AgentMessageChunk."""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "I can help with that!",
|
||||
"messages": [],
|
||||
})
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="help me")]
|
||||
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
# session_update should have been called with the final message
|
||||
mock_conn.session_update.assert_called()
|
||||
# Get the last call's update argument
|
||||
last_call = mock_conn.session_update.call_args_list[-1]
|
||||
update = last_call[1].get("update") or last_call[0][1]
|
||||
assert update.session_update == "agent_message_chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_cancelled_returns_cancelled_stop_reason(self, agent):
|
||||
"""If cancel is called during prompt, stop_reason should be 'cancelled'."""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
# Simulate cancel being set during execution
|
||||
state.cancel_event.set()
|
||||
return {"final_response": "interrupted", "messages": []}
|
||||
|
||||
state.agent.run_conversation = mock_run
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="do something")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert resp.stop_reason == "cancelled"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_connect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOnConnect:
|
||||
def test_on_connect_stores_client(self, agent):
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
agent.on_connect(mock_conn)
|
||||
assert agent._conn is mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slash commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlashCommands:
|
||||
"""Test slash command dispatch in the ACP adapter."""
|
||||
|
||||
def _make_state(self, mock_manager):
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
state.agent.model = "test-model"
|
||||
state.agent.provider = "openrouter"
|
||||
state.model = "test-model"
|
||||
return state
|
||||
|
||||
def test_help_lists_commands(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/help", state)
|
||||
assert result is not None
|
||||
assert "/help" in result
|
||||
assert "/model" in result
|
||||
assert "/tools" in result
|
||||
assert "/reset" in result
|
||||
|
||||
def test_model_shows_current(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/model", state)
|
||||
assert "test-model" in result
|
||||
|
||||
def test_context_empty(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
state.history = []
|
||||
result = agent._handle_slash_command("/context", state)
|
||||
assert "empty" in result.lower()
|
||||
|
||||
def test_context_with_messages(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
state.history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
result = agent._handle_slash_command("/context", state)
|
||||
assert "2 messages" in result
|
||||
assert "user: 1" in result
|
||||
|
||||
def test_reset_clears_history(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
state.history = [{"role": "user", "content": "hello"}]
|
||||
result = agent._handle_slash_command("/reset", state)
|
||||
assert "cleared" in result.lower()
|
||||
assert len(state.history) == 0
|
||||
|
||||
def test_version(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/version", state)
|
||||
assert HERMES_VERSION in result
|
||||
|
||||
def test_unknown_command_returns_none(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/nonexistent", state)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_command_intercepted_in_prompt(self, agent, mock_manager):
|
||||
"""Slash commands should be handled without calling the LLM."""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
mock_conn = AsyncMock(spec=acp.Client)
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="/help")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert resp.stop_reason == "end_turn"
|
||||
mock_conn.session_update.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_slash_falls_through_to_llm(self, agent, mock_manager):
|
||||
"""Unknown /commands should be sent to the LLM, not intercepted."""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
mock_conn = AsyncMock(spec=acp.Client)
|
||||
agent._conn = mock_conn
|
||||
|
||||
# Mock run_in_executor to avoid actually running the agent
|
||||
with patch("asyncio.get_running_loop") as mock_loop:
|
||||
mock_loop.return_value.run_in_executor = AsyncMock(return_value={
|
||||
"final_response": "I processed /foo",
|
||||
"messages": [],
|
||||
})
|
||||
prompt = [TextContentBlock(type="text", text="/foo bar")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert resp.stop_reason == "end_turn"
|
||||
|
||||
def test_model_switch_uses_requested_provider(self, tmp_path, monkeypatch):
|
||||
"""`/model provider:model` should rebuild the ACP agent on that provider."""
|
||||
runtime_calls = []
|
||||
|
||||
def fake_resolve_runtime_provider(requested=None, **kwargs):
|
||||
runtime_calls.append(requested)
|
||||
provider = requested or "openrouter"
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
|
||||
"base_url": f"https://{provider}.example/v1",
|
||||
"api_key": f"{provider}-key",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
|
||||
def fake_agent(**kwargs):
|
||||
return SimpleNamespace(
|
||||
model=kwargs.get("model"),
|
||||
provider=kwargs.get("provider"),
|
||||
base_url=kwargs.get("base_url"),
|
||||
api_mode=kwargs.get("api_mode"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
|
||||
"model": {"provider": "openrouter", "default": "openrouter/gpt-5"}
|
||||
})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
state = manager.create_session(cwd="/tmp")
|
||||
result = acp_agent._cmd_model("anthropic:claude-sonnet-4-6", state)
|
||||
|
||||
assert "Provider: anthropic" in result
|
||||
assert state.agent.provider == "anthropic"
|
||||
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||
assert runtime_calls[-1] == "anthropic"
|
||||
331
hermes_code/tests/acp/test_session.py
Normal file
331
hermes_code/tests/acp/test_session.py
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
"""Tests for acp_adapter.session — SessionManager and SessionState."""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from acp_adapter.session import SessionManager, SessionState
|
||||
from hermes_state import SessionDB
|
||||
|
||||
|
||||
def _mock_agent():
|
||||
return MagicMock(name="MockAIAgent")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def manager():
|
||||
"""SessionManager with a mock agent factory (avoids needing API keys)."""
|
||||
return SessionManager(agent_factory=_mock_agent)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create / get
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateSession:
|
||||
def test_create_session_returns_state(self, manager):
|
||||
state = manager.create_session(cwd="/tmp/work")
|
||||
assert isinstance(state, SessionState)
|
||||
assert state.cwd == "/tmp/work"
|
||||
assert state.session_id
|
||||
assert state.history == []
|
||||
assert state.agent is not None
|
||||
|
||||
def test_create_session_registers_task_cwd(self, manager, monkeypatch):
|
||||
calls = []
|
||||
monkeypatch.setattr("acp_adapter.session._register_task_cwd", lambda task_id, cwd: calls.append((task_id, cwd)))
|
||||
state = manager.create_session(cwd="/tmp/work")
|
||||
assert calls == [(state.session_id, "/tmp/work")]
|
||||
|
||||
def test_session_ids_are_unique(self, manager):
|
||||
s1 = manager.create_session()
|
||||
s2 = manager.create_session()
|
||||
assert s1.session_id != s2.session_id
|
||||
|
||||
def test_get_session(self, manager):
|
||||
state = manager.create_session()
|
||||
fetched = manager.get_session(state.session_id)
|
||||
assert fetched is state
|
||||
|
||||
def test_get_nonexistent_session_returns_none(self, manager):
|
||||
assert manager.get_session("does-not-exist") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fork
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestForkSession:
|
||||
def test_fork_session_deep_copies_history(self, manager):
|
||||
original = manager.create_session()
|
||||
original.history.append({"role": "user", "content": "hello"})
|
||||
original.history.append({"role": "assistant", "content": "hi"})
|
||||
|
||||
forked = manager.fork_session(original.session_id, cwd="/new")
|
||||
assert forked is not None
|
||||
|
||||
# History should be equal in content
|
||||
assert len(forked.history) == 2
|
||||
assert forked.history[0]["content"] == "hello"
|
||||
|
||||
# But a deep copy — mutating one doesn't affect the other
|
||||
forked.history.append({"role": "user", "content": "extra"})
|
||||
assert len(original.history) == 2
|
||||
assert len(forked.history) == 3
|
||||
|
||||
def test_fork_session_has_new_id(self, manager):
|
||||
original = manager.create_session()
|
||||
forked = manager.fork_session(original.session_id)
|
||||
assert forked is not None
|
||||
assert forked.session_id != original.session_id
|
||||
|
||||
def test_fork_nonexistent_returns_none(self, manager):
|
||||
assert manager.fork_session("bogus-id") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list / cleanup / remove
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListAndCleanup:
|
||||
def test_list_sessions_empty(self, manager):
|
||||
assert manager.list_sessions() == []
|
||||
|
||||
def test_list_sessions_returns_created(self, manager):
|
||||
s1 = manager.create_session(cwd="/a")
|
||||
s2 = manager.create_session(cwd="/b")
|
||||
listing = manager.list_sessions()
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert s1.session_id in ids
|
||||
assert s2.session_id in ids
|
||||
assert len(listing) == 2
|
||||
|
||||
def test_cleanup_clears_all(self, manager):
|
||||
manager.create_session()
|
||||
manager.create_session()
|
||||
assert len(manager.list_sessions()) == 2
|
||||
manager.cleanup()
|
||||
assert manager.list_sessions() == []
|
||||
|
||||
def test_remove_session(self, manager):
|
||||
state = manager.create_session()
|
||||
assert manager.remove_session(state.session_id) is True
|
||||
assert manager.get_session(state.session_id) is None
|
||||
# Removing again returns False
|
||||
assert manager.remove_session(state.session_id) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# persistence — sessions survive process restarts (via SessionDB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistence:
|
||||
"""Verify that sessions are persisted to SessionDB and can be restored."""
|
||||
|
||||
def test_create_session_writes_to_db(self, manager):
|
||||
state = manager.create_session(cwd="/project")
|
||||
db = manager._get_db()
|
||||
assert db is not None
|
||||
row = db.get_session(state.session_id)
|
||||
assert row is not None
|
||||
assert row["source"] == "acp"
|
||||
# cwd stored in model_config JSON
|
||||
mc = json.loads(row["model_config"])
|
||||
assert mc["cwd"] == "/project"
|
||||
|
||||
def test_get_session_restores_from_db(self, manager):
|
||||
"""Simulate process restart: create session, drop from memory, get again."""
|
||||
state = manager.create_session(cwd="/work")
|
||||
state.history.append({"role": "user", "content": "hello"})
|
||||
state.history.append({"role": "assistant", "content": "hi there"})
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
sid = state.session_id
|
||||
|
||||
# Drop from in-memory store (simulates process restart).
|
||||
with manager._lock:
|
||||
del manager._sessions[sid]
|
||||
|
||||
# get_session should transparently restore from DB.
|
||||
restored = manager.get_session(sid)
|
||||
assert restored is not None
|
||||
assert restored.session_id == sid
|
||||
assert restored.cwd == "/work"
|
||||
assert len(restored.history) == 2
|
||||
assert restored.history[0]["content"] == "hello"
|
||||
assert restored.history[1]["content"] == "hi there"
|
||||
# Agent should have been recreated.
|
||||
assert restored.agent is not None
|
||||
|
||||
def test_save_session_updates_db(self, manager):
|
||||
state = manager.create_session()
|
||||
state.history.append({"role": "user", "content": "test"})
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
db = manager._get_db()
|
||||
messages = db.get_messages_as_conversation(state.session_id)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "test"
|
||||
|
||||
def test_remove_session_deletes_from_db(self, manager):
|
||||
state = manager.create_session()
|
||||
db = manager._get_db()
|
||||
assert db.get_session(state.session_id) is not None
|
||||
manager.remove_session(state.session_id)
|
||||
assert db.get_session(state.session_id) is None
|
||||
|
||||
def test_cleanup_removes_all_from_db(self, manager):
|
||||
s1 = manager.create_session()
|
||||
s2 = manager.create_session()
|
||||
db = manager._get_db()
|
||||
assert db.get_session(s1.session_id) is not None
|
||||
assert db.get_session(s2.session_id) is not None
|
||||
manager.cleanup()
|
||||
assert db.get_session(s1.session_id) is None
|
||||
assert db.get_session(s2.session_id) is None
|
||||
|
||||
def test_list_sessions_includes_db_only(self, manager):
|
||||
"""Sessions only in DB (not in memory) appear in list_sessions."""
|
||||
state = manager.create_session(cwd="/db-only")
|
||||
sid = state.session_id
|
||||
|
||||
# Drop from memory.
|
||||
with manager._lock:
|
||||
del manager._sessions[sid]
|
||||
|
||||
listing = manager.list_sessions()
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert sid in ids
|
||||
|
||||
def test_fork_restores_source_from_db(self, manager):
|
||||
"""Forking a session that is only in DB should work."""
|
||||
original = manager.create_session()
|
||||
original.history.append({"role": "user", "content": "context"})
|
||||
manager.save_session(original.session_id)
|
||||
|
||||
# Drop original from memory.
|
||||
with manager._lock:
|
||||
del manager._sessions[original.session_id]
|
||||
|
||||
forked = manager.fork_session(original.session_id, cwd="/fork")
|
||||
assert forked is not None
|
||||
assert len(forked.history) == 1
|
||||
assert forked.history[0]["content"] == "context"
|
||||
assert forked.session_id != original.session_id
|
||||
|
||||
def test_update_cwd_restores_from_db(self, manager):
|
||||
state = manager.create_session(cwd="/old")
|
||||
sid = state.session_id
|
||||
|
||||
with manager._lock:
|
||||
del manager._sessions[sid]
|
||||
|
||||
updated = manager.update_cwd(sid, "/new")
|
||||
assert updated is not None
|
||||
assert updated.cwd == "/new"
|
||||
|
||||
# Should also be persisted in DB.
|
||||
db = manager._get_db()
|
||||
row = db.get_session(sid)
|
||||
mc = json.loads(row["model_config"])
|
||||
assert mc["cwd"] == "/new"
|
||||
|
||||
def test_only_restores_acp_sessions(self, manager):
|
||||
"""get_session should not restore non-ACP sessions from DB."""
|
||||
db = manager._get_db()
|
||||
# Manually create a CLI session in the DB.
|
||||
db.create_session(session_id="cli-session-123", source="cli", model="test")
|
||||
# Should not be found via ACP SessionManager.
|
||||
assert manager.get_session("cli-session-123") is None
|
||||
|
||||
def test_sessions_searchable_via_fts(self, manager):
|
||||
"""ACP sessions stored in SessionDB are searchable via FTS5."""
|
||||
state = manager.create_session()
|
||||
state.history.append({"role": "user", "content": "how do I configure nginx"})
|
||||
state.history.append({"role": "assistant", "content": "Here is the nginx config..."})
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
db = manager._get_db()
|
||||
results = db.search_messages("nginx")
|
||||
assert len(results) > 0
|
||||
session_ids = {r["session_id"] for r in results}
|
||||
assert state.session_id in session_ids
|
||||
|
||||
def test_tool_calls_persisted(self, manager):
|
||||
"""Messages with tool_calls should round-trip through the DB."""
|
||||
state = manager.create_session()
|
||||
state.history.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "tc_1", "type": "function",
|
||||
"function": {"name": "terminal", "arguments": "{}"}}],
|
||||
})
|
||||
state.history.append({
|
||||
"role": "tool",
|
||||
"content": "output here",
|
||||
"tool_call_id": "tc_1",
|
||||
"name": "terminal",
|
||||
})
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
# Drop from memory, restore from DB.
|
||||
with manager._lock:
|
||||
del manager._sessions[state.session_id]
|
||||
|
||||
restored = manager.get_session(state.session_id)
|
||||
assert restored is not None
|
||||
assert len(restored.history) == 2
|
||||
assert restored.history[0].get("tool_calls") is not None
|
||||
assert restored.history[1].get("tool_call_id") == "tc_1"
|
||||
|
||||
def test_restore_preserves_persisted_provider_snapshot(self, tmp_path, monkeypatch):
|
||||
"""Restored ACP sessions should keep their original runtime provider."""
|
||||
runtime_choice = {"provider": "anthropic"}
|
||||
|
||||
def fake_resolve_runtime_provider(requested=None, **kwargs):
|
||||
provider = requested or runtime_choice["provider"]
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
|
||||
"base_url": f"https://{provider}.example/v1",
|
||||
"api_key": f"{provider}-key",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
|
||||
def fake_agent(**kwargs):
|
||||
return SimpleNamespace(
|
||||
model=kwargs.get("model"),
|
||||
provider=kwargs.get("provider"),
|
||||
base_url=kwargs.get("base_url"),
|
||||
api_mode=kwargs.get("api_mode"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
|
||||
"model": {"provider": runtime_choice["provider"], "default": "test-model"}
|
||||
})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
db = SessionDB(tmp_path / "state.db")
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
manager = SessionManager(db=db)
|
||||
state = manager.create_session(cwd="/work")
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
with manager._lock:
|
||||
del manager._sessions[state.session_id]
|
||||
|
||||
runtime_choice["provider"] = "openrouter"
|
||||
restored = manager.get_session(state.session_id)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.agent.provider == "anthropic"
|
||||
assert restored.agent.base_url == "https://anthropic.example/v1"
|
||||
236
hermes_code/tests/acp/test_tools.py
Normal file
236
hermes_code/tests/acp/test_tools.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
"""Tests for acp_adapter.tools — tool kind mapping and ACP content building."""
|
||||
|
||||
import pytest
|
||||
|
||||
from acp_adapter.tools import (
|
||||
TOOL_KIND_MAP,
|
||||
build_tool_complete,
|
||||
build_tool_start,
|
||||
build_tool_title,
|
||||
extract_locations,
|
||||
get_tool_kind,
|
||||
make_tool_call_id,
|
||||
)
|
||||
from acp.schema import (
|
||||
FileEditToolCallContent,
|
||||
ContentToolCallContent,
|
||||
ToolCallLocation,
|
||||
ToolCallStart,
|
||||
ToolCallProgress,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TOOL_KIND_MAP coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
COMMON_HERMES_TOOLS = ["read_file", "search_files", "terminal", "patch", "write_file", "process"]
|
||||
|
||||
|
||||
class TestToolKindMap:
|
||||
def test_all_hermes_tools_have_kind(self):
|
||||
"""Every common hermes tool should appear in TOOL_KIND_MAP."""
|
||||
for tool in COMMON_HERMES_TOOLS:
|
||||
assert tool in TOOL_KIND_MAP, f"{tool} missing from TOOL_KIND_MAP"
|
||||
|
||||
def test_tool_kind_read_file(self):
|
||||
assert get_tool_kind("read_file") == "read"
|
||||
|
||||
def test_tool_kind_terminal(self):
|
||||
assert get_tool_kind("terminal") == "execute"
|
||||
|
||||
def test_tool_kind_patch(self):
|
||||
assert get_tool_kind("patch") == "edit"
|
||||
|
||||
def test_tool_kind_write_file(self):
|
||||
assert get_tool_kind("write_file") == "edit"
|
||||
|
||||
def test_tool_kind_web_search(self):
|
||||
assert get_tool_kind("web_search") == "fetch"
|
||||
|
||||
def test_tool_kind_execute_code(self):
|
||||
assert get_tool_kind("execute_code") == "execute"
|
||||
|
||||
def test_tool_kind_browser_navigate(self):
|
||||
assert get_tool_kind("browser_navigate") == "fetch"
|
||||
|
||||
def test_unknown_tool_returns_other_kind(self):
|
||||
assert get_tool_kind("nonexistent_tool_xyz") == "other"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# make_tool_call_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMakeToolCallId:
|
||||
def test_returns_string(self):
|
||||
tc_id = make_tool_call_id()
|
||||
assert isinstance(tc_id, str)
|
||||
|
||||
def test_starts_with_tc_prefix(self):
|
||||
tc_id = make_tool_call_id()
|
||||
assert tc_id.startswith("tc-")
|
||||
|
||||
def test_ids_are_unique(self):
|
||||
ids = {make_tool_call_id() for _ in range(100)}
|
||||
assert len(ids) == 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_tool_title
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildToolTitle:
|
||||
def test_terminal_title_includes_command(self):
|
||||
title = build_tool_title("terminal", {"command": "ls -la /tmp"})
|
||||
assert "ls -la /tmp" in title
|
||||
|
||||
def test_terminal_title_truncates_long_command(self):
|
||||
long_cmd = "x" * 200
|
||||
title = build_tool_title("terminal", {"command": long_cmd})
|
||||
assert len(title) < 120
|
||||
assert "..." in title
|
||||
|
||||
def test_read_file_title(self):
|
||||
title = build_tool_title("read_file", {"path": "/etc/hosts"})
|
||||
assert "/etc/hosts" in title
|
||||
|
||||
def test_patch_title(self):
|
||||
title = build_tool_title("patch", {"path": "main.py", "mode": "replace"})
|
||||
assert "main.py" in title
|
||||
|
||||
def test_search_title(self):
|
||||
title = build_tool_title("search_files", {"pattern": "TODO"})
|
||||
assert "TODO" in title
|
||||
|
||||
def test_web_search_title(self):
|
||||
title = build_tool_title("web_search", {"query": "python asyncio"})
|
||||
assert "python asyncio" in title
|
||||
|
||||
def test_unknown_tool_uses_name(self):
|
||||
title = build_tool_title("some_new_tool", {"foo": "bar"})
|
||||
assert title == "some_new_tool"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_tool_start
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildToolStart:
|
||||
def test_build_tool_start_for_patch(self):
|
||||
"""patch should produce a FileEditToolCallContent (diff)."""
|
||||
args = {
|
||||
"path": "src/main.py",
|
||||
"old_string": "print('hello')",
|
||||
"new_string": "print('world')",
|
||||
}
|
||||
result = build_tool_start("tc-1", "patch", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "edit"
|
||||
# The first content item should be a diff
|
||||
assert len(result.content) >= 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path == "src/main.py"
|
||||
assert diff_item.new_text == "print('world')"
|
||||
assert diff_item.old_text == "print('hello')"
|
||||
|
||||
def test_build_tool_start_for_write_file(self):
|
||||
"""write_file should produce a FileEditToolCallContent (diff)."""
|
||||
args = {"path": "new_file.py", "content": "print('hello')"}
|
||||
result = build_tool_start("tc-w1", "write_file", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "edit"
|
||||
assert len(result.content) >= 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path == "new_file.py"
|
||||
|
||||
def test_build_tool_start_for_terminal(self):
|
||||
"""terminal should produce text content with the command."""
|
||||
args = {"command": "ls -la /tmp"}
|
||||
result = build_tool_start("tc-2", "terminal", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "execute"
|
||||
assert len(result.content) >= 1
|
||||
content_item = result.content[0]
|
||||
assert isinstance(content_item, ContentToolCallContent)
|
||||
# The wrapped text block should contain the command
|
||||
text = content_item.content.text
|
||||
assert "ls -la /tmp" in text
|
||||
|
||||
def test_build_tool_start_for_read_file(self):
|
||||
"""read_file should include the path in content."""
|
||||
args = {"path": "/etc/hosts", "offset": 1, "limit": 50}
|
||||
result = build_tool_start("tc-3", "read_file", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "read"
|
||||
assert len(result.content) >= 1
|
||||
content_item = result.content[0]
|
||||
assert isinstance(content_item, ContentToolCallContent)
|
||||
assert "/etc/hosts" in content_item.content.text
|
||||
|
||||
def test_build_tool_start_for_search(self):
|
||||
"""search_files should include pattern in content."""
|
||||
args = {"pattern": "TODO", "target": "content"}
|
||||
result = build_tool_start("tc-4", "search_files", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "search"
|
||||
assert "TODO" in result.content[0].content.text
|
||||
|
||||
def test_build_tool_start_generic_fallback(self):
|
||||
"""Unknown tools should get a generic text representation."""
|
||||
args = {"foo": "bar", "baz": 42}
|
||||
result = build_tool_start("tc-5", "some_tool", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "other"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_tool_complete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildToolComplete:
|
||||
def test_build_tool_complete_for_terminal(self):
|
||||
"""Completed terminal call should include output text."""
|
||||
result = build_tool_complete("tc-2", "terminal", "total 42\ndrwxr-xr-x 2 root root 4096 ...")
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert result.status == "completed"
|
||||
assert len(result.content) >= 1
|
||||
content_item = result.content[0]
|
||||
assert isinstance(content_item, ContentToolCallContent)
|
||||
assert "total 42" in content_item.content.text
|
||||
|
||||
def test_build_tool_complete_truncates_large_output(self):
|
||||
"""Very large outputs should be truncated."""
|
||||
big_output = "x" * 10000
|
||||
result = build_tool_complete("tc-6", "read_file", big_output)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
display_text = result.content[0].content.text
|
||||
assert len(display_text) < 6000
|
||||
assert "truncated" in display_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_locations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractLocations:
|
||||
def test_extract_locations_with_path(self):
|
||||
args = {"path": "src/app.py", "offset": 42}
|
||||
locs = extract_locations(args)
|
||||
assert len(locs) == 1
|
||||
assert isinstance(locs[0], ToolCallLocation)
|
||||
assert locs[0].path == "src/app.py"
|
||||
assert locs[0].line == 42
|
||||
|
||||
def test_extract_locations_without_path(self):
|
||||
args = {"command": "echo hi"}
|
||||
locs = extract_locations(args)
|
||||
assert locs == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue