surfaces/sdk/agent_api_wrapper.py

186 lines
6.6 KiB
Python

from __future__ import annotations
import asyncio
import inspect
import logging
import sys
import re
from urllib.parse import urlsplit, urlunsplit
from pathlib import Path
import aiohttp
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
if str(_api_root) not in sys.path:
sys.path.insert(0, str(_api_root))
from lambda_agent_api.agent_api import AgentApi, AgentException
from lambda_agent_api.server import (
MsgError,
MsgEventEnd,
MsgEventTextChunk,
MsgGracefulDisconnect,
ServerMessage,
)
logger = logging.getLogger(__name__)
class AgentApiWrapper(AgentApi):
"""Capture tokens_used from MsgEventEnd without patching upstream code."""
def __init__(
self,
agent_id: str,
base_url: str | None = None,
*,
chat_id: int | str = 0,
url: str | None = None,
**kwargs,
) -> None:
if base_url is None and url is None:
raise TypeError("AgentApiWrapper requires base_url or url")
self._base_url = self._normalize_base_url(base_url or url or "")
self._init_kwargs = dict(kwargs)
self.chat_id = chat_id
if self._supports_modern_constructor():
super().__init__(
agent_id=agent_id,
base_url=self._base_url,
chat_id=chat_id,
**kwargs,
)
else:
super().__init__(
agent_id=agent_id,
url=self._build_ws_url(self._base_url, chat_id),
**kwargs,
)
self.last_tokens_used = 0
@staticmethod
def _supports_modern_constructor() -> bool:
try:
parameters = inspect.signature(AgentApi.__init__).parameters
except (TypeError, ValueError):
return False
return "base_url" in parameters and "chat_id" in parameters
@staticmethod
def _normalize_base_url(base_url: str) -> str:
parsed = urlsplit(base_url)
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/"))
return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
@staticmethod
def _build_ws_url(base_url: str, chat_id: int | str) -> str:
return base_url.rstrip("/") + f"/agent_ws/?thread_id={chat_id}"
def for_chat(self, chat_id: int | str) -> "AgentApiWrapper":
return type(self)(
agent_id=self.id,
base_url=self._base_url,
chat_id=chat_id,
**self._init_kwargs,
)
@staticmethod
def _event_kind(event: object) -> str:
raw_kind = getattr(event, "type", None)
if hasattr(raw_kind, "value"):
raw_kind = raw_kind.value
if raw_kind is None:
raw_kind = event.__class__.__name__
kind = str(raw_kind).replace("-", "_")
if "_" in kind:
return kind.upper()
normalized = []
for index, char in enumerate(kind):
if index and char.isupper() and not kind[index - 1].isupper():
normalized.append("_")
normalized.append(char)
return "".join(normalized).upper()
@classmethod
def _is_kind(cls, event: object, *needles: str) -> bool:
kind = cls._event_kind(event)
return any(needle in kind for needle in needles)
@classmethod
def _is_text_event(cls, event: object) -> bool:
return hasattr(event, "text") or cls._is_kind(event, "TEXT_CHUNK")
@classmethod
def _is_end_event(cls, event: object) -> bool:
kind = cls._event_kind(event)
return kind == "END" or kind.endswith("_END")
@classmethod
def _is_send_file_event(cls, event: object) -> bool:
return "SEND_FILE" in cls._event_kind(event)
async def _publish_event(self, event: object, *, queue_event: object | None = None) -> None:
if self.callback:
self.callback(event)
if self._current_queue:
await self._current_queue.put(queue_event if queue_event is not None else event)
async def _publish_error(self, event: object) -> None:
if self.callback:
self.callback(event)
if self._current_queue and hasattr(event, "code") and hasattr(event, "details"):
await self._current_queue.put(AgentException(getattr(event, "code"), getattr(event, "details")))
async def _listen(self):
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
outgoing_msg = ServerMessage.validate_json(msg.data)
if self._is_text_event(outgoing_msg):
if self._current_queue:
await self._current_queue.put(outgoing_msg)
elif self.callback:
self.callback(outgoing_msg)
else:
logger.warning("[%s] AgentEvent without active request", self.id)
elif self._is_end_event(outgoing_msg):
self.last_tokens_used = outgoing_msg.tokens_used
await self._publish_event(outgoing_msg)
elif self._is_kind(outgoing_msg, "ERROR"):
error = AgentException(outgoing_msg.code, outgoing_msg.details)
logger.error("[%s] Agent error: %s", self.id, error)
await self._publish_error(outgoing_msg)
elif self._is_kind(outgoing_msg, "GRACEFUL_DISCONNECT"):
await self._publish_event(outgoing_msg)
logger.info("[%s] Gracefully disconnecting", self.id)
break
else:
await self._publish_event(outgoing_msg)
except Exception as exc:
logger.error("[%s] Failed to deserialize message: %s", self.id, exc)
if self._current_queue:
await self._current_queue.put(
AgentException("PARSE_ERROR", f"Validation failed: {exc}")
)
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type)
break
except asyncio.CancelledError:
pass
except Exception as exc:
logger.error("[%s] Error in listen loop: %s", self.id, exc)
finally:
await self._cleanup()