from __future__ import annotations import asyncio import logging import sys from pathlib import Path import aiohttp _api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api" if str(_api_root) not in sys.path: sys.path.insert(0, str(_api_root)) from lambda_agent_api.agent_api import AgentApi, AgentException from lambda_agent_api.server import ( MsgError, MsgEventEnd, MsgEventTextChunk, MsgGracefulDisconnect, ServerMessage, ) logger = logging.getLogger(__name__) class AgentApiWrapper(AgentApi): """Capture tokens_used from MsgEventEnd without patching upstream code.""" def __init__(self, agent_id: str, url: str, **kwargs) -> None: super().__init__(agent_id=agent_id, url=url, **kwargs) self.last_tokens_used = 0 async def _listen(self): try: async for msg in self._ws: if msg.type == aiohttp.WSMsgType.TEXT: try: outgoing_msg = ServerMessage.validate_json(msg.data) if isinstance(outgoing_msg, MsgEventTextChunk): if self._current_queue: await self._current_queue.put(outgoing_msg) elif self.callback: self.callback(outgoing_msg) else: logger.warning("[%s] AgentEvent without active request", self.id) elif isinstance(outgoing_msg, MsgEventEnd): self.last_tokens_used = outgoing_msg.tokens_used if self._current_queue: await self._current_queue.put(outgoing_msg) elif isinstance(outgoing_msg, MsgError): if self.callback: self.callback(outgoing_msg) error = AgentException(outgoing_msg.code, outgoing_msg.details) logger.error("[%s] Agent error: %s", self.id, error) if self._current_queue: await self._current_queue.put(error) elif isinstance(outgoing_msg, MsgGracefulDisconnect): if self.callback: self.callback(outgoing_msg) logger.info("[%s] Gracefully disconnecting", self.id) break else: logger.warning("[%s] Unknown message type: %s", self.id, outgoing_msg.type) if self.callback: self.callback(outgoing_msg) except Exception as exc: logger.error("[%s] Failed to deserialize message: %s", self.id, exc) if self._current_queue: await self._current_queue.put( AgentException("PARSE_ERROR", f"Validation failed: {exc}") ) elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type) break except asyncio.CancelledError: pass except Exception as exc: logger.error("[%s] Error in listen loop: %s", self.id, exc) finally: await self._cleanup()