from __future__ import annotations import asyncio import inspect import logging import sys import re from urllib.parse import urlsplit, urlunsplit from pathlib import Path import aiohttp _api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api" if str(_api_root) not in sys.path: sys.path.insert(0, str(_api_root)) from lambda_agent_api.agent_api import AgentApi, AgentException from lambda_agent_api.server import ( MsgError, MsgEventEnd, MsgEventTextChunk, MsgGracefulDisconnect, ServerMessage, ) logger = logging.getLogger(__name__) class AgentApiWrapper(AgentApi): """Capture tokens_used from MsgEventEnd without patching upstream code.""" def __init__( self, agent_id: str, base_url: str | None = None, *, chat_id: int | str = 0, url: str | None = None, **kwargs, ) -> None: if base_url is None and url is None: raise TypeError("AgentApiWrapper requires base_url or url") self._base_url = self._normalize_base_url(base_url or url or "") self._init_kwargs = dict(kwargs) self.chat_id = chat_id if self._supports_modern_constructor(): super().__init__( agent_id=agent_id, base_url=self._base_url, chat_id=chat_id, **kwargs, ) else: super().__init__( agent_id=agent_id, url=self._build_ws_url(self._base_url, chat_id), **kwargs, ) self.last_tokens_used = 0 @staticmethod def _supports_modern_constructor() -> bool: try: parameters = inspect.signature(AgentApi.__init__).parameters except (TypeError, ValueError): return False return "base_url" in parameters and "chat_id" in parameters @staticmethod def _normalize_base_url(base_url: str) -> str: parsed = urlsplit(base_url) path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/")) return urlunsplit((parsed.scheme, parsed.netloc, path, "", "")) @staticmethod def _build_ws_url(base_url: str, chat_id: int | str) -> str: return base_url.rstrip("/") + f"/v1/agent_ws/{chat_id}/" def for_chat(self, chat_id: int | str) -> "AgentApiWrapper": return type(self)( agent_id=self.id, base_url=self._base_url, chat_id=chat_id, **self._init_kwargs, ) async def _listen(self): try: async for msg in self._ws: if msg.type == aiohttp.WSMsgType.TEXT: try: outgoing_msg = ServerMessage.validate_json(msg.data) if isinstance(outgoing_msg, MsgEventTextChunk): if self._current_queue: await self._current_queue.put(outgoing_msg) elif self.callback: self.callback(outgoing_msg) else: logger.warning("[%s] AgentEvent without active request", self.id) elif isinstance(outgoing_msg, MsgEventEnd): self.last_tokens_used = outgoing_msg.tokens_used if self._current_queue: await self._current_queue.put(outgoing_msg) elif isinstance(outgoing_msg, MsgError): if self.callback: self.callback(outgoing_msg) error = AgentException(outgoing_msg.code, outgoing_msg.details) logger.error("[%s] Agent error: %s", self.id, error) if self._current_queue: await self._current_queue.put(error) elif isinstance(outgoing_msg, MsgGracefulDisconnect): if self.callback: self.callback(outgoing_msg) logger.info("[%s] Gracefully disconnecting", self.id) break else: logger.warning("[%s] Unknown message type: %s", self.id, outgoing_msg.type) if self.callback: self.callback(outgoing_msg) except Exception as exc: logger.error("[%s] Failed to deserialize message: %s", self.id, exc) if self._current_queue: await self._current_queue.put( AgentException("PARSE_ERROR", f"Validation failed: {exc}") ) elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type) break except asyncio.CancelledError: pass except Exception as exc: logger.error("[%s] Error in listen loop: %s", self.id, exc) finally: await self._cleanup()