diff --git a/sdk/agent_session.py b/sdk/agent_session.py new file mode 100644 index 0000000..6f90e3f --- /dev/null +++ b/sdk/agent_session.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import AsyncIterator +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit + +import aiohttp + +from sdk.interface import MessageChunk, MessageResponse, PlatformError + + +def build_thread_key(platform: str, user_id: str, chat_id: str) -> str: + return f"{platform}:{user_id}:{chat_id}" + + +@dataclass(frozen=True, slots=True) +class AgentSessionConfig: + base_ws_url: str + timeout_seconds: float = 30.0 + + +class AgentSessionClient: + def __init__(self, config: AgentSessionConfig) -> None: + self._config = config + + async def send_message(self, *, thread_key: str, text: str) -> MessageResponse: + response_parts: list[str] = [] + tokens_used = 0 + + async for chunk in self.stream_message(thread_key=thread_key, text=text): + if chunk.delta: + response_parts.append(chunk.delta) + if chunk.finished: + tokens_used = chunk.tokens_used + + return MessageResponse( + message_id=thread_key, + response="".join(response_parts), + tokens_used=tokens_used, + finished=True, + ) + + async def stream_message(self, *, thread_key: str, text: str) -> AsyncIterator[MessageChunk]: + async with aiohttp.ClientSession() as session: + async with session.ws_connect( + self._ws_url(thread_key), + heartbeat=30, + ) as ws: + status = await ws.receive_json(timeout=self._config.timeout_seconds) + if status.get("type") != "STATUS": + raise PlatformError("Agent did not send STATUS", code="AGENT_PROTOCOL_ERROR") + + await ws.send_json({"type": "USER_MESSAGE", "text": text}) + + while True: + payload = await ws.receive_json(timeout=self._config.timeout_seconds) + msg_type = payload.get("type") + + if msg_type == "AGENT_EVENT_TEXT_CHUNK": + yield MessageChunk( + message_id=thread_key, + delta=payload["text"], + finished=False, + ) + elif msg_type == "AGENT_EVENT_END": + yield MessageChunk( + message_id=thread_key, + delta="", + finished=True, + tokens_used=payload.get("tokens_used", 0), + ) + return + elif msg_type == "ERROR": + raise PlatformError( + payload.get("details", "Agent error"), + code=payload.get("code", "AGENT_ERROR"), + ) + else: + raise PlatformError( + f"Unexpected agent message: {payload}", + code="AGENT_PROTOCOL_ERROR", + ) + + def _ws_url(self, thread_key: str) -> str: + parts = urlsplit(self._config.base_ws_url) + query = dict(parse_qsl(parts.query, keep_blank_values=True)) + query["thread_id"] = thread_key + return urlunsplit(parts._replace(query=urlencode(query))) diff --git a/tests/platform/test_agent_session.py b/tests/platform/test_agent_session.py new file mode 100644 index 0000000..a1d9dd6 --- /dev/null +++ b/tests/platform/test_agent_session.py @@ -0,0 +1,86 @@ +import pytest +from aiohttp import web + +from sdk.interface import MessageChunk, MessageResponse +from sdk.agent_session import AgentSessionClient, AgentSessionConfig, build_thread_key + + +def test_build_thread_key_uses_platform_user_and_chat_id(): + assert build_thread_key("matrix", "@alice:example.org", "C1") == "matrix:@alice:example.org:C1" + + +@pytest.mark.asyncio +async def test_stream_message_yields_text_chunks_and_end(aiohttp_server): + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + assert request.query["thread_id"] == "matrix:@alice:example.org:C1" + + await ws.send_json({"type": "STATUS"}) + + message = await ws.receive_json() + assert message == {"type": "USER_MESSAGE", "text": "hello"} + + await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "hel"}) + await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "lo"}) + await ws.send_json({"type": "AGENT_EVENT_END", "tokens_used": 7}) + await ws.close() + return ws + + app = web.Application() + app.router.add_get("/agent_ws/", handler) + server = await aiohttp_server(app) + + client = AgentSessionClient(AgentSessionConfig(base_ws_url=str(server.make_url("/agent_ws/")))) + + chunks = [] + async for chunk in client.stream_message( + thread_key="matrix:@alice:example.org:C1", + text="hello", + ): + chunks.append(chunk) + + assert chunks == [ + MessageChunk(message_id="matrix:@alice:example.org:C1", delta="hel", finished=False, tokens_used=0), + MessageChunk(message_id="matrix:@alice:example.org:C1", delta="lo", finished=False, tokens_used=0), + MessageChunk(message_id="matrix:@alice:example.org:C1", delta="", finished=True, tokens_used=7), + ] + + +@pytest.mark.asyncio +async def test_send_message_collects_streamed_chunks_and_tokens(aiohttp_server): + async def handler(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + assert request.query["thread_id"] == "matrix:@alice:example.org:C1" + + await ws.send_json({"type": "STATUS"}) + + message = await ws.receive_json() + assert message == {"type": "USER_MESSAGE", "text": "hello world"} + + await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "hello "}) + await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "world"}) + await ws.send_json({"type": "AGENT_EVENT_END", "tokens_used": 11}) + await ws.close() + return ws + + app = web.Application() + app.router.add_get("/agent_ws/", handler) + server = await aiohttp_server(app) + + client = AgentSessionClient(AgentSessionConfig(base_ws_url=str(server.make_url("/agent_ws/")))) + + result = await client.send_message( + thread_key="matrix:@alice:example.org:C1", + text="hello world", + ) + + assert result == MessageResponse( + message_id="matrix:@alice:example.org:C1", + response="hello world", + tokens_used=11, + finished=True, + )