diff --git a/sdk/agent_api_wrapper.py b/sdk/agent_api_wrapper.py index f29f820..fa69816 100644 --- a/sdk/agent_api_wrapper.py +++ b/sdk/agent_api_wrapper.py @@ -1,67 +1,43 @@ from __future__ import annotations -import asyncio import inspect -import logging -import os import re import sys -from collections.abc import AsyncIterator from pathlib import Path from urllib.parse import urlsplit, urlunsplit -import aiohttp - _api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api" if str(_api_root) not in sys.path: sys.path.insert(0, str(_api_root)) -from lambda_agent_api.agent_api import AgentApi, AgentBusyException, AgentException # noqa: E402 -from lambda_agent_api.client import EClientMessage, MsgUserMessage # noqa: E402 -from lambda_agent_api.server import AgentEventUnion, MsgEventEnd, ServerMessage # noqa: E402 - -logger = logging.getLogger(__name__) -_DEBUG_STREAM = os.environ.get("SURFACES_AGENT_DEBUG_STREAM", "").strip().lower() in { - "1", - "true", - "yes", -} -_POST_END_DRAIN_MS = int(os.environ.get("SURFACES_AGENT_POST_END_DRAIN_MS", "120")) -_STREAM_IDLE_TIMEOUT_MS = int(os.environ.get("SURFACES_AGENT_IDLE_TIMEOUT_MS", "60000")) +from lambda_agent_api.agent_api import AgentApi # noqa: E402 class AgentApiWrapper(AgentApi): - """Capture tokens_used from MsgEventEnd without patching upstream code.""" + """Thin construction/factory shim over the pinned upstream AgentApi.""" def __init__( self, agent_id: str, - base_url: str | None = None, + base_url: str, *, chat_id: int | str = 0, - url: str | None = None, **kwargs, ) -> None: - if base_url is None and url is None: - raise TypeError("AgentApiWrapper requires base_url or url") - - self._base_url = self._normalize_base_url(base_url or url or "") + self._base_url = self._normalize_base_url(base_url) self._init_kwargs = dict(kwargs) self.chat_id = chat_id - if self._supports_modern_constructor(): - super().__init__( - agent_id=agent_id, - base_url=self._base_url, - chat_id=chat_id, - **kwargs, + if not self._supports_modern_constructor(): + raise RuntimeError( + "Pinned platform-agent_api is expected to support base_url + chat_id" ) - else: - super().__init__( - agent_id=agent_id, - url=self._build_ws_url(self._base_url, chat_id), - **kwargs, - ) - self.last_tokens_used = 0 + + super().__init__( + agent_id=agent_id, + base_url=self._base_url, + chat_id=chat_id, + **kwargs, + ) @staticmethod def _supports_modern_constructor() -> bool: @@ -69,247 +45,18 @@ class AgentApiWrapper(AgentApi): parameters = inspect.signature(AgentApi.__init__).parameters except (TypeError, ValueError): return False - return "base_url" in parameters and "chat_id" in parameters @staticmethod def _normalize_base_url(base_url: str) -> str: parsed = urlsplit(base_url) - path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/")) + path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/")) return urlunsplit((parsed.scheme, parsed.netloc, path, "", "")) - @staticmethod - def _build_ws_url(base_url: str, chat_id: int | str) -> str: - return base_url.rstrip("/") + f"/agent_ws/?thread_id={chat_id}" - - def for_chat(self, chat_id: int | str) -> AgentApiWrapper: + def for_chat(self, chat_id: int | str) -> "AgentApiWrapper": return type(self)( agent_id=self.id, base_url=self._base_url, chat_id=chat_id, **self._init_kwargs, ) - - @staticmethod - def _event_kind(event: object) -> str: - raw_kind = getattr(event, "type", None) - if hasattr(raw_kind, "value"): - raw_kind = raw_kind.value - if raw_kind is None: - raw_kind = event.__class__.__name__ - - kind = str(raw_kind).replace("-", "_") - if "_" in kind: - return kind.upper() - - normalized = [] - for index, char in enumerate(kind): - if index and char.isupper() and not kind[index - 1].isupper(): - normalized.append("_") - normalized.append(char) - return "".join(normalized).upper() - - @classmethod - def _is_kind(cls, event: object, *needles: str) -> bool: - kind = cls._event_kind(event) - return any(needle in kind for needle in needles) - - @classmethod - def _is_text_event(cls, event: object) -> bool: - return hasattr(event, "text") or cls._is_kind(event, "TEXT_CHUNK") - - @classmethod - def _is_end_event(cls, event: object) -> bool: - kind = cls._event_kind(event) - return kind == "END" or kind.endswith("_END") - - @classmethod - def _is_send_file_event(cls, event: object) -> bool: - return "SEND_FILE" in cls._event_kind(event) - - async def _publish_event(self, event: object, *, queue_event: object | None = None) -> None: - if self.callback: - self.callback(event) - if self._current_queue: - await self._current_queue.put(queue_event if queue_event is not None else event) - - async def _publish_error(self, event: object) -> None: - if self.callback: - self.callback(event) - if self._current_queue and hasattr(event, "code") and hasattr(event, "details"): - await self._current_queue.put(AgentException(event.code, event.details)) - - async def _listen(self): - try: - async for msg in self._ws: - if msg.type == aiohttp.WSMsgType.TEXT: - try: - outgoing_msg = ServerMessage.validate_json(msg.data) - - if self._is_text_event(outgoing_msg): - if _DEBUG_STREAM: - logger.warning( - "[%s] text chunk queue=%s text=%r", - self.id, - self._current_queue is not None, - getattr(outgoing_msg, "text", "")[:80], - ) - if self._current_queue: - await self._current_queue.put(outgoing_msg) - elif self.callback: - self.callback(outgoing_msg) - else: - logger.warning("[%s] AgentEvent without active request", self.id) - - elif self._is_end_event(outgoing_msg): - self.last_tokens_used = outgoing_msg.tokens_used - if _DEBUG_STREAM: - logger.warning( - "[%s] end event queue=%s tokens=%s", - self.id, - self._current_queue is not None, - getattr(outgoing_msg, "tokens_used", None), - ) - await self._publish_event(outgoing_msg) - - elif self._is_kind(outgoing_msg, "ERROR"): - error = AgentException(outgoing_msg.code, outgoing_msg.details) - logger.error("[%s] Agent error: %s", self.id, error) - await self._publish_error(outgoing_msg) - - elif self._is_kind(outgoing_msg, "GRACEFUL_DISCONNECT"): - await self._publish_event(outgoing_msg) - logger.info("[%s] Gracefully disconnecting", self.id) - break - - else: - await self._publish_event(outgoing_msg) - - except Exception as exc: - logger.error("[%s] Failed to deserialize message: %s", self.id, exc) - if self._current_queue: - await self._current_queue.put( - AgentException("PARSE_ERROR", f"Validation failed: {exc}") - ) - - elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED): - logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type) - break - - except asyncio.CancelledError: - pass - except Exception as exc: - logger.error("[%s] Error in listen loop: %s", self.id, exc) - finally: - await self._cleanup() - - async def send_message( - self, text: str, attachments: list[str] | None = None - ) -> AsyncIterator[AgentEventUnion]: - if not self._connected or not self._ws: - raise AgentException( - code="NOT_CONNECTED", details="Not connected. Call connect() first." - ) - - if self._request_lock.locked(): - raise AgentBusyException("Agent is currently processing another request") - - await self._request_lock.acquire() - try: - self._current_queue = asyncio.Queue() - - message = MsgUserMessage( - type=EClientMessage.USER_MESSAGE, - text=text, - attachments=attachments or [], - ) - - await self._ws.send_str(message.model_dump_json()) - logger.debug("[%s] Sent message: %s...", self.id, text[:50]) - - while True: - try: - chunk = await asyncio.wait_for( - self._current_queue.get(), - timeout=max(_STREAM_IDLE_TIMEOUT_MS, 0) / 1000, - ) - except TimeoutError as exc: - raise AgentException( - "TIMEOUT", - ( - "Timed out waiting for the next agent stream event " - f"after {max(_STREAM_IDLE_TIMEOUT_MS, 0)}ms" - ), - ) from exc - - if isinstance(chunk, Exception): - raise chunk - - if isinstance(chunk, MsgEventEnd): - self.last_tokens_used = chunk.tokens_used - async for late_chunk in self._drain_post_end_events(): - yield late_chunk - break - - yield chunk - - finally: - if self._current_queue: - orphan_queue = self._current_queue - self._current_queue = None - - while not orphan_queue.empty(): - try: - orphan_msg = orphan_queue.get_nowait() - if isinstance(orphan_msg, Exception): - logger.debug( - "[%s] Dropped exception from queue during cleanup: %s", - self.id, - orphan_msg, - ) - continue - - if self.callback: - self.callback(orphan_msg) - else: - logger.debug("[%s] Dropped orphaned message during cleanup", self.id) - - except asyncio.QueueEmpty: - break - - if self._request_lock.locked(): - self._request_lock.release() - - async def _drain_post_end_events(self) -> AsyncIterator[AgentEventUnion]: - if self._current_queue is None: - return - - timeout_s = max(_POST_END_DRAIN_MS, 0) / 1000 - while True: - try: - chunk = await asyncio.wait_for(self._current_queue.get(), timeout=timeout_s) - except TimeoutError: - break - - if isinstance(chunk, Exception): - logger.warning("[%s] dropping post-END exception: %s", self.id, chunk) - continue - - if isinstance(chunk, MsgEventEnd): - self.last_tokens_used = chunk.tokens_used - if _DEBUG_STREAM: - logger.warning( - "[%s] dropped duplicate END tokens=%s", - self.id, - chunk.tokens_used, - ) - continue - - if _DEBUG_STREAM and self._is_text_event(chunk): - logger.warning( - "[%s] recovered post-END text chunk=%r", - self.id, - getattr(chunk, "text", "")[:80], - ) - - yield chunk diff --git a/tests/platform/test_real.py b/tests/platform/test_real.py index 2291d9d..382b554 100644 --- a/tests/platform/test_real.py +++ b/tests/platform/test_real.py @@ -1,7 +1,6 @@ import asyncio import pytest -from lambda_agent_api.server import MsgEventEnd, MsgEventTextChunk import sdk.agent_api_wrapper as agent_api_wrapper_module from core.protocol import SettingsAction @@ -142,233 +141,61 @@ class TextChunkEvent: self.type = "AGENT_EVENT_TEXT_CHUNK" self.text = text - -class ToolCallChunkEvent: - def __init__(self, payload: str) -> None: - self.type = "AGENT_EVENT_TOOL_CALL_CHUNK" - self.payload = payload - - -class ToolResultEvent: - def __init__(self, payload: str) -> None: - self.type = "AGENT_EVENT_TOOL_RESULT" - self.payload = payload - - -class CustomUpdateEvent: - def __init__(self, payload: str) -> None: - self.type = "AGENT_EVENT_CUSTOM_UPDATE" - self.payload = payload - - -class EndEvent: - def __init__(self, tokens_used: int) -> None: - self.type = "AGENT_EVENT_END" - self.tokens_used = tokens_used - - -class ErrorEvent: - def __init__(self, code: str, details: str) -> None: - self.type = "ERROR" - self.code = code - self.details = details - - -class GracefulDisconnectEvent: - def __init__(self) -> None: - self.type = "GRACEFUL_DISCONNECT" - - -class FakeWSMessage: - def __init__(self, data: str) -> None: - self.type = agent_api_wrapper_module.aiohttp.WSMsgType.TEXT - self.data = data - - -class FakeWebSocket: - def __init__(self, messages: list[FakeWSMessage]) -> None: - self._messages = list(messages) - - def __aiter__(self): - return self - - async def __anext__(self): - if not self._messages: - raise StopAsyncIteration - return self._messages.pop(0) - - -class QueueFeedingWebSocket: - def __init__(self, owner, queued_events: list[object]) -> None: - self.owner = owner - self.queued_events = list(queued_events) - self.sent_payloads: list[str] = [] - - async def send_str(self, payload: str) -> None: - self.sent_payloads.append(payload) - for event in self.queued_events: - await self.owner._current_queue.put(event) - - -class SilentWebSocket: - def __init__(self) -> None: - self.sent_payloads: list[str] = [] - - async def send_str(self, payload: str) -> None: - self.sent_payloads.append(payload) - - class MessageResponseWithAttachments(MessageResponse): attachments: list[Attachment] = [] -def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch): - calls: list[dict[str, object]] = [] +def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch): + captured = {} - def fake_init(self, agent_id, base_url, chat_id, **kwargs): - calls.append( - { - "agent_id": agent_id, - "base_url": base_url, - "chat_id": chat_id, - "kwargs": kwargs, - } - ) - self.id = agent_id - self.url = base_url - self.callback = kwargs.get("callback") - self.on_disconnect = kwargs.get("on_disconnect") + def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): + captured["agent_id"] = agent_id + captured["base_url"] = base_url + captured["chat_id"] = chat_id monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) wrapper = AgentApiWrapper( agent_id="agent-1", - base_url="https://agent.example.com/v1/agent_ws", - chat_id="chat-1", - callback="cb", - on_disconnect="disconnect", - ) - child = wrapper.for_chat("chat-2") - - assert calls == [ - { - "agent_id": "agent-1", - "base_url": "https://agent.example.com", - "chat_id": "chat-1", - "kwargs": {"callback": "cb", "on_disconnect": "disconnect"}, - }, - { - "agent_id": "agent-1", - "base_url": "https://agent.example.com", - "chat_id": "chat-2", - "kwargs": {"callback": "cb", "on_disconnect": "disconnect"}, - }, - ] - assert wrapper._base_url == "https://agent.example.com" - assert wrapper.chat_id == "chat-1" - assert wrapper.last_tokens_used == 0 - assert child.chat_id == "chat-2" - - -def test_agent_api_wrapper_falls_back_to_legacy_url_constructor(monkeypatch): - calls: list[dict[str, object]] = [] - - def fake_init(self, agent_id, url, callback=None, on_disconnect=None): - calls.append( - { - "agent_id": agent_id, - "url": url, - "callback": callback, - "on_disconnect": on_disconnect, - } - ) - self.id = agent_id - self.url = url - self.callback = callback - self.on_disconnect = on_disconnect - - monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) - - wrapper = AgentApiWrapper( - agent_id="agent-2", - url="https://agent.example.com/agent_ws/", - chat_id="chat-9", - callback="cb", + base_url="ws://platform-agent:8000/v1/agent_ws/", + chat_id="41", ) - assert calls == [ - { - "agent_id": "agent-2", - "url": "https://agent.example.com/agent_ws/?thread_id=chat-9", - "callback": "cb", - "on_disconnect": None, - } - ] - assert wrapper._base_url == "https://agent.example.com" - assert wrapper.chat_id == "chat-9" - assert wrapper.last_tokens_used == 0 + assert wrapper.chat_id == "41" + assert wrapper._base_url == "ws://platform-agent:8000" + assert captured == { + "agent_id": "agent-1", + "base_url": "ws://platform-agent:8000", + "chat_id": "41", + } -@pytest.mark.asyncio -async def test_agent_api_wrapper_recovers_late_text_after_first_end(monkeypatch): +def test_agent_api_wrapper_for_chat_reuses_normalized_base_url(monkeypatch): + init_calls = [] + def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): self.id = agent_id + self.chat_id = chat_id self.url = base_url - self.callback = kwargs.get("callback") - self.on_disconnect = kwargs.get("on_disconnect") + init_calls.append((agent_id, base_url, chat_id)) monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) - wrapper = AgentApiWrapper( + root = AgentApiWrapper( agent_id="agent-1", - base_url="https://agent.example.com/v1/agent_ws", - chat_id="chat-1", - ) - wrapper._connected = True - wrapper._request_lock = asyncio.Lock() - wrapper._current_queue = None - wrapper._ws = QueueFeedingWebSocket( - wrapper, - [ - MsgEventTextChunk(text="Иллюстра"), - MsgEventEnd(tokens_used=5), - MsgEventTextChunk(text="ция"), - MsgEventEnd(tokens_used=5), - ], + base_url="http://platform-agent:8000/v1/agent_ws/", + chat_id="1", ) - chunks = [] - async for chunk in wrapper.send_message("hello"): - chunks.append(chunk) + child = root.for_chat("99") - assert [chunk.text for chunk in chunks] == ["Иллюстра", "ция"] - assert wrapper.last_tokens_used == 5 - - -@pytest.mark.asyncio -async def test_agent_api_wrapper_times_out_on_idle_stream(monkeypatch): - def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): - self.id = agent_id - self.url = base_url - self.callback = kwargs.get("callback") - self.on_disconnect = kwargs.get("on_disconnect") - - monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) - monkeypatch.setattr(agent_api_wrapper_module, "_STREAM_IDLE_TIMEOUT_MS", 10) - - wrapper = AgentApiWrapper( - agent_id="agent-1", - base_url="https://agent.example.com/v1/agent_ws", - chat_id="chat-1", - ) - wrapper._connected = True - wrapper._request_lock = asyncio.Lock() - wrapper._current_queue = None - wrapper._ws = SilentWebSocket() - - with pytest.raises(agent_api_wrapper_module.AgentException, match="Timed out waiting"): - async for _ in wrapper.send_message("hello"): - pass + assert child is not root + assert child.chat_id == "99" + assert child._base_url == "http://platform-agent:8000" + assert init_calls == [ + ("agent-1", "http://platform-agent:8000", "1"), + ("agent-1", "http://platform-agent:8000", "99"), + ] @pytest.mark.asyncio @@ -703,87 +530,3 @@ async def test_real_platform_client_settings_are_local(): assert isinstance(settings, UserSettings) assert settings.skills["browser"] is True assert settings.skills["web-search"] is True - - -@pytest.mark.asyncio -async def test_agent_api_wrapper_transparently_surfaces_modern_events(monkeypatch): - callback_events: list[object] = [] - queue: asyncio.Queue = asyncio.Queue() - event_map = { - "text": TextChunkEvent("he"), - "tool_call": ToolCallChunkEvent("call"), - "tool_result": ToolResultEvent("result"), - "custom_update": CustomUpdateEvent("update"), - "send_file": SendFileEvent( - workspace_path="/workspace/report.pdf", - mime_type="application/pdf", - filename="report.pdf", - size=123, - ), - "end": EndEvent(tokens_used=11), - "error": ErrorEvent(code="BOOM", details="bad things"), - "disconnect": GracefulDisconnectEvent(), - } - - def fake_validate_json(data: str): - return event_map[data] - - monkeypatch.setattr( - agent_api_wrapper_module, - "ServerMessage", - type("FakeServerMessage", (), {"validate_json": staticmethod(fake_validate_json)}), - ) - - async def fake_cleanup(self): - return None - - monkeypatch.setattr(agent_api_wrapper_module.AgentApiWrapper, "_cleanup", fake_cleanup) - monkeypatch.setattr( - agent_api_wrapper_module.AgentApi, - "__init__", - lambda self, agent_id, base_url=None, chat_id=0, **kwargs: ( - setattr(self, "id", agent_id) - or setattr(self, "callback", kwargs.get("callback")) - or setattr(self, "on_disconnect", kwargs.get("on_disconnect")) - or setattr(self, "_current_queue", None) - ), - ) - - wrapper = AgentApiWrapper( - agent_id="agent-1", - base_url="https://agent.example.com/v1/agent_ws", - chat_id="chat-1", - callback=callback_events.append, - ) - wrapper._current_queue = queue - wrapper._ws = FakeWebSocket( - [ - FakeWSMessage("text"), - FakeWSMessage("tool_call"), - FakeWSMessage("tool_result"), - FakeWSMessage("custom_update"), - FakeWSMessage("send_file"), - FakeWSMessage("end"), - FakeWSMessage("error"), - FakeWSMessage("disconnect"), - ] - ) - - await wrapper._listen() - - queue_events = [] - while not queue.empty(): - queue_events.append(await queue.get()) - - assert queue_events[0].text == "he" - assert any(isinstance(event, SendFileEvent) for event in queue_events) - assert any(isinstance(event, EndEvent) for event in queue_events) - assert any(isinstance(event, GracefulDisconnectEvent) for event in queue_events) - assert callback_events[0].payload == "call" - assert callback_events[1].payload == "result" - assert callback_events[2].payload == "update" - assert any(isinstance(event, SendFileEvent) for event in callback_events) - assert any(isinstance(event, EndEvent) for event in callback_events) - assert any(isinstance(event, ErrorEvent) for event in callback_events) - assert any(isinstance(event, GracefulDisconnectEvent) for event in callback_events) - assert wrapper.last_tokens_used == 11