from __future__ import annotations import asyncio import os import re from collections.abc import AsyncIterator from pathlib import Path from urllib.parse import urljoin, urlsplit, urlunsplit import structlog from sdk.interface import ( Attachment, MessageChunk, MessageResponse, PlatformClient, PlatformError, User, UserSettings, ) from sdk.prototype_state import PrototypeStateStore from sdk.upstream_agent_api import AgentApi, MsgEventSendFile, MsgEventTextChunk logger = structlog.get_logger(__name__) def _ws_debug_enabled() -> bool: value = os.environ.get("SURFACES_DEBUG_WS", "") return value.strip().lower() in {"1", "true", "yes", "on"} class RealPlatformClient(PlatformClient): def __init__( self, agent_id: str, agent_base_url: str, prototype_state: PrototypeStateStore, platform: str = "matrix", agent_api_cls=AgentApi, ) -> None: self._agent_id = agent_id self._raw_agent_base_url = agent_base_url self._agent_base_url = self._normalize_agent_base_url(agent_base_url) self._agent_api_cls = agent_api_cls self._prototype_state = prototype_state self._platform = platform self._chat_send_locks: dict[str, asyncio.Lock] = {} if _ws_debug_enabled(): logger.warning( "agent_client_initialized", agent_id=self._agent_id, platform=self._platform, raw_base_url=self._raw_agent_base_url, normalized_base_url=self._agent_base_url, ) @property def agent_id(self) -> str: return self._agent_id @property def agent_base_url(self) -> str: return self._agent_base_url def _get_chat_send_lock(self, chat_id: str) -> asyncio.Lock: chat_key = str(chat_id) lock = self._chat_send_locks.get(chat_key) if lock is None: lock = asyncio.Lock() self._chat_send_locks[chat_key] = lock return lock async def get_or_create_user( self, external_id: str, platform: str, display_name: str | None = None, ) -> User: return await self._prototype_state.get_or_create_user( external_id=external_id, platform=platform, display_name=display_name, ) async def send_message( self, user_id: str, chat_id: str, text: str, attachments: list[Attachment] | None = None, ) -> MessageResponse: response_parts: list[str] = [] sent_attachments: list[Attachment] = [] message_id = user_id lock = self._get_chat_send_lock(chat_id) async with lock: chat_api = self._build_chat_api(chat_id) try: await chat_api.connect() async for event in self._stream_agent_events( chat_api, text, attachments=attachments ): message_id = user_id if isinstance(event, MsgEventTextChunk) and event.text: response_parts.append(event.text) elif isinstance(event, MsgEventSendFile): attachment = self._attachment_from_send_file_event(event) if attachment is not None: sent_attachments.append(attachment) except Exception as exc: raise self._to_platform_error(exc) from exc finally: await self._close_chat_api(chat_api) await self._prototype_state.set_last_tokens_used(str(chat_id), 0) response_kwargs = { "message_id": message_id, "response": "".join(response_parts), "tokens_used": 0, "finished": True, "attachments": sent_attachments, } return MessageResponse(**response_kwargs) async def stream_message( self, user_id: str, chat_id: str, text: str, attachments: list[Attachment] | None = None, ) -> AsyncIterator[MessageChunk]: lock = self._get_chat_send_lock(chat_id) async with lock: chat_api = self._build_chat_api(chat_id) try: await chat_api.connect() async for event in self._stream_agent_events( chat_api, text, attachments=attachments ): if isinstance(event, MsgEventTextChunk): yield MessageChunk( message_id=user_id, delta=event.text, finished=False, ) elif isinstance(event, MsgEventSendFile): continue except Exception as exc: raise self._to_platform_error(exc) from exc finally: await self._close_chat_api(chat_api) await self._prototype_state.set_last_tokens_used(str(chat_id), 0) yield MessageChunk( message_id=user_id, delta="", finished=True, tokens_used=0, ) async def get_settings(self, user_id: str) -> UserSettings: return await self._prototype_state.get_settings(user_id) async def update_settings(self, user_id: str, action) -> None: await self._prototype_state.update_settings(user_id, action) async def disconnect_chat(self, chat_id: str) -> None: self._chat_send_locks.pop(str(chat_id), None) async def close(self) -> None: self._chat_send_locks.clear() async def _stream_agent_events( self, chat_api, text: str, attachments: list[Attachment] | None = None, ) -> AsyncIterator[object]: attachment_paths = self._attachment_paths(attachments) event_stream = chat_api.send_message(text, attachments=attachment_paths or None) chunk_index = 0 async for event in event_stream: if isinstance(event, MsgEventTextChunk): logger.debug("agent_chunk", index=chunk_index, text=repr(event.text[:40])) chunk_index += 1 else: logger.debug("agent_event", index=chunk_index, type=type(event).__name__) yield event def _build_chat_api(self, chat_id: str): if _ws_debug_enabled(): logger.warning( "agent_chat_api_build", agent_id=self._agent_id, chat_id=str(chat_id), normalized_base_url=self._agent_base_url, ws_url=urljoin(self._agent_base_url, f"v1/agent_ws/{chat_id}/"), ) return self._agent_api_cls( agent_id=self._agent_id, base_url=self._agent_base_url, chat_id=str(chat_id), ) @staticmethod def _normalize_agent_base_url(base_url: str) -> str: parsed = urlsplit(base_url) path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/")) if path: path = f"{path}/" return urlunsplit((parsed.scheme, parsed.netloc, path, "", "")) @staticmethod async def _close_chat_api(chat_api) -> None: close = getattr(chat_api, "close", None) if callable(close): try: await close() except Exception: pass @staticmethod def _to_platform_error(exc: Exception) -> PlatformError: code = getattr(exc, "code", None) or "PLATFORM_CONNECTION_ERROR" return PlatformError(str(exc), code=code) @staticmethod def _normalize_workspace_path(location: str) -> str | None: if not location: return None path = Path(location) if not path.is_absolute(): normalized = path.as_posix() return normalized or None parts = path.parts if len(parts) >= 2 and parts[1] == "workspace": relative = Path(*parts[2:]).as_posix() return relative or None if len(parts) >= 3 and parts[1] == "agents": relative = Path(*parts[3:]).as_posix() return relative or None relative = path.as_posix().lstrip("/") return relative or None @staticmethod def _attachment_paths(attachments: list[Attachment] | None) -> list[str]: if not attachments: return [] paths = [] for attachment in attachments: if attachment.workspace_path: normalized = RealPlatformClient._normalize_workspace_path( attachment.workspace_path ) if normalized: paths.append(normalized) return paths @staticmethod def _attachment_from_send_file_event(event: MsgEventSendFile) -> Attachment: location = str(event.path) filename = Path(location).name or None workspace_path = RealPlatformClient._normalize_workspace_path(location) return Attachment( url=location, mime_type="application/octet-stream", size=None, filename=filename, workspace_path=workspace_path or None, )