From 414a8645bdb4f90d05a120168811b676d605a35f Mon Sep 17 00:00:00 2001 From: Mikhail Putilovskij Date: Sun, 19 Apr 2026 17:03:48 +0300 Subject: [PATCH] Add per-chat real client routing --- sdk/agent_api_wrapper.py | 42 ++++++++++++++++- sdk/real.py | 22 +++++++-- tests/platform/test_real.py | 92 +++++++++++++++++++++++++++++++------ 3 files changed, 138 insertions(+), 18 deletions(-) diff --git a/sdk/agent_api_wrapper.py b/sdk/agent_api_wrapper.py index 206f6c3..bf21f09 100644 --- a/sdk/agent_api_wrapper.py +++ b/sdk/agent_api_wrapper.py @@ -3,6 +3,8 @@ from __future__ import annotations import asyncio import logging import sys +import re +from urllib.parse import urlsplit, urlunsplit from pathlib import Path import aiohttp @@ -26,10 +28,46 @@ logger = logging.getLogger(__name__) class AgentApiWrapper(AgentApi): """Capture tokens_used from MsgEventEnd without patching upstream code.""" - def __init__(self, agent_id: str, url: str, **kwargs) -> None: - super().__init__(agent_id=agent_id, url=url, **kwargs) + def __init__( + self, + agent_id: str, + base_url: str | None = None, + *, + chat_id: int | str = 0, + url: str | None = None, + **kwargs, + ) -> None: + if base_url is None and url is None: + raise TypeError("AgentApiWrapper requires base_url or url") + + self._base_url = self._normalize_base_url(base_url or url or "") + self._init_kwargs = dict(kwargs) + self.chat_id = chat_id + super().__init__( + agent_id=agent_id, + url=self._build_ws_url(self._base_url, chat_id), + **kwargs, + ) self.last_tokens_used = 0 + @staticmethod + def _normalize_base_url(base_url: str) -> str: + parsed = urlsplit(base_url) + path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/")) + return urlunsplit((parsed.scheme, parsed.netloc, path, "", "")) + + @staticmethod + def _build_ws_url(base_url: str, chat_id: int | str) -> str: + return base_url.rstrip("/") + f"/v1/agent_ws/{chat_id}/" + + def for_chat(self, chat_id: int | str) -> "AgentApiWrapper": + return type(self)( + agent_id=self.id, + base_url=self._base_url, + chat_id=chat_id, + **self._init_kwargs, + ) + async def _listen(self): try: async for msg in self._ws: diff --git a/sdk/real.py b/sdk/real.py index 4492b46..16b62a4 100644 --- a/sdk/real.py +++ b/sdk/real.py @@ -17,11 +17,21 @@ class RealPlatformClient(PlatformClient): self._agent_api = agent_api self._prototype_state = prototype_state self._platform = platform + self._chat_apis: dict[str, AgentApiWrapper] = {} @property def agent_api(self) -> AgentApiWrapper: return self._agent_api + async def _get_chat_api(self, chat_id: str) -> AgentApiWrapper: + chat_key = str(chat_id) + chat_api = self._chat_apis.get(chat_key) + if chat_api is None: + chat_api = self._agent_api.for_chat(chat_key) + await chat_api.connect() + self._chat_apis[chat_key] = chat_api + return chat_api + async def get_or_create_user( self, external_id: str, @@ -66,8 +76,9 @@ class RealPlatformClient(PlatformClient): text: str, attachments: list[Attachment] | None = None, ) -> AsyncIterator[MessageChunk]: - self._agent_api.last_tokens_used = 0 - async for event in self._agent_api.send_message(text): + chat_api = await self._get_chat_api(chat_id) + chat_api.last_tokens_used = 0 + async for event in chat_api.send_message(text): yield MessageChunk( message_id=user_id, delta=event.text, @@ -77,7 +88,7 @@ class RealPlatformClient(PlatformClient): message_id=user_id, delta="", finished=True, - tokens_used=self._agent_api.last_tokens_used, + tokens_used=chat_api.last_tokens_used, ) async def get_settings(self, user_id: str) -> UserSettings: @@ -85,3 +96,8 @@ class RealPlatformClient(PlatformClient): async def update_settings(self, user_id: str, action) -> None: await self._prototype_state.update_settings(user_id, action) + + async def close(self) -> None: + for chat_api in list(self._chat_apis.values()): + await chat_api.close() + self._chat_apis.clear() diff --git a/tests/platform/test_real.py b/tests/platform/test_real.py index 1255888..1097937 100644 --- a/tests/platform/test_real.py +++ b/tests/platform/test_real.py @@ -6,22 +6,49 @@ from sdk.prototype_state import PrototypeStateStore from sdk.real import RealPlatformClient -class FakeAgentApi: - def __init__(self) -> None: +class FakeChunk: + def __init__(self, text: str) -> None: + self.text = text + + +class FakeChatAgentApi: + def __init__(self, chat_id: str) -> None: + self.chat_id = chat_id self.calls: list[str] = [] + self.connect_calls = 0 + self.close_calls = 0 self.last_tokens_used = 0 + async def connect(self) -> None: + self.connect_calls += 1 + + async def close(self) -> None: + self.close_calls += 1 + async def send_message(self, text: str): self.calls.append(text) - yield type("Chunk", (), {"text": text[:2]})() - yield type("Chunk", (), {"text": text[2:]})() + midpoint = len(text) // 2 + yield FakeChunk(text[:midpoint]) + yield FakeChunk(text[midpoint:]) self.last_tokens_used = 3 +class FakeAgentApiFactory: + def __init__(self) -> None: + self.created_chat_ids: list[str] = [] + self.instances: dict[str, FakeChatAgentApi] = {} + + def for_chat(self, chat_id: str) -> FakeChatAgentApi: + chat_api = FakeChatAgentApi(chat_id) + self.created_chat_ids.append(chat_id) + self.instances[chat_id] = chat_api + return chat_api + + @pytest.mark.asyncio async def test_real_platform_client_get_or_create_user_uses_local_state(): client = RealPlatformClient( - agent_api=FakeAgentApi(), + agent_api=FakeAgentApiFactory(), prototype_state=PrototypeStateStore(), ) @@ -36,15 +63,15 @@ async def test_real_platform_client_get_or_create_user_uses_local_state(): @pytest.mark.asyncio -async def test_real_platform_client_send_message_collects_stream_output(): - agent_api = FakeAgentApi() +async def test_real_platform_client_send_message_uses_chat_bound_client(): + agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) - result = await client.send_message("@alice:example.org", "C1", "hello") + result = await client.send_message("@alice:example.org", "chat-7", "hello") assert result == MessageResponse( message_id="@alice:example.org", @@ -52,12 +79,50 @@ async def test_real_platform_client_send_message_collects_stream_output(): tokens_used=3, finished=True, ) - assert agent_api.calls == ["hello"] + assert agent_api.created_chat_ids == ["chat-7"] + assert agent_api.instances["chat-7"].chat_id == "chat-7" + assert agent_api.instances["chat-7"].calls == ["hello"] + assert agent_api.instances["chat-7"].connect_calls == 1 + + +@pytest.mark.asyncio +async def test_real_platform_client_reuses_cached_chat_client(): + agent_api = FakeAgentApiFactory() + client = RealPlatformClient( + agent_api=agent_api, + prototype_state=PrototypeStateStore(), + platform="matrix", + ) + + await client.send_message("@alice:example.org", "chat-1", "hello") + await client.send_message("@alice:example.org", "chat-1", "again") + + assert agent_api.created_chat_ids == ["chat-1"] + assert agent_api.instances["chat-1"].calls == ["hello", "again"] + assert agent_api.instances["chat-1"].connect_calls == 1 + + +@pytest.mark.asyncio +async def test_real_platform_client_creates_distinct_clients_per_chat(): + agent_api = FakeAgentApiFactory() + client = RealPlatformClient( + agent_api=agent_api, + prototype_state=PrototypeStateStore(), + platform="matrix", + ) + + await client.send_message("@alice:example.org", "chat-1", "hello") + await client.send_message("@alice:example.org", "chat-2", "world") + + assert agent_api.created_chat_ids == ["chat-1", "chat-2"] + assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"] + assert agent_api.instances["chat-1"].calls == ["hello"] + assert agent_api.instances["chat-2"].calls == ["world"] @pytest.mark.asyncio async def test_real_platform_client_stream_message_emits_final_tokens_chunk(): - agent_api = FakeAgentApi() + agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), @@ -65,7 +130,7 @@ async def test_real_platform_client_stream_message_emits_final_tokens_chunk(): ) chunks = [] - async for chunk in client.stream_message("@alice:example.org", "C1", "hello"): + async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"): chunks.append(chunk) assert chunks == [ @@ -88,13 +153,14 @@ async def test_real_platform_client_stream_message_emits_final_tokens_chunk(): tokens_used=3, ), ] - assert agent_api.calls == ["hello"] + assert agent_api.created_chat_ids == ["chat-1"] + assert agent_api.instances["chat-1"].calls == ["hello"] @pytest.mark.asyncio async def test_real_platform_client_settings_are_local(): client = RealPlatformClient( - agent_api=FakeAgentApi(), + agent_api=FakeAgentApiFactory(), prototype_state=PrototypeStateStore(), platform="matrix", )