refactor: shrink agent api wrapper to thin adapter

This commit is contained in:
Mikhail Putilovskij 2026-04-22 00:22:20 +03:00
parent 4d917ac794
commit 569824ead1
2 changed files with 47 additions and 557 deletions

View file

@ -1,67 +1,43 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import inspect import inspect
import logging
import os
import re import re
import sys import sys
from collections.abc import AsyncIterator
from pathlib import Path from pathlib import Path
from urllib.parse import urlsplit, urlunsplit from urllib.parse import urlsplit, urlunsplit
import aiohttp
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api" _api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
if str(_api_root) not in sys.path: if str(_api_root) not in sys.path:
sys.path.insert(0, str(_api_root)) sys.path.insert(0, str(_api_root))
from lambda_agent_api.agent_api import AgentApi, AgentBusyException, AgentException # noqa: E402 from lambda_agent_api.agent_api import AgentApi # noqa: E402
from lambda_agent_api.client import EClientMessage, MsgUserMessage # noqa: E402
from lambda_agent_api.server import AgentEventUnion, MsgEventEnd, ServerMessage # noqa: E402
logger = logging.getLogger(__name__)
_DEBUG_STREAM = os.environ.get("SURFACES_AGENT_DEBUG_STREAM", "").strip().lower() in {
"1",
"true",
"yes",
}
_POST_END_DRAIN_MS = int(os.environ.get("SURFACES_AGENT_POST_END_DRAIN_MS", "120"))
_STREAM_IDLE_TIMEOUT_MS = int(os.environ.get("SURFACES_AGENT_IDLE_TIMEOUT_MS", "60000"))
class AgentApiWrapper(AgentApi): class AgentApiWrapper(AgentApi):
"""Capture tokens_used from MsgEventEnd without patching upstream code.""" """Thin construction/factory shim over the pinned upstream AgentApi."""
def __init__( def __init__(
self, self,
agent_id: str, agent_id: str,
base_url: str | None = None, base_url: str,
*, *,
chat_id: int | str = 0, chat_id: int | str = 0,
url: str | None = None,
**kwargs, **kwargs,
) -> None: ) -> None:
if base_url is None and url is None: self._base_url = self._normalize_base_url(base_url)
raise TypeError("AgentApiWrapper requires base_url or url")
self._base_url = self._normalize_base_url(base_url or url or "")
self._init_kwargs = dict(kwargs) self._init_kwargs = dict(kwargs)
self.chat_id = chat_id self.chat_id = chat_id
if self._supports_modern_constructor(): if not self._supports_modern_constructor():
super().__init__( raise RuntimeError(
agent_id=agent_id, "Pinned platform-agent_api is expected to support base_url + chat_id"
base_url=self._base_url,
chat_id=chat_id,
**kwargs,
) )
else:
super().__init__( super().__init__(
agent_id=agent_id, agent_id=agent_id,
url=self._build_ws_url(self._base_url, chat_id), base_url=self._base_url,
**kwargs, chat_id=chat_id,
) **kwargs,
self.last_tokens_used = 0 )
@staticmethod @staticmethod
def _supports_modern_constructor() -> bool: def _supports_modern_constructor() -> bool:
@ -69,247 +45,18 @@ class AgentApiWrapper(AgentApi):
parameters = inspect.signature(AgentApi.__init__).parameters parameters = inspect.signature(AgentApi.__init__).parameters
except (TypeError, ValueError): except (TypeError, ValueError):
return False return False
return "base_url" in parameters and "chat_id" in parameters return "base_url" in parameters and "chat_id" in parameters
@staticmethod @staticmethod
def _normalize_base_url(base_url: str) -> str: def _normalize_base_url(base_url: str) -> str:
parsed = urlsplit(base_url) parsed = urlsplit(base_url)
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/")) path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/"))
return urlunsplit((parsed.scheme, parsed.netloc, path, "", "")) return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
@staticmethod def for_chat(self, chat_id: int | str) -> "AgentApiWrapper":
def _build_ws_url(base_url: str, chat_id: int | str) -> str:
return base_url.rstrip("/") + f"/agent_ws/?thread_id={chat_id}"
def for_chat(self, chat_id: int | str) -> AgentApiWrapper:
return type(self)( return type(self)(
agent_id=self.id, agent_id=self.id,
base_url=self._base_url, base_url=self._base_url,
chat_id=chat_id, chat_id=chat_id,
**self._init_kwargs, **self._init_kwargs,
) )
@staticmethod
def _event_kind(event: object) -> str:
raw_kind = getattr(event, "type", None)
if hasattr(raw_kind, "value"):
raw_kind = raw_kind.value
if raw_kind is None:
raw_kind = event.__class__.__name__
kind = str(raw_kind).replace("-", "_")
if "_" in kind:
return kind.upper()
normalized = []
for index, char in enumerate(kind):
if index and char.isupper() and not kind[index - 1].isupper():
normalized.append("_")
normalized.append(char)
return "".join(normalized).upper()
@classmethod
def _is_kind(cls, event: object, *needles: str) -> bool:
kind = cls._event_kind(event)
return any(needle in kind for needle in needles)
@classmethod
def _is_text_event(cls, event: object) -> bool:
return hasattr(event, "text") or cls._is_kind(event, "TEXT_CHUNK")
@classmethod
def _is_end_event(cls, event: object) -> bool:
kind = cls._event_kind(event)
return kind == "END" or kind.endswith("_END")
@classmethod
def _is_send_file_event(cls, event: object) -> bool:
return "SEND_FILE" in cls._event_kind(event)
async def _publish_event(self, event: object, *, queue_event: object | None = None) -> None:
if self.callback:
self.callback(event)
if self._current_queue:
await self._current_queue.put(queue_event if queue_event is not None else event)
async def _publish_error(self, event: object) -> None:
if self.callback:
self.callback(event)
if self._current_queue and hasattr(event, "code") and hasattr(event, "details"):
await self._current_queue.put(AgentException(event.code, event.details))
async def _listen(self):
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
outgoing_msg = ServerMessage.validate_json(msg.data)
if self._is_text_event(outgoing_msg):
if _DEBUG_STREAM:
logger.warning(
"[%s] text chunk queue=%s text=%r",
self.id,
self._current_queue is not None,
getattr(outgoing_msg, "text", "")[:80],
)
if self._current_queue:
await self._current_queue.put(outgoing_msg)
elif self.callback:
self.callback(outgoing_msg)
else:
logger.warning("[%s] AgentEvent without active request", self.id)
elif self._is_end_event(outgoing_msg):
self.last_tokens_used = outgoing_msg.tokens_used
if _DEBUG_STREAM:
logger.warning(
"[%s] end event queue=%s tokens=%s",
self.id,
self._current_queue is not None,
getattr(outgoing_msg, "tokens_used", None),
)
await self._publish_event(outgoing_msg)
elif self._is_kind(outgoing_msg, "ERROR"):
error = AgentException(outgoing_msg.code, outgoing_msg.details)
logger.error("[%s] Agent error: %s", self.id, error)
await self._publish_error(outgoing_msg)
elif self._is_kind(outgoing_msg, "GRACEFUL_DISCONNECT"):
await self._publish_event(outgoing_msg)
logger.info("[%s] Gracefully disconnecting", self.id)
break
else:
await self._publish_event(outgoing_msg)
except Exception as exc:
logger.error("[%s] Failed to deserialize message: %s", self.id, exc)
if self._current_queue:
await self._current_queue.put(
AgentException("PARSE_ERROR", f"Validation failed: {exc}")
)
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type)
break
except asyncio.CancelledError:
pass
except Exception as exc:
logger.error("[%s] Error in listen loop: %s", self.id, exc)
finally:
await self._cleanup()
async def send_message(
self, text: str, attachments: list[str] | None = None
) -> AsyncIterator[AgentEventUnion]:
if not self._connected or not self._ws:
raise AgentException(
code="NOT_CONNECTED", details="Not connected. Call connect() first."
)
if self._request_lock.locked():
raise AgentBusyException("Agent is currently processing another request")
await self._request_lock.acquire()
try:
self._current_queue = asyncio.Queue()
message = MsgUserMessage(
type=EClientMessage.USER_MESSAGE,
text=text,
attachments=attachments or [],
)
await self._ws.send_str(message.model_dump_json())
logger.debug("[%s] Sent message: %s...", self.id, text[:50])
while True:
try:
chunk = await asyncio.wait_for(
self._current_queue.get(),
timeout=max(_STREAM_IDLE_TIMEOUT_MS, 0) / 1000,
)
except TimeoutError as exc:
raise AgentException(
"TIMEOUT",
(
"Timed out waiting for the next agent stream event "
f"after {max(_STREAM_IDLE_TIMEOUT_MS, 0)}ms"
),
) from exc
if isinstance(chunk, Exception):
raise chunk
if isinstance(chunk, MsgEventEnd):
self.last_tokens_used = chunk.tokens_used
async for late_chunk in self._drain_post_end_events():
yield late_chunk
break
yield chunk
finally:
if self._current_queue:
orphan_queue = self._current_queue
self._current_queue = None
while not orphan_queue.empty():
try:
orphan_msg = orphan_queue.get_nowait()
if isinstance(orphan_msg, Exception):
logger.debug(
"[%s] Dropped exception from queue during cleanup: %s",
self.id,
orphan_msg,
)
continue
if self.callback:
self.callback(orphan_msg)
else:
logger.debug("[%s] Dropped orphaned message during cleanup", self.id)
except asyncio.QueueEmpty:
break
if self._request_lock.locked():
self._request_lock.release()
async def _drain_post_end_events(self) -> AsyncIterator[AgentEventUnion]:
if self._current_queue is None:
return
timeout_s = max(_POST_END_DRAIN_MS, 0) / 1000
while True:
try:
chunk = await asyncio.wait_for(self._current_queue.get(), timeout=timeout_s)
except TimeoutError:
break
if isinstance(chunk, Exception):
logger.warning("[%s] dropping post-END exception: %s", self.id, chunk)
continue
if isinstance(chunk, MsgEventEnd):
self.last_tokens_used = chunk.tokens_used
if _DEBUG_STREAM:
logger.warning(
"[%s] dropped duplicate END tokens=%s",
self.id,
chunk.tokens_used,
)
continue
if _DEBUG_STREAM and self._is_text_event(chunk):
logger.warning(
"[%s] recovered post-END text chunk=%r",
self.id,
getattr(chunk, "text", "")[:80],
)
yield chunk

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
import pytest import pytest
from lambda_agent_api.server import MsgEventEnd, MsgEventTextChunk
import sdk.agent_api_wrapper as agent_api_wrapper_module import sdk.agent_api_wrapper as agent_api_wrapper_module
from core.protocol import SettingsAction from core.protocol import SettingsAction
@ -142,233 +141,61 @@ class TextChunkEvent:
self.type = "AGENT_EVENT_TEXT_CHUNK" self.type = "AGENT_EVENT_TEXT_CHUNK"
self.text = text self.text = text
class ToolCallChunkEvent:
def __init__(self, payload: str) -> None:
self.type = "AGENT_EVENT_TOOL_CALL_CHUNK"
self.payload = payload
class ToolResultEvent:
def __init__(self, payload: str) -> None:
self.type = "AGENT_EVENT_TOOL_RESULT"
self.payload = payload
class CustomUpdateEvent:
def __init__(self, payload: str) -> None:
self.type = "AGENT_EVENT_CUSTOM_UPDATE"
self.payload = payload
class EndEvent:
def __init__(self, tokens_used: int) -> None:
self.type = "AGENT_EVENT_END"
self.tokens_used = tokens_used
class ErrorEvent:
def __init__(self, code: str, details: str) -> None:
self.type = "ERROR"
self.code = code
self.details = details
class GracefulDisconnectEvent:
def __init__(self) -> None:
self.type = "GRACEFUL_DISCONNECT"
class FakeWSMessage:
def __init__(self, data: str) -> None:
self.type = agent_api_wrapper_module.aiohttp.WSMsgType.TEXT
self.data = data
class FakeWebSocket:
def __init__(self, messages: list[FakeWSMessage]) -> None:
self._messages = list(messages)
def __aiter__(self):
return self
async def __anext__(self):
if not self._messages:
raise StopAsyncIteration
return self._messages.pop(0)
class QueueFeedingWebSocket:
def __init__(self, owner, queued_events: list[object]) -> None:
self.owner = owner
self.queued_events = list(queued_events)
self.sent_payloads: list[str] = []
async def send_str(self, payload: str) -> None:
self.sent_payloads.append(payload)
for event in self.queued_events:
await self.owner._current_queue.put(event)
class SilentWebSocket:
def __init__(self) -> None:
self.sent_payloads: list[str] = []
async def send_str(self, payload: str) -> None:
self.sent_payloads.append(payload)
class MessageResponseWithAttachments(MessageResponse): class MessageResponseWithAttachments(MessageResponse):
attachments: list[Attachment] = [] attachments: list[Attachment] = []
def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch): def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch):
calls: list[dict[str, object]] = [] captured = {}
def fake_init(self, agent_id, base_url, chat_id, **kwargs): def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
calls.append( captured["agent_id"] = agent_id
{ captured["base_url"] = base_url
"agent_id": agent_id, captured["chat_id"] = chat_id
"base_url": base_url,
"chat_id": chat_id,
"kwargs": kwargs,
}
)
self.id = agent_id
self.url = base_url
self.callback = kwargs.get("callback")
self.on_disconnect = kwargs.get("on_disconnect")
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
wrapper = AgentApiWrapper( wrapper = AgentApiWrapper(
agent_id="agent-1", agent_id="agent-1",
base_url="https://agent.example.com/v1/agent_ws", base_url="ws://platform-agent:8000/v1/agent_ws/",
chat_id="chat-1", chat_id="41",
callback="cb",
on_disconnect="disconnect",
)
child = wrapper.for_chat("chat-2")
assert calls == [
{
"agent_id": "agent-1",
"base_url": "https://agent.example.com",
"chat_id": "chat-1",
"kwargs": {"callback": "cb", "on_disconnect": "disconnect"},
},
{
"agent_id": "agent-1",
"base_url": "https://agent.example.com",
"chat_id": "chat-2",
"kwargs": {"callback": "cb", "on_disconnect": "disconnect"},
},
]
assert wrapper._base_url == "https://agent.example.com"
assert wrapper.chat_id == "chat-1"
assert wrapper.last_tokens_used == 0
assert child.chat_id == "chat-2"
def test_agent_api_wrapper_falls_back_to_legacy_url_constructor(monkeypatch):
calls: list[dict[str, object]] = []
def fake_init(self, agent_id, url, callback=None, on_disconnect=None):
calls.append(
{
"agent_id": agent_id,
"url": url,
"callback": callback,
"on_disconnect": on_disconnect,
}
)
self.id = agent_id
self.url = url
self.callback = callback
self.on_disconnect = on_disconnect
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
wrapper = AgentApiWrapper(
agent_id="agent-2",
url="https://agent.example.com/agent_ws/",
chat_id="chat-9",
callback="cb",
) )
assert calls == [ assert wrapper.chat_id == "41"
{ assert wrapper._base_url == "ws://platform-agent:8000"
"agent_id": "agent-2", assert captured == {
"url": "https://agent.example.com/agent_ws/?thread_id=chat-9", "agent_id": "agent-1",
"callback": "cb", "base_url": "ws://platform-agent:8000",
"on_disconnect": None, "chat_id": "41",
} }
]
assert wrapper._base_url == "https://agent.example.com"
assert wrapper.chat_id == "chat-9"
assert wrapper.last_tokens_used == 0
@pytest.mark.asyncio def test_agent_api_wrapper_for_chat_reuses_normalized_base_url(monkeypatch):
async def test_agent_api_wrapper_recovers_late_text_after_first_end(monkeypatch): init_calls = []
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
self.id = agent_id self.id = agent_id
self.chat_id = chat_id
self.url = base_url self.url = base_url
self.callback = kwargs.get("callback") init_calls.append((agent_id, base_url, chat_id))
self.on_disconnect = kwargs.get("on_disconnect")
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
wrapper = AgentApiWrapper( root = AgentApiWrapper(
agent_id="agent-1", agent_id="agent-1",
base_url="https://agent.example.com/v1/agent_ws", base_url="http://platform-agent:8000/v1/agent_ws/",
chat_id="chat-1", chat_id="1",
)
wrapper._connected = True
wrapper._request_lock = asyncio.Lock()
wrapper._current_queue = None
wrapper._ws = QueueFeedingWebSocket(
wrapper,
[
MsgEventTextChunk(text="Иллюстра"),
MsgEventEnd(tokens_used=5),
MsgEventTextChunk(text="ция"),
MsgEventEnd(tokens_used=5),
],
) )
chunks = [] child = root.for_chat("99")
async for chunk in wrapper.send_message("hello"):
chunks.append(chunk)
assert [chunk.text for chunk in chunks] == ["Иллюстра", "ция"] assert child is not root
assert wrapper.last_tokens_used == 5 assert child.chat_id == "99"
assert child._base_url == "http://platform-agent:8000"
assert init_calls == [
@pytest.mark.asyncio ("agent-1", "http://platform-agent:8000", "1"),
async def test_agent_api_wrapper_times_out_on_idle_stream(monkeypatch): ("agent-1", "http://platform-agent:8000", "99"),
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): ]
self.id = agent_id
self.url = base_url
self.callback = kwargs.get("callback")
self.on_disconnect = kwargs.get("on_disconnect")
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
monkeypatch.setattr(agent_api_wrapper_module, "_STREAM_IDLE_TIMEOUT_MS", 10)
wrapper = AgentApiWrapper(
agent_id="agent-1",
base_url="https://agent.example.com/v1/agent_ws",
chat_id="chat-1",
)
wrapper._connected = True
wrapper._request_lock = asyncio.Lock()
wrapper._current_queue = None
wrapper._ws = SilentWebSocket()
with pytest.raises(agent_api_wrapper_module.AgentException, match="Timed out waiting"):
async for _ in wrapper.send_message("hello"):
pass
@pytest.mark.asyncio @pytest.mark.asyncio
@ -703,87 +530,3 @@ async def test_real_platform_client_settings_are_local():
assert isinstance(settings, UserSettings) assert isinstance(settings, UserSettings)
assert settings.skills["browser"] is True assert settings.skills["browser"] is True
assert settings.skills["web-search"] is True assert settings.skills["web-search"] is True
@pytest.mark.asyncio
async def test_agent_api_wrapper_transparently_surfaces_modern_events(monkeypatch):
callback_events: list[object] = []
queue: asyncio.Queue = asyncio.Queue()
event_map = {
"text": TextChunkEvent("he"),
"tool_call": ToolCallChunkEvent("call"),
"tool_result": ToolResultEvent("result"),
"custom_update": CustomUpdateEvent("update"),
"send_file": SendFileEvent(
workspace_path="/workspace/report.pdf",
mime_type="application/pdf",
filename="report.pdf",
size=123,
),
"end": EndEvent(tokens_used=11),
"error": ErrorEvent(code="BOOM", details="bad things"),
"disconnect": GracefulDisconnectEvent(),
}
def fake_validate_json(data: str):
return event_map[data]
monkeypatch.setattr(
agent_api_wrapper_module,
"ServerMessage",
type("FakeServerMessage", (), {"validate_json": staticmethod(fake_validate_json)}),
)
async def fake_cleanup(self):
return None
monkeypatch.setattr(agent_api_wrapper_module.AgentApiWrapper, "_cleanup", fake_cleanup)
monkeypatch.setattr(
agent_api_wrapper_module.AgentApi,
"__init__",
lambda self, agent_id, base_url=None, chat_id=0, **kwargs: (
setattr(self, "id", agent_id)
or setattr(self, "callback", kwargs.get("callback"))
or setattr(self, "on_disconnect", kwargs.get("on_disconnect"))
or setattr(self, "_current_queue", None)
),
)
wrapper = AgentApiWrapper(
agent_id="agent-1",
base_url="https://agent.example.com/v1/agent_ws",
chat_id="chat-1",
callback=callback_events.append,
)
wrapper._current_queue = queue
wrapper._ws = FakeWebSocket(
[
FakeWSMessage("text"),
FakeWSMessage("tool_call"),
FakeWSMessage("tool_result"),
FakeWSMessage("custom_update"),
FakeWSMessage("send_file"),
FakeWSMessage("end"),
FakeWSMessage("error"),
FakeWSMessage("disconnect"),
]
)
await wrapper._listen()
queue_events = []
while not queue.empty():
queue_events.append(await queue.get())
assert queue_events[0].text == "he"
assert any(isinstance(event, SendFileEvent) for event in queue_events)
assert any(isinstance(event, EndEvent) for event in queue_events)
assert any(isinstance(event, GracefulDisconnectEvent) for event in queue_events)
assert callback_events[0].payload == "call"
assert callback_events[1].payload == "result"
assert callback_events[2].payload == "update"
assert any(isinstance(event, SendFileEvent) for event in callback_events)
assert any(isinstance(event, EndEvent) for event in callback_events)
assert any(isinstance(event, ErrorEvent) for event in callback_events)
assert any(isinstance(event, GracefulDisconnectEvent) for event in callback_events)
assert wrapper.last_tokens_used == 11