refactor: shrink agent api wrapper to thin adapter
This commit is contained in:
parent
4d917ac794
commit
569824ead1
2 changed files with 47 additions and 557 deletions
|
|
@ -1,67 +1,43 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlsplit, urlunsplit
|
||||
|
||||
import aiohttp
|
||||
|
||||
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
|
||||
if str(_api_root) not in sys.path:
|
||||
sys.path.insert(0, str(_api_root))
|
||||
|
||||
from lambda_agent_api.agent_api import AgentApi, AgentBusyException, AgentException # noqa: E402
|
||||
from lambda_agent_api.client import EClientMessage, MsgUserMessage # noqa: E402
|
||||
from lambda_agent_api.server import AgentEventUnion, MsgEventEnd, ServerMessage # noqa: E402
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_DEBUG_STREAM = os.environ.get("SURFACES_AGENT_DEBUG_STREAM", "").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
}
|
||||
_POST_END_DRAIN_MS = int(os.environ.get("SURFACES_AGENT_POST_END_DRAIN_MS", "120"))
|
||||
_STREAM_IDLE_TIMEOUT_MS = int(os.environ.get("SURFACES_AGENT_IDLE_TIMEOUT_MS", "60000"))
|
||||
from lambda_agent_api.agent_api import AgentApi # noqa: E402
|
||||
|
||||
|
||||
class AgentApiWrapper(AgentApi):
|
||||
"""Capture tokens_used from MsgEventEnd without patching upstream code."""
|
||||
"""Thin construction/factory shim over the pinned upstream AgentApi."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
base_url: str | None = None,
|
||||
base_url: str,
|
||||
*,
|
||||
chat_id: int | str = 0,
|
||||
url: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if base_url is None and url is None:
|
||||
raise TypeError("AgentApiWrapper requires base_url or url")
|
||||
|
||||
self._base_url = self._normalize_base_url(base_url or url or "")
|
||||
self._base_url = self._normalize_base_url(base_url)
|
||||
self._init_kwargs = dict(kwargs)
|
||||
self.chat_id = chat_id
|
||||
if self._supports_modern_constructor():
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
base_url=self._base_url,
|
||||
chat_id=chat_id,
|
||||
**kwargs,
|
||||
if not self._supports_modern_constructor():
|
||||
raise RuntimeError(
|
||||
"Pinned platform-agent_api is expected to support base_url + chat_id"
|
||||
)
|
||||
else:
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
url=self._build_ws_url(self._base_url, chat_id),
|
||||
**kwargs,
|
||||
)
|
||||
self.last_tokens_used = 0
|
||||
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
base_url=self._base_url,
|
||||
chat_id=chat_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_modern_constructor() -> bool:
|
||||
|
|
@ -69,247 +45,18 @@ class AgentApiWrapper(AgentApi):
|
|||
parameters = inspect.signature(AgentApi.__init__).parameters
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
return "base_url" in parameters and "chat_id" in parameters
|
||||
|
||||
@staticmethod
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
parsed = urlsplit(base_url)
|
||||
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/"))
|
||||
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/"))
|
||||
return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
|
||||
|
||||
@staticmethod
|
||||
def _build_ws_url(base_url: str, chat_id: int | str) -> str:
|
||||
return base_url.rstrip("/") + f"/agent_ws/?thread_id={chat_id}"
|
||||
|
||||
def for_chat(self, chat_id: int | str) -> AgentApiWrapper:
|
||||
def for_chat(self, chat_id: int | str) -> "AgentApiWrapper":
|
||||
return type(self)(
|
||||
agent_id=self.id,
|
||||
base_url=self._base_url,
|
||||
chat_id=chat_id,
|
||||
**self._init_kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _event_kind(event: object) -> str:
|
||||
raw_kind = getattr(event, "type", None)
|
||||
if hasattr(raw_kind, "value"):
|
||||
raw_kind = raw_kind.value
|
||||
if raw_kind is None:
|
||||
raw_kind = event.__class__.__name__
|
||||
|
||||
kind = str(raw_kind).replace("-", "_")
|
||||
if "_" in kind:
|
||||
return kind.upper()
|
||||
|
||||
normalized = []
|
||||
for index, char in enumerate(kind):
|
||||
if index and char.isupper() and not kind[index - 1].isupper():
|
||||
normalized.append("_")
|
||||
normalized.append(char)
|
||||
return "".join(normalized).upper()
|
||||
|
||||
@classmethod
|
||||
def _is_kind(cls, event: object, *needles: str) -> bool:
|
||||
kind = cls._event_kind(event)
|
||||
return any(needle in kind for needle in needles)
|
||||
|
||||
@classmethod
|
||||
def _is_text_event(cls, event: object) -> bool:
|
||||
return hasattr(event, "text") or cls._is_kind(event, "TEXT_CHUNK")
|
||||
|
||||
@classmethod
|
||||
def _is_end_event(cls, event: object) -> bool:
|
||||
kind = cls._event_kind(event)
|
||||
return kind == "END" or kind.endswith("_END")
|
||||
|
||||
@classmethod
|
||||
def _is_send_file_event(cls, event: object) -> bool:
|
||||
return "SEND_FILE" in cls._event_kind(event)
|
||||
|
||||
async def _publish_event(self, event: object, *, queue_event: object | None = None) -> None:
|
||||
if self.callback:
|
||||
self.callback(event)
|
||||
if self._current_queue:
|
||||
await self._current_queue.put(queue_event if queue_event is not None else event)
|
||||
|
||||
async def _publish_error(self, event: object) -> None:
|
||||
if self.callback:
|
||||
self.callback(event)
|
||||
if self._current_queue and hasattr(event, "code") and hasattr(event, "details"):
|
||||
await self._current_queue.put(AgentException(event.code, event.details))
|
||||
|
||||
async def _listen(self):
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
outgoing_msg = ServerMessage.validate_json(msg.data)
|
||||
|
||||
if self._is_text_event(outgoing_msg):
|
||||
if _DEBUG_STREAM:
|
||||
logger.warning(
|
||||
"[%s] text chunk queue=%s text=%r",
|
||||
self.id,
|
||||
self._current_queue is not None,
|
||||
getattr(outgoing_msg, "text", "")[:80],
|
||||
)
|
||||
if self._current_queue:
|
||||
await self._current_queue.put(outgoing_msg)
|
||||
elif self.callback:
|
||||
self.callback(outgoing_msg)
|
||||
else:
|
||||
logger.warning("[%s] AgentEvent without active request", self.id)
|
||||
|
||||
elif self._is_end_event(outgoing_msg):
|
||||
self.last_tokens_used = outgoing_msg.tokens_used
|
||||
if _DEBUG_STREAM:
|
||||
logger.warning(
|
||||
"[%s] end event queue=%s tokens=%s",
|
||||
self.id,
|
||||
self._current_queue is not None,
|
||||
getattr(outgoing_msg, "tokens_used", None),
|
||||
)
|
||||
await self._publish_event(outgoing_msg)
|
||||
|
||||
elif self._is_kind(outgoing_msg, "ERROR"):
|
||||
error = AgentException(outgoing_msg.code, outgoing_msg.details)
|
||||
logger.error("[%s] Agent error: %s", self.id, error)
|
||||
await self._publish_error(outgoing_msg)
|
||||
|
||||
elif self._is_kind(outgoing_msg, "GRACEFUL_DISCONNECT"):
|
||||
await self._publish_event(outgoing_msg)
|
||||
logger.info("[%s] Gracefully disconnecting", self.id)
|
||||
break
|
||||
|
||||
else:
|
||||
await self._publish_event(outgoing_msg)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("[%s] Failed to deserialize message: %s", self.id, exc)
|
||||
if self._current_queue:
|
||||
await self._current_queue.put(
|
||||
AgentException("PARSE_ERROR", f"Validation failed: {exc}")
|
||||
)
|
||||
|
||||
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
|
||||
logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type)
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.error("[%s] Error in listen loop: %s", self.id, exc)
|
||||
finally:
|
||||
await self._cleanup()
|
||||
|
||||
async def send_message(
|
||||
self, text: str, attachments: list[str] | None = None
|
||||
) -> AsyncIterator[AgentEventUnion]:
|
||||
if not self._connected or not self._ws:
|
||||
raise AgentException(
|
||||
code="NOT_CONNECTED", details="Not connected. Call connect() first."
|
||||
)
|
||||
|
||||
if self._request_lock.locked():
|
||||
raise AgentBusyException("Agent is currently processing another request")
|
||||
|
||||
await self._request_lock.acquire()
|
||||
try:
|
||||
self._current_queue = asyncio.Queue()
|
||||
|
||||
message = MsgUserMessage(
|
||||
type=EClientMessage.USER_MESSAGE,
|
||||
text=text,
|
||||
attachments=attachments or [],
|
||||
)
|
||||
|
||||
await self._ws.send_str(message.model_dump_json())
|
||||
logger.debug("[%s] Sent message: %s...", self.id, text[:50])
|
||||
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
self._current_queue.get(),
|
||||
timeout=max(_STREAM_IDLE_TIMEOUT_MS, 0) / 1000,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
raise AgentException(
|
||||
"TIMEOUT",
|
||||
(
|
||||
"Timed out waiting for the next agent stream event "
|
||||
f"after {max(_STREAM_IDLE_TIMEOUT_MS, 0)}ms"
|
||||
),
|
||||
) from exc
|
||||
|
||||
if isinstance(chunk, Exception):
|
||||
raise chunk
|
||||
|
||||
if isinstance(chunk, MsgEventEnd):
|
||||
self.last_tokens_used = chunk.tokens_used
|
||||
async for late_chunk in self._drain_post_end_events():
|
||||
yield late_chunk
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
finally:
|
||||
if self._current_queue:
|
||||
orphan_queue = self._current_queue
|
||||
self._current_queue = None
|
||||
|
||||
while not orphan_queue.empty():
|
||||
try:
|
||||
orphan_msg = orphan_queue.get_nowait()
|
||||
if isinstance(orphan_msg, Exception):
|
||||
logger.debug(
|
||||
"[%s] Dropped exception from queue during cleanup: %s",
|
||||
self.id,
|
||||
orphan_msg,
|
||||
)
|
||||
continue
|
||||
|
||||
if self.callback:
|
||||
self.callback(orphan_msg)
|
||||
else:
|
||||
logger.debug("[%s] Dropped orphaned message during cleanup", self.id)
|
||||
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if self._request_lock.locked():
|
||||
self._request_lock.release()
|
||||
|
||||
async def _drain_post_end_events(self) -> AsyncIterator[AgentEventUnion]:
|
||||
if self._current_queue is None:
|
||||
return
|
||||
|
||||
timeout_s = max(_POST_END_DRAIN_MS, 0) / 1000
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(self._current_queue.get(), timeout=timeout_s)
|
||||
except TimeoutError:
|
||||
break
|
||||
|
||||
if isinstance(chunk, Exception):
|
||||
logger.warning("[%s] dropping post-END exception: %s", self.id, chunk)
|
||||
continue
|
||||
|
||||
if isinstance(chunk, MsgEventEnd):
|
||||
self.last_tokens_used = chunk.tokens_used
|
||||
if _DEBUG_STREAM:
|
||||
logger.warning(
|
||||
"[%s] dropped duplicate END tokens=%s",
|
||||
self.id,
|
||||
chunk.tokens_used,
|
||||
)
|
||||
continue
|
||||
|
||||
if _DEBUG_STREAM and self._is_text_event(chunk):
|
||||
logger.warning(
|
||||
"[%s] recovered post-END text chunk=%r",
|
||||
self.id,
|
||||
getattr(chunk, "text", "")[:80],
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue