from __future__ import annotations import asyncio import inspect from collections.abc import AsyncIterator from pathlib import Path from sdk.agent_api_wrapper import AgentApiWrapper from sdk.interface import ( Attachment, MessageChunk, MessageResponse, PlatformClient, PlatformError, User, UserSettings, ) from sdk.prototype_state import PrototypeStateStore class RealPlatformClient(PlatformClient): def __init__( self, agent_api: AgentApiWrapper, prototype_state: PrototypeStateStore, platform: str = "matrix", ) -> None: self._agent_api = agent_api self._prototype_state = prototype_state self._platform = platform self._chat_apis: dict[str, AgentApiWrapper] = {} self._chat_api_lock = asyncio.Lock() self._chat_send_locks: dict[str, asyncio.Lock] = {} @property def agent_api(self) -> AgentApiWrapper: return self._agent_api async def _get_chat_api(self, chat_id: str): chat_key = str(chat_id) chat_api = self._chat_apis.get(chat_key) if chat_api is None: chat_api_factory = getattr(self._agent_api, "for_chat", None) if not callable(chat_api_factory): return self._agent_api async with self._chat_api_lock: chat_api = self._chat_apis.get(chat_key) if chat_api is None: chat_api = chat_api_factory(chat_key) await chat_api.connect() self._chat_apis[chat_key] = chat_api return chat_api def _get_chat_send_lock(self, chat_id: str) -> asyncio.Lock: chat_key = str(chat_id) lock = self._chat_send_locks.get(chat_key) if lock is None: lock = asyncio.Lock() self._chat_send_locks[chat_key] = lock return lock async def get_or_create_user( self, external_id: str, platform: str, display_name: str | None = None, ) -> User: return await self._prototype_state.get_or_create_user( external_id=external_id, platform=platform, display_name=display_name, ) async def send_message( self, user_id: str, chat_id: str, text: str, attachments: list[Attachment] | None = None, ) -> MessageResponse: response_parts: list[str] = [] tokens_used = 0 sent_attachments: list[Attachment] = [] message_id = user_id saw_end_event = False lock = self._get_chat_send_lock(chat_id) async with lock: chat_api = await self._get_chat_api(chat_id) if hasattr(chat_api, "last_tokens_used"): chat_api.last_tokens_used = 0 try: async for event in self._stream_agent_events( chat_api, text, attachments=attachments ): message_id = user_id if self._is_text_event(event): chunk_text = getattr(event, "text", "") if chunk_text: response_parts.append(chunk_text) elif self._is_end_event(event): tokens_used = getattr(event, "tokens_used", tokens_used) saw_end_event = True elif self._is_send_file_event(event): attachment = self._attachment_from_send_file_event(event) if attachment is not None: sent_attachments.append(attachment) except Exception as exc: await self._handle_chat_api_failure(chat_id, exc) if not saw_end_event: tokens_used = getattr(chat_api, "last_tokens_used", tokens_used) await self._prototype_state.set_last_tokens_used(str(chat_id), tokens_used) response_kwargs = { "message_id": message_id, "response": "".join(response_parts), "tokens_used": tokens_used, "finished": True, } if self._message_response_accepts_attachments(): response_kwargs["attachments"] = sent_attachments return MessageResponse(**response_kwargs) async def stream_message( self, user_id: str, chat_id: str, text: str, attachments: list[Attachment] | None = None, ) -> AsyncIterator[MessageChunk]: lock = self._get_chat_send_lock(chat_id) async with lock: chat_api = await self._get_chat_api(chat_id) if hasattr(chat_api, "last_tokens_used"): chat_api.last_tokens_used = 0 saw_end_event = False try: async for event in self._stream_agent_events( chat_api, text, attachments=attachments ): if self._is_text_event(event): yield MessageChunk( message_id=user_id, delta=getattr(event, "text", ""), finished=False, ) elif self._is_end_event(event): tokens_used = getattr(event, "tokens_used", 0) saw_end_event = True await self._prototype_state.set_last_tokens_used(str(chat_id), tokens_used) yield MessageChunk( message_id=user_id, delta="", finished=True, tokens_used=tokens_used, ) elif self._is_send_file_event(event): continue else: continue except Exception as exc: await self._handle_chat_api_failure(chat_id, exc) if not saw_end_event: tokens_used = getattr(chat_api, "last_tokens_used", 0) await self._prototype_state.set_last_tokens_used(str(chat_id), tokens_used) yield MessageChunk( message_id=user_id, delta="", finished=True, tokens_used=tokens_used, ) async def get_settings(self, user_id: str) -> UserSettings: return await self._prototype_state.get_settings(user_id) async def update_settings(self, user_id: str, action) -> None: await self._prototype_state.update_settings(user_id, action) async def disconnect_chat(self, chat_id: str) -> None: chat_key = str(chat_id) chat_api = self._chat_apis.pop(chat_key, None) self._chat_send_locks.pop(chat_key, None) if chat_api is not None: close = getattr(chat_api, "close", None) if callable(close): await close() async def close(self) -> None: for chat_api in list(self._chat_apis.values()): close = getattr(chat_api, "close", None) if callable(close): await close() self._chat_apis.clear() self._chat_send_locks.clear() if not callable(getattr(self._agent_api, "for_chat", None)): close = getattr(self._agent_api, "close", None) if callable(close): await close() async def _stream_agent_events( self, chat_api, text: str, attachments: list[Attachment] | None = None, ) -> AsyncIterator[object]: send_message = chat_api.send_message attachment_paths = self._attachment_paths(attachments) if attachment_paths and self._send_message_accepts_attachments(send_message): event_stream = send_message(text, attachments=attachment_paths) else: event_stream = send_message(text) async for event in event_stream: yield event async def _handle_chat_api_failure(self, chat_id: str, exc: Exception) -> None: await self.disconnect_chat(chat_id) code = getattr(exc, "code", None) or "PLATFORM_CONNECTION_ERROR" raise PlatformError(str(exc), code=code) from exc @staticmethod def _attachment_paths(attachments: list[Attachment] | None) -> list[str]: if not attachments: return [] paths = [] for attachment in attachments: if attachment.workspace_path: paths.append(attachment.workspace_path) return paths @staticmethod def _send_message_accepts_attachments(send_message) -> bool: try: parameters = inspect.signature(send_message).parameters except (TypeError, ValueError): return False return "attachments" in parameters or any( parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in parameters.values() ) @staticmethod def _event_kind(event: object) -> str: raw_kind = getattr(event, "type", None) if hasattr(raw_kind, "value"): raw_kind = raw_kind.value if raw_kind is None: raw_kind = event.__class__.__name__ kind = str(raw_kind).replace("-", "_") if "_" in kind: return kind.upper() normalized = [] for index, char in enumerate(kind): if index and char.isupper() and not kind[index - 1].isupper(): normalized.append("_") normalized.append(char) return "".join(normalized).upper() @classmethod def _is_text_event(cls, event: object) -> bool: return hasattr(event, "text") or "TEXT_CHUNK" in cls._event_kind(event) @classmethod def _is_end_event(cls, event: object) -> bool: kind = cls._event_kind(event) return kind == "END" or kind.endswith("_END") @classmethod def _is_send_file_event(cls, event: object) -> bool: kind = cls._event_kind(event) return "SEND_FILE" in kind @staticmethod def _attachment_from_send_file_event(event: object) -> Attachment | None: location = None for attr in ("url", "workspace_path", "path", "file_path", "uri"): value = getattr(event, attr, None) if value: location = str(value) break if location is None: return None mime_type = getattr(event, "mime_type", None) or "application/octet-stream" filename = getattr(event, "filename", None) or Path(location).name or None size = getattr(event, "size", None) workspace_path = location if workspace_path.startswith("/workspace/"): workspace_path = workspace_path[len("/workspace/") :] elif workspace_path == "/workspace": workspace_path = "" return Attachment( url=location, mime_type=mime_type, size=size, filename=filename, workspace_path=workspace_path or None, ) @staticmethod def _message_response_accepts_attachments() -> bool: fields = getattr(MessageResponse, "model_fields", None) if isinstance(fields, dict): return "attachments" in fields try: return "attachments" in inspect.signature(MessageResponse).parameters except (TypeError, ValueError): return False