surfaces/sdk/agent_api_wrapper.py

88 lines
3.5 KiB
Python

from __future__ import annotations
import asyncio
import logging
import sys
from pathlib import Path
import aiohttp
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
if str(_api_root) not in sys.path:
sys.path.insert(0, str(_api_root))
from lambda_agent_api.agent_api import AgentApi, AgentException
from lambda_agent_api.server import (
MsgError,
MsgEventEnd,
MsgEventTextChunk,
MsgGracefulDisconnect,
ServerMessage,
)
logger = logging.getLogger(__name__)
class AgentApiWrapper(AgentApi):
"""Capture tokens_used from MsgEventEnd without patching upstream code."""
def __init__(self, agent_id: str, url: str, **kwargs) -> None:
super().__init__(agent_id=agent_id, url=url, **kwargs)
self.last_tokens_used = 0
async def _listen(self):
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
outgoing_msg = ServerMessage.validate_json(msg.data)
if isinstance(outgoing_msg, MsgEventTextChunk):
if self._current_queue:
await self._current_queue.put(outgoing_msg)
elif self.callback:
self.callback(outgoing_msg)
else:
logger.warning("[%s] AgentEvent without active request", self.id)
elif isinstance(outgoing_msg, MsgEventEnd):
self.last_tokens_used = outgoing_msg.tokens_used
if self._current_queue:
await self._current_queue.put(outgoing_msg)
elif isinstance(outgoing_msg, MsgError):
if self.callback:
self.callback(outgoing_msg)
error = AgentException(outgoing_msg.code, outgoing_msg.details)
logger.error("[%s] Agent error: %s", self.id, error)
if self._current_queue:
await self._current_queue.put(error)
elif isinstance(outgoing_msg, MsgGracefulDisconnect):
if self.callback:
self.callback(outgoing_msg)
logger.info("[%s] Gracefully disconnecting", self.id)
break
else:
logger.warning("[%s] Unknown message type: %s", self.id, outgoing_msg.type)
if self.callback:
self.callback(outgoing_msg)
except Exception as exc:
logger.error("[%s] Failed to deserialize message: %s", self.id, exc)
if self._current_queue:
await self._current_queue.put(
AgentException("PARSE_ERROR", f"Validation failed: {exc}")
)
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type)
break
except asyncio.CancelledError:
pass
except Exception as exc:
logger.error("[%s] Error in listen loop: %s", self.id, exc)
finally:
await self._cleanup()