diff --git a/sdk/real.py b/sdk/real.py index 16b62a4..8e2dba6 100644 --- a/sdk/real.py +++ b/sdk/real.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import AsyncIterator from sdk.agent_api_wrapper import AgentApiWrapper @@ -18,18 +19,26 @@ class RealPlatformClient(PlatformClient): self._prototype_state = prototype_state self._platform = platform self._chat_apis: dict[str, AgentApiWrapper] = {} + self._chat_api_lock = asyncio.Lock() @property def agent_api(self) -> AgentApiWrapper: return self._agent_api - async def _get_chat_api(self, chat_id: str) -> AgentApiWrapper: + async def _get_chat_api(self, chat_id: str): chat_key = str(chat_id) chat_api = self._chat_apis.get(chat_key) if chat_api is None: - chat_api = self._agent_api.for_chat(chat_key) - await chat_api.connect() - self._chat_apis[chat_key] = chat_api + chat_api_factory = getattr(self._agent_api, "for_chat", None) + if not callable(chat_api_factory): + return self._agent_api + + async with self._chat_api_lock: + chat_api = self._chat_apis.get(chat_key) + if chat_api is None: + chat_api = chat_api_factory(chat_key) + await chat_api.connect() + self._chat_apis[chat_key] = chat_api return chat_api async def get_or_create_user( @@ -77,7 +86,8 @@ class RealPlatformClient(PlatformClient): attachments: list[Attachment] | None = None, ) -> AsyncIterator[MessageChunk]: chat_api = await self._get_chat_api(chat_id) - chat_api.last_tokens_used = 0 + if hasattr(chat_api, "last_tokens_used"): + chat_api.last_tokens_used = 0 async for event in chat_api.send_message(text): yield MessageChunk( message_id=user_id, @@ -88,7 +98,7 @@ class RealPlatformClient(PlatformClient): message_id=user_id, delta="", finished=True, - tokens_used=chat_api.last_tokens_used, + tokens_used=getattr(chat_api, "last_tokens_used", 0), ) async def get_settings(self, user_id: str) -> UserSettings: @@ -99,5 +109,11 @@ class RealPlatformClient(PlatformClient): async def close(self) -> None: for chat_api in list(self._chat_apis.values()): - await chat_api.close() + close = getattr(chat_api, "close", None) + if callable(close): + await close() self._chat_apis.clear() + if not callable(getattr(self._agent_api, "for_chat", None)): + close = getattr(self._agent_api, "close", None) + if callable(close): + await close() diff --git a/tests/platform/test_real.py b/tests/platform/test_real.py index 1097937..94b9520 100644 --- a/tests/platform/test_real.py +++ b/tests/platform/test_real.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from core.protocol import SettingsAction @@ -45,6 +47,18 @@ class FakeAgentApiFactory: return chat_api +class LegacyAgentApi: + def __init__(self) -> None: + self.calls: list[str] = [] + self.last_tokens_used = 0 + + async def send_message(self, text: str): + self.calls.append(text) + yield FakeChunk(text[:2]) + yield FakeChunk(text[2:]) + self.last_tokens_used = 7 + + @pytest.mark.asyncio async def test_real_platform_client_get_or_create_user_uses_local_state(): client = RealPlatformClient( @@ -85,6 +99,26 @@ async def test_real_platform_client_send_message_uses_chat_bound_client(): assert agent_api.instances["chat-7"].connect_calls == 1 +@pytest.mark.asyncio +async def test_real_platform_client_works_with_legacy_agent_api_without_for_chat(): + legacy_api = LegacyAgentApi() + client = RealPlatformClient( + agent_api=legacy_api, + prototype_state=PrototypeStateStore(), + platform="matrix", + ) + + result = await client.send_message("@alice:example.org", "chat-legacy", "hello") + + assert result == MessageResponse( + message_id="@alice:example.org", + response="hello", + tokens_used=7, + finished=True, + ) + assert legacy_api.calls == ["hello"] + + @pytest.mark.asyncio async def test_real_platform_client_reuses_cached_chat_client(): agent_api = FakeAgentApiFactory() @@ -102,6 +136,26 @@ async def test_real_platform_client_reuses_cached_chat_client(): assert agent_api.instances["chat-1"].connect_calls == 1 +@pytest.mark.asyncio +async def test_real_platform_client_creates_chat_client_atomically_for_concurrent_requests(): + agent_api = FakeAgentApiFactory() + client = RealPlatformClient( + agent_api=agent_api, + prototype_state=PrototypeStateStore(), + platform="matrix", + ) + + results = await asyncio.gather( + client.send_message("@alice:example.org", "chat-1", "hello"), + client.send_message("@alice:example.org", "chat-1", "again"), + ) + + assert [result.response for result in results] == ["hello", "again"] + assert agent_api.created_chat_ids == ["chat-1"] + assert agent_api.instances["chat-1"].connect_calls == 1 + assert agent_api.instances["chat-1"].calls == ["hello", "again"] + + @pytest.mark.asyncio async def test_real_platform_client_creates_distinct_clients_per_chat(): agent_api = FakeAgentApiFactory()