from __future__ import annotations import asyncio import inspect import logging import os import re import sys from collections.abc import AsyncIterator from pathlib import Path from urllib.parse import urlsplit, urlunsplit import aiohttp _api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api" if str(_api_root) not in sys.path: sys.path.insert(0, str(_api_root)) from lambda_agent_api.agent_api import AgentApi, AgentBusyException, AgentException # noqa: E402 from lambda_agent_api.client import EClientMessage, MsgUserMessage # noqa: E402 from lambda_agent_api.server import AgentEventUnion, MsgEventEnd, ServerMessage # noqa: E402 logger = logging.getLogger(__name__) _DEBUG_STREAM = os.environ.get("SURFACES_AGENT_DEBUG_STREAM", "").strip().lower() in { "1", "true", "yes", } _POST_END_DRAIN_MS = int(os.environ.get("SURFACES_AGENT_POST_END_DRAIN_MS", "120")) _STREAM_IDLE_TIMEOUT_MS = int(os.environ.get("SURFACES_AGENT_IDLE_TIMEOUT_MS", "60000")) class AgentApiWrapper(AgentApi): """Capture tokens_used from MsgEventEnd without patching upstream code.""" def __init__( self, agent_id: str, base_url: str | None = None, *, chat_id: int | str = 0, url: str | None = None, **kwargs, ) -> None: if base_url is None and url is None: raise TypeError("AgentApiWrapper requires base_url or url") self._base_url = self._normalize_base_url(base_url or url or "") self._init_kwargs = dict(kwargs) self.chat_id = chat_id if self._supports_modern_constructor(): super().__init__( agent_id=agent_id, base_url=self._base_url, chat_id=chat_id, **kwargs, ) else: super().__init__( agent_id=agent_id, url=self._build_ws_url(self._base_url, chat_id), **kwargs, ) self.last_tokens_used = 0 @staticmethod def _supports_modern_constructor() -> bool: try: parameters = inspect.signature(AgentApi.__init__).parameters except (TypeError, ValueError): return False return "base_url" in parameters and "chat_id" in parameters @staticmethod def _normalize_base_url(base_url: str) -> str: parsed = urlsplit(base_url) path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/")) return urlunsplit((parsed.scheme, parsed.netloc, path, "", "")) @staticmethod def _build_ws_url(base_url: str, chat_id: int | str) -> str: return base_url.rstrip("/") + f"/agent_ws/?thread_id={chat_id}" def for_chat(self, chat_id: int | str) -> AgentApiWrapper: return type(self)( agent_id=self.id, base_url=self._base_url, chat_id=chat_id, **self._init_kwargs, ) @staticmethod def _event_kind(event: object) -> str: raw_kind = getattr(event, "type", None) if hasattr(raw_kind, "value"): raw_kind = raw_kind.value if raw_kind is None: raw_kind = event.__class__.__name__ kind = str(raw_kind).replace("-", "_") if "_" in kind: return kind.upper() normalized = [] for index, char in enumerate(kind): if index and char.isupper() and not kind[index - 1].isupper(): normalized.append("_") normalized.append(char) return "".join(normalized).upper() @classmethod def _is_kind(cls, event: object, *needles: str) -> bool: kind = cls._event_kind(event) return any(needle in kind for needle in needles) @classmethod def _is_text_event(cls, event: object) -> bool: return hasattr(event, "text") or cls._is_kind(event, "TEXT_CHUNK") @classmethod def _is_end_event(cls, event: object) -> bool: kind = cls._event_kind(event) return kind == "END" or kind.endswith("_END") @classmethod def _is_send_file_event(cls, event: object) -> bool: return "SEND_FILE" in cls._event_kind(event) async def _publish_event(self, event: object, *, queue_event: object | None = None) -> None: if self.callback: self.callback(event) if self._current_queue: await self._current_queue.put(queue_event if queue_event is not None else event) async def _publish_error(self, event: object) -> None: if self.callback: self.callback(event) if self._current_queue and hasattr(event, "code") and hasattr(event, "details"): await self._current_queue.put(AgentException(event.code, event.details)) async def _listen(self): try: async for msg in self._ws: if msg.type == aiohttp.WSMsgType.TEXT: try: outgoing_msg = ServerMessage.validate_json(msg.data) if self._is_text_event(outgoing_msg): if _DEBUG_STREAM: logger.warning( "[%s] text chunk queue=%s text=%r", self.id, self._current_queue is not None, getattr(outgoing_msg, "text", "")[:80], ) if self._current_queue: await self._current_queue.put(outgoing_msg) elif self.callback: self.callback(outgoing_msg) else: logger.warning("[%s] AgentEvent without active request", self.id) elif self._is_end_event(outgoing_msg): self.last_tokens_used = outgoing_msg.tokens_used if _DEBUG_STREAM: logger.warning( "[%s] end event queue=%s tokens=%s", self.id, self._current_queue is not None, getattr(outgoing_msg, "tokens_used", None), ) await self._publish_event(outgoing_msg) elif self._is_kind(outgoing_msg, "ERROR"): error = AgentException(outgoing_msg.code, outgoing_msg.details) logger.error("[%s] Agent error: %s", self.id, error) await self._publish_error(outgoing_msg) elif self._is_kind(outgoing_msg, "GRACEFUL_DISCONNECT"): await self._publish_event(outgoing_msg) logger.info("[%s] Gracefully disconnecting", self.id) break else: await self._publish_event(outgoing_msg) except Exception as exc: logger.error("[%s] Failed to deserialize message: %s", self.id, exc) if self._current_queue: await self._current_queue.put( AgentException("PARSE_ERROR", f"Validation failed: {exc}") ) elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type) break except asyncio.CancelledError: pass except Exception as exc: logger.error("[%s] Error in listen loop: %s", self.id, exc) finally: await self._cleanup() async def send_message( self, text: str, attachments: list[str] | None = None ) -> AsyncIterator[AgentEventUnion]: if not self._connected or not self._ws: raise AgentException( code="NOT_CONNECTED", details="Not connected. Call connect() first." ) if self._request_lock.locked(): raise AgentBusyException("Agent is currently processing another request") await self._request_lock.acquire() try: self._current_queue = asyncio.Queue() message = MsgUserMessage( type=EClientMessage.USER_MESSAGE, text=text, attachments=attachments or [], ) await self._ws.send_str(message.model_dump_json()) logger.debug("[%s] Sent message: %s...", self.id, text[:50]) while True: try: chunk = await asyncio.wait_for( self._current_queue.get(), timeout=max(_STREAM_IDLE_TIMEOUT_MS, 0) / 1000, ) except TimeoutError as exc: raise AgentException( "TIMEOUT", ( "Timed out waiting for the next agent stream event " f"after {max(_STREAM_IDLE_TIMEOUT_MS, 0)}ms" ), ) from exc if isinstance(chunk, Exception): raise chunk if isinstance(chunk, MsgEventEnd): self.last_tokens_used = chunk.tokens_used async for late_chunk in self._drain_post_end_events(): yield late_chunk break yield chunk finally: if self._current_queue: orphan_queue = self._current_queue self._current_queue = None while not orphan_queue.empty(): try: orphan_msg = orphan_queue.get_nowait() if isinstance(orphan_msg, Exception): logger.debug( "[%s] Dropped exception from queue during cleanup: %s", self.id, orphan_msg, ) continue if self.callback: self.callback(orphan_msg) else: logger.debug("[%s] Dropped orphaned message during cleanup", self.id) except asyncio.QueueEmpty: break if self._request_lock.locked(): self._request_lock.release() async def _drain_post_end_events(self) -> AsyncIterator[AgentEventUnion]: if self._current_queue is None: return timeout_s = max(_POST_END_DRAIN_MS, 0) / 1000 while True: try: chunk = await asyncio.wait_for(self._current_queue.get(), timeout=timeout_s) except TimeoutError: break if isinstance(chunk, Exception): logger.warning("[%s] dropping post-END exception: %s", self.id, chunk) continue if isinstance(chunk, MsgEventEnd): self.last_tokens_used = chunk.tokens_used if _DEBUG_STREAM: logger.warning( "[%s] dropped duplicate END tokens=%s", self.id, chunk.tokens_used, ) continue if _DEBUG_STREAM and self._is_text_event(chunk): logger.warning( "[%s] recovered post-END text chunk=%r", self.id, getattr(chunk, "text", "")[:80], ) yield chunk