from __future__ import annotations import asyncio from typing import AsyncIterator from sdk.agent_api_wrapper import AgentApiWrapper from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformClient, User, UserSettings from sdk.prototype_state import PrototypeStateStore class RealPlatformClient(PlatformClient): def __init__( self, agent_api: AgentApiWrapper, prototype_state: PrototypeStateStore, platform: str = "matrix", ) -> None: self._agent_api = agent_api self._prototype_state = prototype_state self._platform = platform self._chat_apis: dict[str, AgentApiWrapper] = {} self._chat_api_lock = asyncio.Lock() self._chat_send_locks: dict[str, asyncio.Lock] = {} @property def agent_api(self) -> AgentApiWrapper: return self._agent_api async def _get_chat_api(self, chat_id: str): chat_key = str(chat_id) chat_api = self._chat_apis.get(chat_key) if chat_api is None: chat_api_factory = getattr(self._agent_api, "for_chat", None) if not callable(chat_api_factory): return self._agent_api async with self._chat_api_lock: chat_api = self._chat_apis.get(chat_key) if chat_api is None: chat_api = chat_api_factory(chat_key) await chat_api.connect() self._chat_apis[chat_key] = chat_api return chat_api def _get_chat_send_lock(self, chat_id: str) -> asyncio.Lock: chat_key = str(chat_id) lock = self._chat_send_locks.get(chat_key) if lock is None: lock = asyncio.Lock() self._chat_send_locks[chat_key] = lock return lock async def get_or_create_user( self, external_id: str, platform: str, display_name: str | None = None, ) -> User: return await self._prototype_state.get_or_create_user( external_id=external_id, platform=platform, display_name=display_name, ) async def send_message( self, user_id: str, chat_id: str, text: str, attachments: list[Attachment] | None = None, ) -> MessageResponse: response_parts: list[str] = [] tokens_used = 0 message_id = user_id async for chunk in self.stream_message(user_id, chat_id, text, attachments=attachments): message_id = chunk.message_id if chunk.delta: response_parts.append(chunk.delta) if chunk.finished: tokens_used = chunk.tokens_used return MessageResponse( message_id=message_id, response="".join(response_parts), tokens_used=tokens_used, finished=True, ) async def stream_message( self, user_id: str, chat_id: str, text: str, attachments: list[Attachment] | None = None, ) -> AsyncIterator[MessageChunk]: lock = self._get_chat_send_lock(chat_id) async with lock: chat_api = await self._get_chat_api(chat_id) if hasattr(chat_api, "last_tokens_used"): chat_api.last_tokens_used = 0 async for event in chat_api.send_message(text): yield MessageChunk( message_id=user_id, delta=event.text, finished=False, ) tokens_used = getattr(chat_api, "last_tokens_used", 0) await self._prototype_state.set_last_tokens_used(str(chat_id), tokens_used) yield MessageChunk( message_id=user_id, delta="", finished=True, tokens_used=tokens_used, ) async def get_settings(self, user_id: str) -> UserSettings: return await self._prototype_state.get_settings(user_id) async def update_settings(self, user_id: str, action) -> None: await self._prototype_state.update_settings(user_id, action) async def close(self) -> None: for chat_api in list(self._chat_apis.values()): close = getattr(chat_api, "close", None) if callable(close): await close() self._chat_apis.clear() self._chat_send_locks.clear() if not callable(getattr(self._agent_api, "for_chat", None)): close = getattr(self._agent_api, "close", None) if callable(close): await close()