Restore the ACP editor-integration implementation that was present on the original PR branch but did not actually land in main. Includes: - acp_adapter/ server, session manager, event bridge, auth, permissions, and tool helpers - hermes acp subcommand and hermes-acp entry point - hermes-acp curated toolset - ACP registry manifest, setup guide, and ACP test suite - jupyter-live-kernel data science skill from the original branch Also updates the revived ACP code for current main by: - resolving runtime providers through the modern shared provider router - binding ACP sessions to per-session cwd task overrides - tracking duplicate same-name tool calls with FIFO IDs - restoring terminal approval callbacks after prompts - normalizing supporting docs/skill metadata Validated with tests/acp and the full pytest suite (-n0).
This commit is contained in:
parent
2fe853bcc9
commit
25481d4286
24 changed files with 2625 additions and 6 deletions
0
tests/acp/__init__.py
Normal file
0
tests/acp/__init__.py
Normal file
56
tests/acp/test_auth.py
Normal file
56
tests/acp/test_auth.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Tests for acp_adapter.auth — provider detection."""
|
||||
|
||||
from acp_adapter.auth import has_provider, detect_provider
|
||||
|
||||
|
||||
class TestHasProvider:
|
||||
def test_has_provider_with_resolved_runtime(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "openrouter", "api_key": "sk-or-test"},
|
||||
)
|
||||
assert has_provider() is True
|
||||
|
||||
def test_has_no_provider_when_runtime_has_no_key(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "openrouter", "api_key": ""},
|
||||
)
|
||||
assert has_provider() is False
|
||||
|
||||
def test_has_no_provider_when_runtime_resolution_fails(self, monkeypatch):
|
||||
def _boom():
|
||||
raise RuntimeError("no provider")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _boom)
|
||||
assert has_provider() is False
|
||||
|
||||
|
||||
class TestDetectProvider:
|
||||
def test_detect_openrouter(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "openrouter", "api_key": "sk-or-test"},
|
||||
)
|
||||
assert detect_provider() == "openrouter"
|
||||
|
||||
def test_detect_anthropic(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "anthropic", "api_key": "sk-ant-test"},
|
||||
)
|
||||
assert detect_provider() == "anthropic"
|
||||
|
||||
def test_detect_none_when_no_key(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda: {"provider": "kimi-coding", "api_key": ""},
|
||||
)
|
||||
assert detect_provider() is None
|
||||
|
||||
def test_detect_none_on_resolution_error(self, monkeypatch):
|
||||
def _boom():
|
||||
raise RuntimeError("broken")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _boom)
|
||||
assert detect_provider() is None
|
||||
239
tests/acp/test_events.py
Normal file
239
tests/acp/test_events.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
"""Tests for acp_adapter.events — callback factories for ACP notifications."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Future
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.schema import ToolCallStart, ToolCallProgress, AgentThoughtChunk, AgentMessageChunk
|
||||
|
||||
from acp_adapter.events import (
|
||||
make_message_cb,
|
||||
make_step_cb,
|
||||
make_thinking_cb,
|
||||
make_tool_progress_cb,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_conn():
|
||||
"""Mock ACP Client connection."""
|
||||
conn = MagicMock(spec=acp.Client)
|
||||
conn.session_update = AsyncMock()
|
||||
return conn
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def event_loop_fixture():
|
||||
"""Create a real event loop for testing threadsafe coroutine submission."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool progress callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolProgressCallback:
|
||||
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
|
||||
"""Tool progress should emit a ToolCallStart update."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
# Run callback in the event loop context
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("terminal", "$ ls -la", {"command": "ls -la"})
|
||||
|
||||
# Should have tracked the tool call ID
|
||||
assert "terminal" in tool_call_ids
|
||||
|
||||
# Should have called run_coroutine_threadsafe
|
||||
mock_rcts.assert_called_once()
|
||||
coro = mock_rcts.call_args[0][0]
|
||||
# The coroutine should be conn.session_update
|
||||
assert mock_conn.session_update.called or coro is not None
|
||||
|
||||
def test_handles_string_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is a JSON string, it should be parsed."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("read_file", "Reading /etc/hosts", '{"path": "/etc/hosts"}')
|
||||
|
||||
assert "read_file" in tool_call_ids
|
||||
|
||||
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is not a dict, it should be wrapped."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("terminal", "$ echo hi", None)
|
||||
|
||||
assert "terminal" in tool_call_ids
|
||||
|
||||
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
|
||||
"""Multiple same-name tool calls should be tracked independently in order."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
progress_cb("terminal", "$ ls", {"command": "ls"})
|
||||
progress_cb("terminal", "$ pwd", {"command": "pwd"})
|
||||
assert len(tool_call_ids["terminal"]) == 2
|
||||
|
||||
step_cb(1, [{"name": "terminal", "result": "ok-1"}])
|
||||
assert len(tool_call_ids["terminal"]) == 1
|
||||
|
||||
step_cb(2, [{"name": "terminal", "result": "ok-2"}])
|
||||
assert "terminal" not in tool_call_ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thinking callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestThinkingCallback:
|
||||
def test_emits_thought_chunk(self, mock_conn, event_loop_fixture):
|
||||
"""Thinking callback should emit AgentThoughtChunk."""
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_thinking_cb(mock_conn, "session-1", loop)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("Analyzing the code...")
|
||||
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
def test_ignores_empty_text(self, mock_conn, event_loop_fixture):
|
||||
"""Empty text should not emit any update."""
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_thinking_cb(mock_conn, "session-1", loop)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb("")
|
||||
|
||||
mock_rcts.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStepCallback:
|
||||
def test_completes_tracked_tool_calls(self, mock_conn, event_loop_fixture):
|
||||
"""Step callback should mark tracked tools as completed."""
|
||||
tool_call_ids = {"terminal": "tc-abc123"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "terminal", "result": "success"}])
|
||||
|
||||
# Tool should have been removed from tracking
|
||||
assert "terminal" not in tool_call_ids
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
def test_ignores_untracked_tools(self, mock_conn, event_loop_fixture):
|
||||
"""Tools not in tool_call_ids should be silently ignored."""
|
||||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb(1, [{"name": "unknown_tool", "result": "ok"}])
|
||||
|
||||
mock_rcts.assert_not_called()
|
||||
|
||||
def test_handles_string_tool_info(self, mock_conn, event_loop_fixture):
|
||||
"""Tool info as a string (just the name) should work."""
|
||||
tool_call_ids = {"read_file": "tc-def456"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(2, ["read_file"])
|
||||
|
||||
assert "read_file" not in tool_call_ids
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageCallback:
|
||||
def test_emits_agent_message_chunk(self, mock_conn, event_loop_fixture):
|
||||
"""Message callback should emit AgentMessageChunk."""
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_message_cb(mock_conn, "session-1", loop)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("Here is your answer.")
|
||||
|
||||
mock_rcts.assert_called_once()
|
||||
|
||||
def test_ignores_empty_message(self, mock_conn, event_loop_fixture):
|
||||
"""Empty text should not emit any update."""
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_message_cb(mock_conn, "session-1", loop)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb("")
|
||||
|
||||
mock_rcts.assert_not_called()
|
||||
75
tests/acp/test_permissions.py
Normal file
75
tests/acp/test_permissions.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Tests for acp_adapter.permissions — ACP approval bridging."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Future
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from acp.schema import (
|
||||
AllowedOutcome,
|
||||
DeniedOutcome,
|
||||
RequestPermissionResponse,
|
||||
)
|
||||
from acp_adapter.permissions import make_approval_callback
|
||||
|
||||
|
||||
def _make_response(outcome):
|
||||
"""Helper to build a RequestPermissionResponse with the given outcome."""
|
||||
return RequestPermissionResponse(outcome=outcome)
|
||||
|
||||
|
||||
def _setup_callback(outcome, timeout=60.0):
|
||||
"""
|
||||
Create a callback wired to a mock request_permission coroutine
|
||||
that resolves to the given outcome.
|
||||
|
||||
Returns:
|
||||
(callback, mock_request_permission_fn)
|
||||
"""
|
||||
loop = MagicMock(spec=asyncio.AbstractEventLoop)
|
||||
mock_rp = MagicMock(name="request_permission")
|
||||
|
||||
response = _make_response(outcome)
|
||||
|
||||
# Patch asyncio.run_coroutine_threadsafe so it returns a future
|
||||
# that immediately yields the response.
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = response
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", return_value=future):
|
||||
cb = make_approval_callback(mock_rp, loop, session_id="s1", timeout=timeout)
|
||||
result = cb("rm -rf /", "dangerous command")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TestApprovalMapping:
|
||||
def test_approval_allow_once_maps_correctly(self):
|
||||
outcome = AllowedOutcome(option_id="allow_once", outcome="selected")
|
||||
result = _setup_callback(outcome)
|
||||
assert result == "once"
|
||||
|
||||
def test_approval_allow_always_maps_correctly(self):
|
||||
outcome = AllowedOutcome(option_id="allow_always", outcome="selected")
|
||||
result = _setup_callback(outcome)
|
||||
assert result == "always"
|
||||
|
||||
def test_approval_deny_maps_correctly(self):
|
||||
outcome = DeniedOutcome(outcome="cancelled")
|
||||
result = _setup_callback(outcome)
|
||||
assert result == "deny"
|
||||
|
||||
def test_approval_timeout_returns_deny(self):
|
||||
"""When the future times out, the callback should return 'deny'."""
|
||||
loop = MagicMock(spec=asyncio.AbstractEventLoop)
|
||||
mock_rp = MagicMock(name="request_permission")
|
||||
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.side_effect = TimeoutError("timed out")
|
||||
|
||||
with patch("acp_adapter.permissions.asyncio.run_coroutine_threadsafe", return_value=future):
|
||||
cb = make_approval_callback(mock_rp, loop, session_id="s1", timeout=0.01)
|
||||
result = cb("rm -rf /", "dangerous")
|
||||
|
||||
assert result == "deny"
|
||||
297
tests/acp/test_server.py
Normal file
297
tests/acp/test_server.py
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
"""Tests for acp_adapter.server — HermesACPAgent ACP server."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.schema import (
|
||||
AgentCapabilities,
|
||||
AuthenticateResponse,
|
||||
Implementation,
|
||||
InitializeResponse,
|
||||
ListSessionsResponse,
|
||||
LoadSessionResponse,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
Usage,
|
||||
)
|
||||
from acp_adapter.server import HermesACPAgent, HERMES_VERSION
|
||||
from acp_adapter.session import SessionManager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_manager():
|
||||
"""SessionManager with a mock agent factory."""
|
||||
return SessionManager(agent_factory=lambda: MagicMock(name="MockAIAgent"))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def agent(mock_manager):
|
||||
"""HermesACPAgent backed by a mock session manager."""
|
||||
return HermesACPAgent(session_manager=mock_manager)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# initialize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInitialize:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_returns_correct_protocol_version(self, agent):
|
||||
resp = await agent.initialize(protocol_version=1)
|
||||
assert isinstance(resp, InitializeResponse)
|
||||
assert resp.protocol_version == acp.PROTOCOL_VERSION
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_returns_agent_info(self, agent):
|
||||
resp = await agent.initialize(protocol_version=1)
|
||||
assert resp.agent_info is not None
|
||||
assert isinstance(resp.agent_info, Implementation)
|
||||
assert resp.agent_info.name == "hermes-agent"
|
||||
assert resp.agent_info.version == HERMES_VERSION
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_returns_capabilities(self, agent):
|
||||
resp = await agent.initialize(protocol_version=1)
|
||||
caps = resp.agent_capabilities
|
||||
assert isinstance(caps, AgentCapabilities)
|
||||
assert caps.session_capabilities is not None
|
||||
assert caps.session_capabilities.fork is not None
|
||||
assert caps.session_capabilities.list is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# authenticate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthenticate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_with_provider_configured(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"acp_adapter.server.has_provider",
|
||||
lambda: True,
|
||||
)
|
||||
resp = await agent.authenticate(method_id="openrouter")
|
||||
assert isinstance(resp, AuthenticateResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_without_provider(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"acp_adapter.server.has_provider",
|
||||
lambda: False,
|
||||
)
|
||||
resp = await agent.authenticate(method_id="openrouter")
|
||||
assert resp is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# new_session / cancel / load / resume
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionOps:
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_creates_session(self, agent):
|
||||
resp = await agent.new_session(cwd="/home/user/project")
|
||||
assert isinstance(resp, NewSessionResponse)
|
||||
assert resp.session_id
|
||||
# Session should be retrievable from the manager
|
||||
state = agent.session_manager.get_session(resp.session_id)
|
||||
assert state is not None
|
||||
assert state.cwd == "/home/user/project"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_sets_event(self, agent):
|
||||
resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(resp.session_id)
|
||||
assert not state.cancel_event.is_set()
|
||||
await agent.cancel(session_id=resp.session_id)
|
||||
assert state.cancel_event.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_session_is_noop(self, agent):
|
||||
# Should not raise
|
||||
await agent.cancel(session_id="does-not-exist")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_returns_response(self, agent):
|
||||
resp = await agent.new_session(cwd="/tmp")
|
||||
load_resp = await agent.load_session(cwd="/tmp", session_id=resp.session_id)
|
||||
assert isinstance(load_resp, LoadSessionResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_session_not_found_returns_none(self, agent):
|
||||
resp = await agent.load_session(cwd="/tmp", session_id="bogus")
|
||||
assert resp is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_returns_response(self, agent):
|
||||
resp = await agent.new_session(cwd="/tmp")
|
||||
resume_resp = await agent.resume_session(cwd="/tmp", session_id=resp.session_id)
|
||||
assert isinstance(resume_resp, ResumeSessionResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_session_creates_new_if_missing(self, agent):
|
||||
resume_resp = await agent.resume_session(cwd="/tmp", session_id="nonexistent")
|
||||
assert isinstance(resume_resp, ResumeSessionResponse)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list / fork
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListAndFork:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions(self, agent):
|
||||
await agent.new_session(cwd="/a")
|
||||
await agent.new_session(cwd="/b")
|
||||
resp = await agent.list_sessions()
|
||||
assert isinstance(resp, ListSessionsResponse)
|
||||
assert len(resp.sessions) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fork_session(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/original")
|
||||
fork_resp = await agent.fork_session(cwd="/forked", session_id=new_resp.session_id)
|
||||
assert fork_resp.session_id
|
||||
assert fork_resp.session_id != new_resp.session_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPrompt:
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_returns_refusal_for_unknown_session(self, agent):
|
||||
prompt = [TextContentBlock(type="text", text="hello")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id="nonexistent")
|
||||
assert isinstance(resp, PromptResponse)
|
||||
assert resp.stop_reason == "refusal"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_returns_end_turn_for_empty_message(self, agent):
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
prompt = [TextContentBlock(type="text", text=" ")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
assert resp.stop_reason == "end_turn"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_runs_agent(self, agent):
|
||||
"""The prompt method should call run_conversation on the agent."""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
# Mock the agent's run_conversation
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "Hello! How can I help?",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "Hello! How can I help?"},
|
||||
],
|
||||
})
|
||||
|
||||
# Set up a mock connection
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="hello")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, PromptResponse)
|
||||
assert resp.stop_reason == "end_turn"
|
||||
state.agent.run_conversation.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_updates_history(self, agent):
|
||||
"""After a prompt, session history should be updated."""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
expected_history = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hey"},
|
||||
]
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "hey",
|
||||
"messages": expected_history,
|
||||
})
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="hi")]
|
||||
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert state.history == expected_history
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_sends_final_message_update(self, agent):
|
||||
"""The final response should be sent as an AgentMessageChunk."""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "I can help with that!",
|
||||
"messages": [],
|
||||
})
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="help me")]
|
||||
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
# session_update should have been called with the final message
|
||||
mock_conn.session_update.assert_called()
|
||||
# Get the last call's update argument
|
||||
last_call = mock_conn.session_update.call_args_list[-1]
|
||||
update = last_call[1].get("update") or last_call[0][1]
|
||||
assert update.session_update == "agent_message_chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_cancelled_returns_cancelled_stop_reason(self, agent):
|
||||
"""If cancel is called during prompt, stop_reason should be 'cancelled'."""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
# Simulate cancel being set during execution
|
||||
state.cancel_event.set()
|
||||
return {"final_response": "interrupted", "messages": []}
|
||||
|
||||
state.agent.run_conversation = mock_run
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="do something")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert resp.stop_reason == "cancelled"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_connect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOnConnect:
|
||||
def test_on_connect_stores_client(self, agent):
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
agent.on_connect(mock_conn)
|
||||
assert agent._conn is mock_conn
|
||||
112
tests/acp/test_session.py
Normal file
112
tests/acp/test_session.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
"""Tests for acp_adapter.session — SessionManager and SessionState."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from acp_adapter.session import SessionManager, SessionState
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def manager():
|
||||
"""SessionManager with a mock agent factory (avoids needing API keys)."""
|
||||
return SessionManager(agent_factory=lambda: MagicMock(name="MockAIAgent"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create / get
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateSession:
|
||||
def test_create_session_returns_state(self, manager):
|
||||
state = manager.create_session(cwd="/tmp/work")
|
||||
assert isinstance(state, SessionState)
|
||||
assert state.cwd == "/tmp/work"
|
||||
assert state.session_id
|
||||
assert state.history == []
|
||||
assert state.agent is not None
|
||||
|
||||
def test_create_session_registers_task_cwd(self, manager, monkeypatch):
|
||||
calls = []
|
||||
monkeypatch.setattr("acp_adapter.session._register_task_cwd", lambda task_id, cwd: calls.append((task_id, cwd)))
|
||||
state = manager.create_session(cwd="/tmp/work")
|
||||
assert calls == [(state.session_id, "/tmp/work")]
|
||||
|
||||
def test_session_ids_are_unique(self, manager):
|
||||
s1 = manager.create_session()
|
||||
s2 = manager.create_session()
|
||||
assert s1.session_id != s2.session_id
|
||||
|
||||
def test_get_session(self, manager):
|
||||
state = manager.create_session()
|
||||
fetched = manager.get_session(state.session_id)
|
||||
assert fetched is state
|
||||
|
||||
def test_get_nonexistent_session_returns_none(self, manager):
|
||||
assert manager.get_session("does-not-exist") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fork
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestForkSession:
|
||||
def test_fork_session_deep_copies_history(self, manager):
|
||||
original = manager.create_session()
|
||||
original.history.append({"role": "user", "content": "hello"})
|
||||
original.history.append({"role": "assistant", "content": "hi"})
|
||||
|
||||
forked = manager.fork_session(original.session_id, cwd="/new")
|
||||
assert forked is not None
|
||||
|
||||
# History should be equal in content
|
||||
assert len(forked.history) == 2
|
||||
assert forked.history[0]["content"] == "hello"
|
||||
|
||||
# But a deep copy — mutating one doesn't affect the other
|
||||
forked.history.append({"role": "user", "content": "extra"})
|
||||
assert len(original.history) == 2
|
||||
assert len(forked.history) == 3
|
||||
|
||||
def test_fork_session_has_new_id(self, manager):
|
||||
original = manager.create_session()
|
||||
forked = manager.fork_session(original.session_id)
|
||||
assert forked is not None
|
||||
assert forked.session_id != original.session_id
|
||||
|
||||
def test_fork_nonexistent_returns_none(self, manager):
|
||||
assert manager.fork_session("bogus-id") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list / cleanup / remove
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListAndCleanup:
|
||||
def test_list_sessions_empty(self, manager):
|
||||
assert manager.list_sessions() == []
|
||||
|
||||
def test_list_sessions_returns_created(self, manager):
|
||||
s1 = manager.create_session(cwd="/a")
|
||||
s2 = manager.create_session(cwd="/b")
|
||||
listing = manager.list_sessions()
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert s1.session_id in ids
|
||||
assert s2.session_id in ids
|
||||
assert len(listing) == 2
|
||||
|
||||
def test_cleanup_clears_all(self, manager):
|
||||
manager.create_session()
|
||||
manager.create_session()
|
||||
assert len(manager.list_sessions()) == 2
|
||||
manager.cleanup()
|
||||
assert manager.list_sessions() == []
|
||||
|
||||
def test_remove_session(self, manager):
|
||||
state = manager.create_session()
|
||||
assert manager.remove_session(state.session_id) is True
|
||||
assert manager.get_session(state.session_id) is None
|
||||
# Removing again returns False
|
||||
assert manager.remove_session(state.session_id) is False
|
||||
236
tests/acp/test_tools.py
Normal file
236
tests/acp/test_tools.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
"""Tests for acp_adapter.tools — tool kind mapping and ACP content building."""
|
||||
|
||||
import pytest
|
||||
|
||||
from acp_adapter.tools import (
|
||||
TOOL_KIND_MAP,
|
||||
build_tool_complete,
|
||||
build_tool_start,
|
||||
build_tool_title,
|
||||
extract_locations,
|
||||
get_tool_kind,
|
||||
make_tool_call_id,
|
||||
)
|
||||
from acp.schema import (
|
||||
FileEditToolCallContent,
|
||||
ContentToolCallContent,
|
||||
ToolCallLocation,
|
||||
ToolCallStart,
|
||||
ToolCallProgress,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TOOL_KIND_MAP coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
COMMON_HERMES_TOOLS = ["read_file", "search_files", "terminal", "patch", "write_file", "process"]
|
||||
|
||||
|
||||
class TestToolKindMap:
|
||||
def test_all_hermes_tools_have_kind(self):
|
||||
"""Every common hermes tool should appear in TOOL_KIND_MAP."""
|
||||
for tool in COMMON_HERMES_TOOLS:
|
||||
assert tool in TOOL_KIND_MAP, f"{tool} missing from TOOL_KIND_MAP"
|
||||
|
||||
def test_tool_kind_read_file(self):
|
||||
assert get_tool_kind("read_file") == "read"
|
||||
|
||||
def test_tool_kind_terminal(self):
|
||||
assert get_tool_kind("terminal") == "execute"
|
||||
|
||||
def test_tool_kind_patch(self):
|
||||
assert get_tool_kind("patch") == "edit"
|
||||
|
||||
def test_tool_kind_write_file(self):
|
||||
assert get_tool_kind("write_file") == "edit"
|
||||
|
||||
def test_tool_kind_web_search(self):
|
||||
assert get_tool_kind("web_search") == "fetch"
|
||||
|
||||
def test_tool_kind_execute_code(self):
|
||||
assert get_tool_kind("execute_code") == "execute"
|
||||
|
||||
def test_tool_kind_browser_navigate(self):
|
||||
assert get_tool_kind("browser_navigate") == "fetch"
|
||||
|
||||
def test_unknown_tool_returns_other_kind(self):
|
||||
assert get_tool_kind("nonexistent_tool_xyz") == "other"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# make_tool_call_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMakeToolCallId:
|
||||
def test_returns_string(self):
|
||||
tc_id = make_tool_call_id()
|
||||
assert isinstance(tc_id, str)
|
||||
|
||||
def test_starts_with_tc_prefix(self):
|
||||
tc_id = make_tool_call_id()
|
||||
assert tc_id.startswith("tc-")
|
||||
|
||||
def test_ids_are_unique(self):
|
||||
ids = {make_tool_call_id() for _ in range(100)}
|
||||
assert len(ids) == 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_tool_title
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildToolTitle:
|
||||
def test_terminal_title_includes_command(self):
|
||||
title = build_tool_title("terminal", {"command": "ls -la /tmp"})
|
||||
assert "ls -la /tmp" in title
|
||||
|
||||
def test_terminal_title_truncates_long_command(self):
|
||||
long_cmd = "x" * 200
|
||||
title = build_tool_title("terminal", {"command": long_cmd})
|
||||
assert len(title) < 120
|
||||
assert "..." in title
|
||||
|
||||
def test_read_file_title(self):
|
||||
title = build_tool_title("read_file", {"path": "/etc/hosts"})
|
||||
assert "/etc/hosts" in title
|
||||
|
||||
def test_patch_title(self):
|
||||
title = build_tool_title("patch", {"path": "main.py", "mode": "replace"})
|
||||
assert "main.py" in title
|
||||
|
||||
def test_search_title(self):
|
||||
title = build_tool_title("search_files", {"pattern": "TODO"})
|
||||
assert "TODO" in title
|
||||
|
||||
def test_web_search_title(self):
|
||||
title = build_tool_title("web_search", {"query": "python asyncio"})
|
||||
assert "python asyncio" in title
|
||||
|
||||
def test_unknown_tool_uses_name(self):
|
||||
title = build_tool_title("some_new_tool", {"foo": "bar"})
|
||||
assert title == "some_new_tool"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_tool_start
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildToolStart:
|
||||
def test_build_tool_start_for_patch(self):
|
||||
"""patch should produce a FileEditToolCallContent (diff)."""
|
||||
args = {
|
||||
"path": "src/main.py",
|
||||
"old_string": "print('hello')",
|
||||
"new_string": "print('world')",
|
||||
}
|
||||
result = build_tool_start("tc-1", "patch", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "edit"
|
||||
# The first content item should be a diff
|
||||
assert len(result.content) >= 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path == "src/main.py"
|
||||
assert diff_item.new_text == "print('world')"
|
||||
assert diff_item.old_text == "print('hello')"
|
||||
|
||||
def test_build_tool_start_for_write_file(self):
|
||||
"""write_file should produce a FileEditToolCallContent (diff)."""
|
||||
args = {"path": "new_file.py", "content": "print('hello')"}
|
||||
result = build_tool_start("tc-w1", "write_file", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "edit"
|
||||
assert len(result.content) >= 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path == "new_file.py"
|
||||
|
||||
def test_build_tool_start_for_terminal(self):
|
||||
"""terminal should produce text content with the command."""
|
||||
args = {"command": "ls -la /tmp"}
|
||||
result = build_tool_start("tc-2", "terminal", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "execute"
|
||||
assert len(result.content) >= 1
|
||||
content_item = result.content[0]
|
||||
assert isinstance(content_item, ContentToolCallContent)
|
||||
# The wrapped text block should contain the command
|
||||
text = content_item.content.text
|
||||
assert "ls -la /tmp" in text
|
||||
|
||||
def test_build_tool_start_for_read_file(self):
|
||||
"""read_file should include the path in content."""
|
||||
args = {"path": "/etc/hosts", "offset": 1, "limit": 50}
|
||||
result = build_tool_start("tc-3", "read_file", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "read"
|
||||
assert len(result.content) >= 1
|
||||
content_item = result.content[0]
|
||||
assert isinstance(content_item, ContentToolCallContent)
|
||||
assert "/etc/hosts" in content_item.content.text
|
||||
|
||||
def test_build_tool_start_for_search(self):
|
||||
"""search_files should include pattern in content."""
|
||||
args = {"pattern": "TODO", "target": "content"}
|
||||
result = build_tool_start("tc-4", "search_files", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "search"
|
||||
assert "TODO" in result.content[0].content.text
|
||||
|
||||
def test_build_tool_start_generic_fallback(self):
|
||||
"""Unknown tools should get a generic text representation."""
|
||||
args = {"foo": "bar", "baz": 42}
|
||||
result = build_tool_start("tc-5", "some_tool", args)
|
||||
assert isinstance(result, ToolCallStart)
|
||||
assert result.kind == "other"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_tool_complete
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildToolComplete:
|
||||
def test_build_tool_complete_for_terminal(self):
|
||||
"""Completed terminal call should include output text."""
|
||||
result = build_tool_complete("tc-2", "terminal", "total 42\ndrwxr-xr-x 2 root root 4096 ...")
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert result.status == "completed"
|
||||
assert len(result.content) >= 1
|
||||
content_item = result.content[0]
|
||||
assert isinstance(content_item, ContentToolCallContent)
|
||||
assert "total 42" in content_item.content.text
|
||||
|
||||
def test_build_tool_complete_truncates_large_output(self):
|
||||
"""Very large outputs should be truncated."""
|
||||
big_output = "x" * 10000
|
||||
result = build_tool_complete("tc-6", "read_file", big_output)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
display_text = result.content[0].content.text
|
||||
assert len(display_text) < 6000
|
||||
assert "truncated" in display_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_locations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractLocations:
|
||||
def test_extract_locations_with_path(self):
|
||||
args = {"path": "src/app.py", "offset": 42}
|
||||
locs = extract_locations(args)
|
||||
assert len(locs) == 1
|
||||
assert isinstance(locs[0], ToolCallLocation)
|
||||
assert locs[0].path == "src/app.py"
|
||||
assert locs[0].line == 42
|
||||
|
||||
def test_extract_locations_without_path(self):
|
||||
args = {"command": "echo hi"}
|
||||
locs = extract_locations(args)
|
||||
assert locs == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue