Add per-chat real client routing

This commit is contained in:
Mikhail Putilovskij 2026-04-19 17:03:48 +03:00
parent 5782001d3d
commit 414a8645bd
3 changed files with 138 additions and 18 deletions

View file

@ -3,6 +3,8 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import sys import sys
import re
from urllib.parse import urlsplit, urlunsplit
from pathlib import Path from pathlib import Path
import aiohttp import aiohttp
@ -26,10 +28,46 @@ logger = logging.getLogger(__name__)
class AgentApiWrapper(AgentApi): class AgentApiWrapper(AgentApi):
"""Capture tokens_used from MsgEventEnd without patching upstream code.""" """Capture tokens_used from MsgEventEnd without patching upstream code."""
def __init__(self, agent_id: str, url: str, **kwargs) -> None: def __init__(
super().__init__(agent_id=agent_id, url=url, **kwargs) self,
agent_id: str,
base_url: str | None = None,
*,
chat_id: int | str = 0,
url: str | None = None,
**kwargs,
) -> None:
if base_url is None and url is None:
raise TypeError("AgentApiWrapper requires base_url or url")
self._base_url = self._normalize_base_url(base_url or url or "")
self._init_kwargs = dict(kwargs)
self.chat_id = chat_id
super().__init__(
agent_id=agent_id,
url=self._build_ws_url(self._base_url, chat_id),
**kwargs,
)
self.last_tokens_used = 0 self.last_tokens_used = 0
@staticmethod
def _normalize_base_url(base_url: str) -> str:
parsed = urlsplit(base_url)
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/"))
return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
@staticmethod
def _build_ws_url(base_url: str, chat_id: int | str) -> str:
return base_url.rstrip("/") + f"/v1/agent_ws/{chat_id}/"
def for_chat(self, chat_id: int | str) -> "AgentApiWrapper":
return type(self)(
agent_id=self.id,
base_url=self._base_url,
chat_id=chat_id,
**self._init_kwargs,
)
async def _listen(self): async def _listen(self):
try: try:
async for msg in self._ws: async for msg in self._ws:

View file

@ -17,11 +17,21 @@ class RealPlatformClient(PlatformClient):
self._agent_api = agent_api self._agent_api = agent_api
self._prototype_state = prototype_state self._prototype_state = prototype_state
self._platform = platform self._platform = platform
self._chat_apis: dict[str, AgentApiWrapper] = {}
@property @property
def agent_api(self) -> AgentApiWrapper: def agent_api(self) -> AgentApiWrapper:
return self._agent_api return self._agent_api
async def _get_chat_api(self, chat_id: str) -> AgentApiWrapper:
chat_key = str(chat_id)
chat_api = self._chat_apis.get(chat_key)
if chat_api is None:
chat_api = self._agent_api.for_chat(chat_key)
await chat_api.connect()
self._chat_apis[chat_key] = chat_api
return chat_api
async def get_or_create_user( async def get_or_create_user(
self, self,
external_id: str, external_id: str,
@ -66,8 +76,9 @@ class RealPlatformClient(PlatformClient):
text: str, text: str,
attachments: list[Attachment] | None = None, attachments: list[Attachment] | None = None,
) -> AsyncIterator[MessageChunk]: ) -> AsyncIterator[MessageChunk]:
self._agent_api.last_tokens_used = 0 chat_api = await self._get_chat_api(chat_id)
async for event in self._agent_api.send_message(text): chat_api.last_tokens_used = 0
async for event in chat_api.send_message(text):
yield MessageChunk( yield MessageChunk(
message_id=user_id, message_id=user_id,
delta=event.text, delta=event.text,
@ -77,7 +88,7 @@ class RealPlatformClient(PlatformClient):
message_id=user_id, message_id=user_id,
delta="", delta="",
finished=True, finished=True,
tokens_used=self._agent_api.last_tokens_used, tokens_used=chat_api.last_tokens_used,
) )
async def get_settings(self, user_id: str) -> UserSettings: async def get_settings(self, user_id: str) -> UserSettings:
@ -85,3 +96,8 @@ class RealPlatformClient(PlatformClient):
async def update_settings(self, user_id: str, action) -> None: async def update_settings(self, user_id: str, action) -> None:
await self._prototype_state.update_settings(user_id, action) await self._prototype_state.update_settings(user_id, action)
async def close(self) -> None:
for chat_api in list(self._chat_apis.values()):
await chat_api.close()
self._chat_apis.clear()

View file

@ -6,22 +6,49 @@ from sdk.prototype_state import PrototypeStateStore
from sdk.real import RealPlatformClient from sdk.real import RealPlatformClient
class FakeAgentApi: class FakeChunk:
def __init__(self) -> None: def __init__(self, text: str) -> None:
self.text = text
class FakeChatAgentApi:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.calls: list[str] = [] self.calls: list[str] = []
self.connect_calls = 0
self.close_calls = 0
self.last_tokens_used = 0 self.last_tokens_used = 0
async def connect(self) -> None:
self.connect_calls += 1
async def close(self) -> None:
self.close_calls += 1
async def send_message(self, text: str): async def send_message(self, text: str):
self.calls.append(text) self.calls.append(text)
yield type("Chunk", (), {"text": text[:2]})() midpoint = len(text) // 2
yield type("Chunk", (), {"text": text[2:]})() yield FakeChunk(text[:midpoint])
yield FakeChunk(text[midpoint:])
self.last_tokens_used = 3 self.last_tokens_used = 3
class FakeAgentApiFactory:
def __init__(self) -> None:
self.created_chat_ids: list[str] = []
self.instances: dict[str, FakeChatAgentApi] = {}
def for_chat(self, chat_id: str) -> FakeChatAgentApi:
chat_api = FakeChatAgentApi(chat_id)
self.created_chat_ids.append(chat_id)
self.instances[chat_id] = chat_api
return chat_api
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_real_platform_client_get_or_create_user_uses_local_state(): async def test_real_platform_client_get_or_create_user_uses_local_state():
client = RealPlatformClient( client = RealPlatformClient(
agent_api=FakeAgentApi(), agent_api=FakeAgentApiFactory(),
prototype_state=PrototypeStateStore(), prototype_state=PrototypeStateStore(),
) )
@ -36,15 +63,15 @@ async def test_real_platform_client_get_or_create_user_uses_local_state():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_real_platform_client_send_message_collects_stream_output(): async def test_real_platform_client_send_message_uses_chat_bound_client():
agent_api = FakeAgentApi() agent_api = FakeAgentApiFactory()
client = RealPlatformClient( client = RealPlatformClient(
agent_api=agent_api, agent_api=agent_api,
prototype_state=PrototypeStateStore(), prototype_state=PrototypeStateStore(),
platform="matrix", platform="matrix",
) )
result = await client.send_message("@alice:example.org", "C1", "hello") result = await client.send_message("@alice:example.org", "chat-7", "hello")
assert result == MessageResponse( assert result == MessageResponse(
message_id="@alice:example.org", message_id="@alice:example.org",
@ -52,12 +79,50 @@ async def test_real_platform_client_send_message_collects_stream_output():
tokens_used=3, tokens_used=3,
finished=True, finished=True,
) )
assert agent_api.calls == ["hello"] assert agent_api.created_chat_ids == ["chat-7"]
assert agent_api.instances["chat-7"].chat_id == "chat-7"
assert agent_api.instances["chat-7"].calls == ["hello"]
assert agent_api.instances["chat-7"].connect_calls == 1
@pytest.mark.asyncio
async def test_real_platform_client_reuses_cached_chat_client():
agent_api = FakeAgentApiFactory()
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
await client.send_message("@alice:example.org", "chat-1", "hello")
await client.send_message("@alice:example.org", "chat-1", "again")
assert agent_api.created_chat_ids == ["chat-1"]
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
assert agent_api.instances["chat-1"].connect_calls == 1
@pytest.mark.asyncio
async def test_real_platform_client_creates_distinct_clients_per_chat():
agent_api = FakeAgentApiFactory()
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
await client.send_message("@alice:example.org", "chat-1", "hello")
await client.send_message("@alice:example.org", "chat-2", "world")
assert agent_api.created_chat_ids == ["chat-1", "chat-2"]
assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"]
assert agent_api.instances["chat-1"].calls == ["hello"]
assert agent_api.instances["chat-2"].calls == ["world"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_real_platform_client_stream_message_emits_final_tokens_chunk(): async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
agent_api = FakeAgentApi() agent_api = FakeAgentApiFactory()
client = RealPlatformClient( client = RealPlatformClient(
agent_api=agent_api, agent_api=agent_api,
prototype_state=PrototypeStateStore(), prototype_state=PrototypeStateStore(),
@ -65,7 +130,7 @@ async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
) )
chunks = [] chunks = []
async for chunk in client.stream_message("@alice:example.org", "C1", "hello"): async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
chunks.append(chunk) chunks.append(chunk)
assert chunks == [ assert chunks == [
@ -88,13 +153,14 @@ async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
tokens_used=3, tokens_used=3,
), ),
] ]
assert agent_api.calls == ["hello"] assert agent_api.created_chat_ids == ["chat-1"]
assert agent_api.instances["chat-1"].calls == ["hello"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_real_platform_client_settings_are_local(): async def test_real_platform_client_settings_are_local():
client = RealPlatformClient( client = RealPlatformClient(
agent_api=FakeAgentApi(), agent_api=FakeAgentApiFactory(),
prototype_state=PrototypeStateStore(), prototype_state=PrototypeStateStore(),
platform="matrix", platform="matrix",
) )