import sys from pathlib import Path from types import ModuleType import pytest from aiohttp import web from sdk.interface import MessageChunk, MessageResponse from sdk.agent_session import AgentSessionClient, AgentSessionConfig, build_thread_key AGENT_ROOT = Path(__file__).resolve().parents[2] / "external" / "platform-agent" AGENT_API_ROOT = Path(__file__).resolve().parents[2] / "external" / "platform-agent_api" for path in (AGENT_ROOT, AGENT_API_ROOT): if str(path) not in sys.path: sys.path.insert(0, str(path)) if "fastapi" not in sys.modules: fastapi = ModuleType("fastapi") class _Router: def websocket(self, _path: str): def decorator(fn): return fn return decorator class _WebSocketDisconnect(Exception): pass def _depends(value): return value fastapi.APIRouter = _Router fastapi.WebSocket = object fastapi.WebSocketDisconnect = _WebSocketDisconnect fastapi.Depends = _depends sys.modules["fastapi"] = fastapi if "src.agent" not in sys.modules: agent_module = ModuleType("src.agent") class _AgentService: async def astream(self, text: str, thread_id: str): yield text def _get_agent_service(): return _AgentService() agent_module.AgentService = _AgentService agent_module.get_agent_service = _get_agent_service sys.modules["src.agent"] = agent_module from lambda_agent_api.client import MsgUserMessage # noqa: E402 from src.api.external import process_message # noqa: E402 def test_build_thread_key_uses_platform_user_and_chat_id(): assert build_thread_key("matrix", "@alice:example.org", "C1") == "6:matrix18:@alice:example.org2:C1" def test_build_thread_key_does_not_collide_when_user_id_contains_colons(): left = build_thread_key("matrix", "@alice:example.org", "C1") right = build_thread_key("matrix", "@alice", "example.org:C1") assert left != right @pytest.mark.asyncio async def test_stream_message_yields_text_chunks_and_end(aiohttp_server): thread_key = build_thread_key("matrix", "@alice:example.org", "C1") async def handler(request): ws = web.WebSocketResponse() await ws.prepare(request) assert request.query["thread_id"] == thread_key await ws.send_json({"type": "STATUS"}) message = await ws.receive_json() assert message == {"type": "USER_MESSAGE", "text": "hello"} await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "hel"}) await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "lo"}) await ws.send_json({"type": "AGENT_EVENT_END", "tokens_used": 7}) await ws.close() return ws app = web.Application() app.router.add_get("/agent_ws/", handler) server = await aiohttp_server(app) client = AgentSessionClient(AgentSessionConfig(base_ws_url=str(server.make_url("/agent_ws/")))) chunks = [] async for chunk in client.stream_message( thread_key=thread_key, text="hello", ): chunks.append(chunk) assert chunks == [ MessageChunk(message_id=thread_key, delta="hel", finished=False, tokens_used=0), MessageChunk(message_id=thread_key, delta="lo", finished=False, tokens_used=0), MessageChunk(message_id=thread_key, delta="", finished=True, tokens_used=7), ] @pytest.mark.asyncio async def test_send_message_collects_streamed_chunks_and_tokens(aiohttp_server): thread_key = build_thread_key("matrix", "@alice:example.org", "C1") async def handler(request): ws = web.WebSocketResponse() await ws.prepare(request) assert request.query["thread_id"] == thread_key await ws.send_json({"type": "STATUS"}) message = await ws.receive_json() assert message == {"type": "USER_MESSAGE", "text": "hello world"} await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "hello "}) await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "world"}) await ws.send_json({"type": "AGENT_EVENT_END", "tokens_used": 11}) await ws.close() return ws app = web.Application() app.router.add_get("/agent_ws/", handler) server = await aiohttp_server(app) client = AgentSessionClient(AgentSessionConfig(base_ws_url=str(server.make_url("/agent_ws/")))) result = await client.send_message( thread_key=thread_key, text="hello world", ) assert result == MessageResponse( message_id=thread_key, response="hello world", tokens_used=11, finished=True, ) @pytest.mark.asyncio async def test_process_message_requires_thread_id_query_param(): class FakeWebSocket: query_params = {} async def send_text(self, text: str) -> None: raise AssertionError(f"send_text should not be called: {text}") class FakeAgentService: async def astream(self, text: str, thread_id: str): yield text with pytest.raises(ValueError, match="thread_id query parameter is required"): await process_message( FakeWebSocket(), MsgUserMessage(text="hello"), FakeAgentService(), ) @pytest.mark.asyncio async def test_process_message_passes_thread_id_to_agent_service(): class FakeWebSocket: def __init__(self) -> None: self.query_params = {"thread_id": "6:matrix18:@alice:example.org2:C1"} self.sent_messages: list[str] = [] async def send_text(self, text: str) -> None: self.sent_messages.append(text) class FakeAgentService: def __init__(self) -> None: self.calls: list[tuple[str, str]] = [] async def astream(self, text: str, thread_id: str): self.calls.append((text, thread_id)) yield "hello" ws = FakeWebSocket() agent_service = FakeAgentService() await process_message(ws, MsgUserMessage(text="hello"), agent_service) assert agent_service.calls == [("hello", "6:matrix18:@alice:example.org2:C1")] assert any("AGENT_EVENT_TEXT_CHUNK" in message for message in ws.sent_messages) assert any("AGENT_EVENT_END" in message for message in ws.sent_messages)