diff --git a/pyproject.toml b/pyproject.toml index 8f4978b..ccc6309 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,14 @@ dependencies = [ "structlog>=24.1", "python-dotenv>=1.0", "httpx>=0.27", + "aiohttp>=3.9", ] [project.optional-dependencies] dev = [ "pytest>=8.0", "pytest-asyncio>=0.23", + "pytest-aiohttp>=1.0", "pytest-cov>=4.1", "ruff>=0.3", "mypy>=1.8", diff --git a/sdk/__init__.py b/sdk/__init__.py index 36f81c1..f7939f7 100644 --- a/sdk/__init__.py +++ b/sdk/__init__.py @@ -1,3 +1,9 @@ -from sdk.real import RealPlatformClient - __all__ = ["RealPlatformClient"] + + +def __getattr__(name: str): + if name == "RealPlatformClient": + from sdk.real import RealPlatformClient + + return RealPlatformClient + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/sdk/agent_session.py b/sdk/agent_session.py index 8b7c7b3..0f959a1 100644 --- a/sdk/agent_session.py +++ b/sdk/agent_session.py @@ -4,8 +4,6 @@ from dataclasses import dataclass from typing import AsyncIterator from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit -import aiohttp - from sdk.interface import MessageChunk, MessageResponse, PlatformError @@ -41,6 +39,8 @@ class AgentSessionClient: ) async def stream_message(self, *, thread_key: str, text: str) -> AsyncIterator[MessageChunk]: + import aiohttp + async with aiohttp.ClientSession() as session: async with session.ws_connect( self._ws_url(thread_key), diff --git a/sdk/prototype_state.py b/sdk/prototype_state.py index 3423701..ccb75f1 100644 --- a/sdk/prototype_state.py +++ b/sdk/prototype_state.py @@ -34,7 +34,6 @@ class PrototypeStateStore: async def get_or_create_user( self, - *, external_id: str, platform: str, display_name: str | None = None, @@ -54,14 +53,14 @@ class PrototypeStateStore: created_at=datetime.now(UTC), is_new=True, ) - self._users[key] = user + self._users[key] = user.model_copy(update={"is_new": False}) return user.model_copy() async def get_settings(self, user_id: str) -> UserSettings: stored = self._settings.get(user_id, {}) return UserSettings( skills={**DEFAULT_SKILLS, **stored.get("skills", {})}, - connectors=stored.get("connectors", {}), + connectors=dict(stored.get("connectors", {})), soul={**DEFAULT_SOUL, **stored.get("soul", {})}, safety={**DEFAULT_SAFETY, **stored.get("safety", {})}, plan={**DEFAULT_PLAN, **stored.get("plan", {})}, diff --git a/sdk/real.py b/sdk/real.py index cd38cc2..7da48c8 100644 --- a/sdk/real.py +++ b/sdk/real.py @@ -1,18 +1,21 @@ from __future__ import annotations -from typing import AsyncIterator +from typing import TYPE_CHECKING, AsyncIterator -from sdk.agent_session import AgentSessionClient, build_thread_key +from sdk.agent_session import build_thread_key from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformClient, User, UserSettings from sdk.prototype_state import PrototypeStateStore +if TYPE_CHECKING: + from sdk.agent_session import AgentSessionClient + class RealPlatformClient(PlatformClient): def __init__( self, agent_sessions: AgentSessionClient, prototype_state: PrototypeStateStore, - platform: str, + platform: str = "matrix", ) -> None: self._agent_sessions = agent_sessions self._prototype_state = prototype_state diff --git a/tests/platform/test_agent_session.py b/tests/platform/test_agent_session.py index bd38b27..2d085c3 100644 --- a/tests/platform/test_agent_session.py +++ b/tests/platform/test_agent_session.py @@ -1,9 +1,58 @@ +import sys +from pathlib import Path +from types import ModuleType + import pytest from aiohttp import web from sdk.interface import MessageChunk, MessageResponse from sdk.agent_session import AgentSessionClient, AgentSessionConfig, build_thread_key +AGENT_ROOT = Path(__file__).resolve().parents[2] / "external" / "platform-agent" +AGENT_API_ROOT = Path(__file__).resolve().parents[2] / "external" / "platform-agent_api" +for path in (AGENT_ROOT, AGENT_API_ROOT): + if str(path) not in sys.path: + sys.path.insert(0, str(path)) + +if "fastapi" not in sys.modules: + fastapi = ModuleType("fastapi") + + class _Router: + def websocket(self, _path: str): + def decorator(fn): + return fn + + return decorator + + class _WebSocketDisconnect(Exception): + pass + + def _depends(value): + return value + + fastapi.APIRouter = _Router + fastapi.WebSocket = object + fastapi.WebSocketDisconnect = _WebSocketDisconnect + fastapi.Depends = _depends + sys.modules["fastapi"] = fastapi + +if "src.agent" not in sys.modules: + agent_module = ModuleType("src.agent") + + class _AgentService: + async def astream(self, text: str, thread_id: str): + yield text + + def _get_agent_service(): + return _AgentService() + + agent_module.AgentService = _AgentService + agent_module.get_agent_service = _get_agent_service + sys.modules["src.agent"] = agent_module + +from lambda_agent_api.client import MsgUserMessage # noqa: E402 +from src.api.external import process_message # noqa: E402 + def test_build_thread_key_uses_platform_user_and_chat_id(): assert build_thread_key("matrix", "@alice:example.org", "C1") == "6:matrix18:@alice:example.org2:C1" @@ -18,11 +67,13 @@ def test_build_thread_key_does_not_collide_when_user_id_contains_colons(): @pytest.mark.asyncio async def test_stream_message_yields_text_chunks_and_end(aiohttp_server): + thread_key = build_thread_key("matrix", "@alice:example.org", "C1") + async def handler(request): ws = web.WebSocketResponse() await ws.prepare(request) - assert request.query["thread_id"] == "matrix:@alice:example.org:C1" + assert request.query["thread_id"] == thread_key await ws.send_json({"type": "STATUS"}) @@ -43,25 +94,27 @@ async def test_stream_message_yields_text_chunks_and_end(aiohttp_server): chunks = [] async for chunk in client.stream_message( - thread_key="matrix:@alice:example.org:C1", + thread_key=thread_key, text="hello", ): chunks.append(chunk) assert chunks == [ - MessageChunk(message_id="matrix:@alice:example.org:C1", delta="hel", finished=False, tokens_used=0), - MessageChunk(message_id="matrix:@alice:example.org:C1", delta="lo", finished=False, tokens_used=0), - MessageChunk(message_id="matrix:@alice:example.org:C1", delta="", finished=True, tokens_used=7), + MessageChunk(message_id=thread_key, delta="hel", finished=False, tokens_used=0), + MessageChunk(message_id=thread_key, delta="lo", finished=False, tokens_used=0), + MessageChunk(message_id=thread_key, delta="", finished=True, tokens_used=7), ] @pytest.mark.asyncio async def test_send_message_collects_streamed_chunks_and_tokens(aiohttp_server): + thread_key = build_thread_key("matrix", "@alice:example.org", "C1") + async def handler(request): ws = web.WebSocketResponse() await ws.prepare(request) - assert request.query["thread_id"] == "matrix:@alice:example.org:C1" + assert request.query["thread_id"] == thread_key await ws.send_json({"type": "STATUS"}) @@ -81,13 +134,60 @@ async def test_send_message_collects_streamed_chunks_and_tokens(aiohttp_server): client = AgentSessionClient(AgentSessionConfig(base_ws_url=str(server.make_url("/agent_ws/")))) result = await client.send_message( - thread_key="matrix:@alice:example.org:C1", + thread_key=thread_key, text="hello world", ) assert result == MessageResponse( - message_id="matrix:@alice:example.org:C1", + message_id=thread_key, response="hello world", tokens_used=11, finished=True, ) + + +@pytest.mark.asyncio +async def test_process_message_requires_thread_id_query_param(): + class FakeWebSocket: + query_params = {} + + async def send_text(self, text: str) -> None: + raise AssertionError(f"send_text should not be called: {text}") + + class FakeAgentService: + async def astream(self, text: str, thread_id: str): + yield text + + with pytest.raises(ValueError, match="thread_id query parameter is required"): + await process_message( + FakeWebSocket(), + MsgUserMessage(text="hello"), + FakeAgentService(), + ) + + +@pytest.mark.asyncio +async def test_process_message_passes_thread_id_to_agent_service(): + class FakeWebSocket: + def __init__(self) -> None: + self.query_params = {"thread_id": "6:matrix18:@alice:example.org2:C1"} + self.sent_messages: list[str] = [] + + async def send_text(self, text: str) -> None: + self.sent_messages.append(text) + + class FakeAgentService: + def __init__(self) -> None: + self.calls: list[tuple[str, str]] = [] + + async def astream(self, text: str, thread_id: str): + self.calls.append((text, thread_id)) + yield "hello" + + ws = FakeWebSocket() + agent_service = FakeAgentService() + await process_message(ws, MsgUserMessage(text="hello"), agent_service) + + assert agent_service.calls == [("hello", "6:matrix18:@alice:example.org2:C1")] + assert any("AGENT_EVENT_TEXT_CHUNK" in message for message in ws.sent_messages) + assert any("AGENT_EVENT_END" in message for message in ws.sent_messages) diff --git a/tests/platform/test_prototype_state.py b/tests/platform/test_prototype_state.py index c1a2d73..b5f5dc3 100644 --- a/tests/platform/test_prototype_state.py +++ b/tests/platform/test_prototype_state.py @@ -9,18 +9,12 @@ from sdk.prototype_state import PrototypeStateStore async def test_get_or_create_user_is_stable_per_surface_identity(): store = PrototypeStateStore() - first = await store.get_or_create_user( - external_id="@alice:example.org", - platform="matrix", - display_name="Alice", - ) - second = await store.get_or_create_user( - external_id="@alice:example.org", - platform="matrix", - ) + first = await store.get_or_create_user("@alice:example.org", "matrix", "Alice") + second = await store.get_or_create_user("@alice:example.org", "matrix") assert first.user_id == "usr-matrix-@alice:example.org" assert first.is_new is True + assert store._users["matrix:@alice:example.org"].is_new is False first.display_name = "Mallory" first.is_new = False @@ -56,6 +50,22 @@ async def test_settings_defaults_match_existing_mock_shape(): assert settings.plan == {"name": "Beta", "tokens_used": 0, "tokens_limit": 1000} +@pytest.mark.asyncio +async def test_get_settings_returns_connectors_copy(): + store = PrototypeStateStore() + store._settings["usr-matrix-@alice:example.org"] = { + "connectors": {"github": {"enabled": True}}, + } + + settings = await store.get_settings("usr-matrix-@alice:example.org") + settings.connectors["github"]["enabled"] = False + settings.connectors["slack"] = {"enabled": True} + + assert store._settings["usr-matrix-@alice:example.org"]["connectors"] == { + "github": {"enabled": True}, + } + + @pytest.mark.asyncio async def test_update_settings_supports_toggle_skill_and_setters(): store = PrototypeStateStore() diff --git a/tests/platform/test_real.py b/tests/platform/test_real.py index f10e2c0..7225cfd 100644 --- a/tests/platform/test_real.py +++ b/tests/platform/test_real.py @@ -1,6 +1,7 @@ import pytest from core.protocol import SettingsAction +from sdk.agent_session import build_thread_key from sdk.interface import MessageChunk, MessageResponse, UserSettings from sdk.prototype_state import PrototypeStateStore from sdk.real import RealPlatformClient @@ -31,13 +32,12 @@ async def test_real_platform_client_get_or_create_user_uses_local_state(): client = RealPlatformClient( agent_sessions=FakeAgentSessionClient(), prototype_state=PrototypeStateStore(), - platform="telegram", ) - first = await client.get_or_create_user("u1", "telegram", "Alice") - second = await client.get_or_create_user("u1", "telegram") + first = await client.get_or_create_user("u1", "matrix", "Alice") + second = await client.get_or_create_user("u1", "matrix") - assert first.user_id == "usr-telegram-u1" + assert first.user_id == "usr-matrix-u1" assert first.is_new is True assert second.user_id == first.user_id assert second.is_new is False @@ -45,57 +45,55 @@ async def test_real_platform_client_get_or_create_user_uses_local_state(): @pytest.mark.asyncio -async def test_real_platform_client_send_message_uses_configured_platform(): +async def test_real_platform_client_send_message_uses_surface_user_thread_identity(): agent_sessions = FakeAgentSessionClient() client = RealPlatformClient( agent_sessions=agent_sessions, prototype_state=PrototypeStateStore(), - platform="telegram", + platform="matrix", ) - result = await client.send_message("usr-telegram-u1", "C1", "hello") + thread_key = build_thread_key("matrix", "@alice:example.org", "C1") + result = await client.send_message("@alice:example.org", "C1", "hello") assert result == MessageResponse( - message_id="8:telegram15:usr-telegram-u12:C1", + message_id=thread_key, response="echo:hello", tokens_used=3, finished=True, ) - assert agent_sessions.send_calls == [ - ("8:telegram15:usr-telegram-u12:C1", "hello") - ] + assert agent_sessions.send_calls == [(thread_key, "hello")] @pytest.mark.asyncio -async def test_real_platform_client_stream_message_uses_configured_platform(): +async def test_real_platform_client_stream_message_uses_surface_user_thread_identity(): agent_sessions = FakeAgentSessionClient() client = RealPlatformClient( agent_sessions=agent_sessions, prototype_state=PrototypeStateStore(), - platform="telegram", + platform="matrix", ) + thread_key = build_thread_key("matrix", "@alice:example.org", "C1") chunks = [] - async for chunk in client.stream_message("usr-telegram-u1", "C1", "hello"): + async for chunk in client.stream_message("@alice:example.org", "C1", "hello"): chunks.append(chunk) assert chunks == [ MessageChunk( - message_id="8:telegram15:usr-telegram-u12:C1", + message_id=thread_key, delta="he", finished=False, tokens_used=0, ), MessageChunk( - message_id="8:telegram15:usr-telegram-u12:C1", + message_id=thread_key, delta="llo", finished=True, tokens_used=3, ), ] - assert agent_sessions.stream_calls == [ - ("8:telegram15:usr-telegram-u12:C1", "hello") - ] + assert agent_sessions.stream_calls == [(thread_key, "hello")] @pytest.mark.asyncio diff --git a/uv.lock b/uv.lock index 0c37403..35c8460 100644 --- a/uv.lock +++ b/uv.lock @@ -1095,6 +1095,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-aiohttp" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/4b/d326890c153f2c4ce1bf45d07683c08c10a1766058a22934620bc6ac6592/pytest_aiohttp-1.1.0.tar.gz", hash = "sha256:147de8cb164f3fc9d7196967f109ab3c0b93ea3463ab50631e56438eab7b5adc", size = 12842, upload-time = "2025-01-23T12:44:04.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/0f/e6af71c02e0f1098eaf7d2dbf3ffdf0a69fc1e0ef174f96af05cef161f1b/pytest_aiohttp-1.1.0-py3-none-any.whl", hash = "sha256:f39a11693a0dce08dd6c542d241e199dd8047a6e6596b2bcfa60d373f143456d", size = 8932, upload-time = "2025-01-23T12:44:03.27Z" }, +] + [[package]] name = "pytest-asyncio" version = "1.3.0" @@ -1302,6 +1316,7 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "aiogram" }, + { name = "aiohttp" }, { name = "httpx" }, { name = "matrix-nio" }, { name = "pydantic" }, @@ -1313,6 +1328,7 @@ dependencies = [ dev = [ { name = "mypy" }, { name = "pytest" }, + { name = "pytest-aiohttp" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "ruff" }, @@ -1321,11 +1337,13 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiogram", specifier = ">=3.4,<4" }, + { name = "aiohttp", specifier = ">=3.9" }, { name = "httpx", specifier = ">=0.27" }, { name = "matrix-nio", specifier = ">=0.21" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8" }, { name = "pydantic", specifier = ">=2.5" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" }, + { name = "pytest-aiohttp", marker = "extra == 'dev'", specifier = ">=1.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1" }, { name = "python-dotenv", specifier = ">=1.0" },