144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import inspect
|
|
import logging
|
|
import sys
|
|
import re
|
|
from urllib.parse import urlsplit, urlunsplit
|
|
from pathlib import Path
|
|
|
|
import aiohttp
|
|
|
|
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
|
|
if str(_api_root) not in sys.path:
|
|
sys.path.insert(0, str(_api_root))
|
|
|
|
from lambda_agent_api.agent_api import AgentApi, AgentException
|
|
from lambda_agent_api.server import (
|
|
MsgError,
|
|
MsgEventEnd,
|
|
MsgEventTextChunk,
|
|
MsgGracefulDisconnect,
|
|
ServerMessage,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgentApiWrapper(AgentApi):
|
|
"""Capture tokens_used from MsgEventEnd without patching upstream code."""
|
|
|
|
def __init__(
|
|
self,
|
|
agent_id: str,
|
|
base_url: str | None = None,
|
|
*,
|
|
chat_id: int | str = 0,
|
|
url: str | None = None,
|
|
**kwargs,
|
|
) -> None:
|
|
if base_url is None and url is None:
|
|
raise TypeError("AgentApiWrapper requires base_url or url")
|
|
|
|
self._base_url = self._normalize_base_url(base_url or url or "")
|
|
self._init_kwargs = dict(kwargs)
|
|
self.chat_id = chat_id
|
|
if self._supports_modern_constructor():
|
|
super().__init__(
|
|
agent_id=agent_id,
|
|
base_url=self._base_url,
|
|
chat_id=chat_id,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
super().__init__(
|
|
agent_id=agent_id,
|
|
url=self._build_ws_url(self._base_url, chat_id),
|
|
**kwargs,
|
|
)
|
|
self.last_tokens_used = 0
|
|
|
|
@staticmethod
|
|
def _supports_modern_constructor() -> bool:
|
|
try:
|
|
parameters = inspect.signature(AgentApi.__init__).parameters
|
|
except (TypeError, ValueError):
|
|
return False
|
|
|
|
return "base_url" in parameters and "chat_id" in parameters
|
|
|
|
@staticmethod
|
|
def _normalize_base_url(base_url: str) -> str:
|
|
parsed = urlsplit(base_url)
|
|
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/"))
|
|
return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
|
|
|
|
@staticmethod
|
|
def _build_ws_url(base_url: str, chat_id: int | str) -> str:
|
|
return base_url.rstrip("/") + f"/v1/agent_ws/{chat_id}/"
|
|
|
|
def for_chat(self, chat_id: int | str) -> "AgentApiWrapper":
|
|
return type(self)(
|
|
agent_id=self.id,
|
|
base_url=self._base_url,
|
|
chat_id=chat_id,
|
|
**self._init_kwargs,
|
|
)
|
|
|
|
async def _listen(self):
|
|
try:
|
|
async for msg in self._ws:
|
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
try:
|
|
outgoing_msg = ServerMessage.validate_json(msg.data)
|
|
|
|
if isinstance(outgoing_msg, MsgEventTextChunk):
|
|
if self._current_queue:
|
|
await self._current_queue.put(outgoing_msg)
|
|
elif self.callback:
|
|
self.callback(outgoing_msg)
|
|
else:
|
|
logger.warning("[%s] AgentEvent without active request", self.id)
|
|
|
|
elif isinstance(outgoing_msg, MsgEventEnd):
|
|
self.last_tokens_used = outgoing_msg.tokens_used
|
|
if self._current_queue:
|
|
await self._current_queue.put(outgoing_msg)
|
|
|
|
elif isinstance(outgoing_msg, MsgError):
|
|
if self.callback:
|
|
self.callback(outgoing_msg)
|
|
error = AgentException(outgoing_msg.code, outgoing_msg.details)
|
|
logger.error("[%s] Agent error: %s", self.id, error)
|
|
if self._current_queue:
|
|
await self._current_queue.put(error)
|
|
|
|
elif isinstance(outgoing_msg, MsgGracefulDisconnect):
|
|
if self.callback:
|
|
self.callback(outgoing_msg)
|
|
logger.info("[%s] Gracefully disconnecting", self.id)
|
|
break
|
|
|
|
else:
|
|
logger.warning("[%s] Unknown message type: %s", self.id, outgoing_msg.type)
|
|
if self.callback:
|
|
self.callback(outgoing_msg)
|
|
|
|
except Exception as exc:
|
|
logger.error("[%s] Failed to deserialize message: %s", self.id, exc)
|
|
if self._current_queue:
|
|
await self._current_queue.put(
|
|
AgentException("PARSE_ERROR", f"Validation failed: {exc}")
|
|
)
|
|
|
|
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
|
|
logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type)
|
|
break
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as exc:
|
|
logger.error("[%s] Error in listen loop: %s", self.id, exc)
|
|
finally:
|
|
await self._cleanup()
|