refactor: shrink agent api wrapper to thin adapter
This commit is contained in:
parent
4d917ac794
commit
569824ead1
2 changed files with 47 additions and 557 deletions
|
|
@ -1,67 +1,43 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urlsplit, urlunsplit
|
from urllib.parse import urlsplit, urlunsplit
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
|
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
|
||||||
if str(_api_root) not in sys.path:
|
if str(_api_root) not in sys.path:
|
||||||
sys.path.insert(0, str(_api_root))
|
sys.path.insert(0, str(_api_root))
|
||||||
|
|
||||||
from lambda_agent_api.agent_api import AgentApi, AgentBusyException, AgentException # noqa: E402
|
from lambda_agent_api.agent_api import AgentApi # noqa: E402
|
||||||
from lambda_agent_api.client import EClientMessage, MsgUserMessage # noqa: E402
|
|
||||||
from lambda_agent_api.server import AgentEventUnion, MsgEventEnd, ServerMessage # noqa: E402
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
_DEBUG_STREAM = os.environ.get("SURFACES_AGENT_DEBUG_STREAM", "").strip().lower() in {
|
|
||||||
"1",
|
|
||||||
"true",
|
|
||||||
"yes",
|
|
||||||
}
|
|
||||||
_POST_END_DRAIN_MS = int(os.environ.get("SURFACES_AGENT_POST_END_DRAIN_MS", "120"))
|
|
||||||
_STREAM_IDLE_TIMEOUT_MS = int(os.environ.get("SURFACES_AGENT_IDLE_TIMEOUT_MS", "60000"))
|
|
||||||
|
|
||||||
|
|
||||||
class AgentApiWrapper(AgentApi):
|
class AgentApiWrapper(AgentApi):
|
||||||
"""Capture tokens_used from MsgEventEnd without patching upstream code."""
|
"""Thin construction/factory shim over the pinned upstream AgentApi."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
base_url: str | None = None,
|
base_url: str,
|
||||||
*,
|
*,
|
||||||
chat_id: int | str = 0,
|
chat_id: int | str = 0,
|
||||||
url: str | None = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if base_url is None and url is None:
|
self._base_url = self._normalize_base_url(base_url)
|
||||||
raise TypeError("AgentApiWrapper requires base_url or url")
|
|
||||||
|
|
||||||
self._base_url = self._normalize_base_url(base_url or url or "")
|
|
||||||
self._init_kwargs = dict(kwargs)
|
self._init_kwargs = dict(kwargs)
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
if self._supports_modern_constructor():
|
if not self._supports_modern_constructor():
|
||||||
super().__init__(
|
raise RuntimeError(
|
||||||
agent_id=agent_id,
|
"Pinned platform-agent_api is expected to support base_url + chat_id"
|
||||||
base_url=self._base_url,
|
|
||||||
chat_id=chat_id,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
url=self._build_ws_url(self._base_url, chat_id),
|
base_url=self._base_url,
|
||||||
**kwargs,
|
chat_id=chat_id,
|
||||||
)
|
**kwargs,
|
||||||
self.last_tokens_used = 0
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_modern_constructor() -> bool:
|
def _supports_modern_constructor() -> bool:
|
||||||
|
|
@ -69,247 +45,18 @@ class AgentApiWrapper(AgentApi):
|
||||||
parameters = inspect.signature(AgentApi.__init__).parameters
|
parameters = inspect.signature(AgentApi.__init__).parameters
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return "base_url" in parameters and "chat_id" in parameters
|
return "base_url" in parameters and "chat_id" in parameters
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_base_url(base_url: str) -> str:
|
def _normalize_base_url(base_url: str) -> str:
|
||||||
parsed = urlsplit(base_url)
|
parsed = urlsplit(base_url)
|
||||||
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?$", "", parsed.path.rstrip("/"))
|
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/"))
|
||||||
return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
|
return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
|
||||||
|
|
||||||
@staticmethod
|
def for_chat(self, chat_id: int | str) -> "AgentApiWrapper":
|
||||||
def _build_ws_url(base_url: str, chat_id: int | str) -> str:
|
|
||||||
return base_url.rstrip("/") + f"/agent_ws/?thread_id={chat_id}"
|
|
||||||
|
|
||||||
def for_chat(self, chat_id: int | str) -> AgentApiWrapper:
|
|
||||||
return type(self)(
|
return type(self)(
|
||||||
agent_id=self.id,
|
agent_id=self.id,
|
||||||
base_url=self._base_url,
|
base_url=self._base_url,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
**self._init_kwargs,
|
**self._init_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _event_kind(event: object) -> str:
|
|
||||||
raw_kind = getattr(event, "type", None)
|
|
||||||
if hasattr(raw_kind, "value"):
|
|
||||||
raw_kind = raw_kind.value
|
|
||||||
if raw_kind is None:
|
|
||||||
raw_kind = event.__class__.__name__
|
|
||||||
|
|
||||||
kind = str(raw_kind).replace("-", "_")
|
|
||||||
if "_" in kind:
|
|
||||||
return kind.upper()
|
|
||||||
|
|
||||||
normalized = []
|
|
||||||
for index, char in enumerate(kind):
|
|
||||||
if index and char.isupper() and not kind[index - 1].isupper():
|
|
||||||
normalized.append("_")
|
|
||||||
normalized.append(char)
|
|
||||||
return "".join(normalized).upper()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _is_kind(cls, event: object, *needles: str) -> bool:
|
|
||||||
kind = cls._event_kind(event)
|
|
||||||
return any(needle in kind for needle in needles)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _is_text_event(cls, event: object) -> bool:
|
|
||||||
return hasattr(event, "text") or cls._is_kind(event, "TEXT_CHUNK")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _is_end_event(cls, event: object) -> bool:
|
|
||||||
kind = cls._event_kind(event)
|
|
||||||
return kind == "END" or kind.endswith("_END")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _is_send_file_event(cls, event: object) -> bool:
|
|
||||||
return "SEND_FILE" in cls._event_kind(event)
|
|
||||||
|
|
||||||
async def _publish_event(self, event: object, *, queue_event: object | None = None) -> None:
|
|
||||||
if self.callback:
|
|
||||||
self.callback(event)
|
|
||||||
if self._current_queue:
|
|
||||||
await self._current_queue.put(queue_event if queue_event is not None else event)
|
|
||||||
|
|
||||||
async def _publish_error(self, event: object) -> None:
|
|
||||||
if self.callback:
|
|
||||||
self.callback(event)
|
|
||||||
if self._current_queue and hasattr(event, "code") and hasattr(event, "details"):
|
|
||||||
await self._current_queue.put(AgentException(event.code, event.details))
|
|
||||||
|
|
||||||
async def _listen(self):
|
|
||||||
try:
|
|
||||||
async for msg in self._ws:
|
|
||||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
||||||
try:
|
|
||||||
outgoing_msg = ServerMessage.validate_json(msg.data)
|
|
||||||
|
|
||||||
if self._is_text_event(outgoing_msg):
|
|
||||||
if _DEBUG_STREAM:
|
|
||||||
logger.warning(
|
|
||||||
"[%s] text chunk queue=%s text=%r",
|
|
||||||
self.id,
|
|
||||||
self._current_queue is not None,
|
|
||||||
getattr(outgoing_msg, "text", "")[:80],
|
|
||||||
)
|
|
||||||
if self._current_queue:
|
|
||||||
await self._current_queue.put(outgoing_msg)
|
|
||||||
elif self.callback:
|
|
||||||
self.callback(outgoing_msg)
|
|
||||||
else:
|
|
||||||
logger.warning("[%s] AgentEvent without active request", self.id)
|
|
||||||
|
|
||||||
elif self._is_end_event(outgoing_msg):
|
|
||||||
self.last_tokens_used = outgoing_msg.tokens_used
|
|
||||||
if _DEBUG_STREAM:
|
|
||||||
logger.warning(
|
|
||||||
"[%s] end event queue=%s tokens=%s",
|
|
||||||
self.id,
|
|
||||||
self._current_queue is not None,
|
|
||||||
getattr(outgoing_msg, "tokens_used", None),
|
|
||||||
)
|
|
||||||
await self._publish_event(outgoing_msg)
|
|
||||||
|
|
||||||
elif self._is_kind(outgoing_msg, "ERROR"):
|
|
||||||
error = AgentException(outgoing_msg.code, outgoing_msg.details)
|
|
||||||
logger.error("[%s] Agent error: %s", self.id, error)
|
|
||||||
await self._publish_error(outgoing_msg)
|
|
||||||
|
|
||||||
elif self._is_kind(outgoing_msg, "GRACEFUL_DISCONNECT"):
|
|
||||||
await self._publish_event(outgoing_msg)
|
|
||||||
logger.info("[%s] Gracefully disconnecting", self.id)
|
|
||||||
break
|
|
||||||
|
|
||||||
else:
|
|
||||||
await self._publish_event(outgoing_msg)
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("[%s] Failed to deserialize message: %s", self.id, exc)
|
|
||||||
if self._current_queue:
|
|
||||||
await self._current_queue.put(
|
|
||||||
AgentException("PARSE_ERROR", f"Validation failed: {exc}")
|
|
||||||
)
|
|
||||||
|
|
||||||
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
|
|
||||||
logger.error("[%s] WebSocket closed/error: %s", self.id, msg.type)
|
|
||||||
break
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("[%s] Error in listen loop: %s", self.id, exc)
|
|
||||||
finally:
|
|
||||||
await self._cleanup()
|
|
||||||
|
|
||||||
async def send_message(
|
|
||||||
self, text: str, attachments: list[str] | None = None
|
|
||||||
) -> AsyncIterator[AgentEventUnion]:
|
|
||||||
if not self._connected or not self._ws:
|
|
||||||
raise AgentException(
|
|
||||||
code="NOT_CONNECTED", details="Not connected. Call connect() first."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._request_lock.locked():
|
|
||||||
raise AgentBusyException("Agent is currently processing another request")
|
|
||||||
|
|
||||||
await self._request_lock.acquire()
|
|
||||||
try:
|
|
||||||
self._current_queue = asyncio.Queue()
|
|
||||||
|
|
||||||
message = MsgUserMessage(
|
|
||||||
type=EClientMessage.USER_MESSAGE,
|
|
||||||
text=text,
|
|
||||||
attachments=attachments or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._ws.send_str(message.model_dump_json())
|
|
||||||
logger.debug("[%s] Sent message: %s...", self.id, text[:50])
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
chunk = await asyncio.wait_for(
|
|
||||||
self._current_queue.get(),
|
|
||||||
timeout=max(_STREAM_IDLE_TIMEOUT_MS, 0) / 1000,
|
|
||||||
)
|
|
||||||
except TimeoutError as exc:
|
|
||||||
raise AgentException(
|
|
||||||
"TIMEOUT",
|
|
||||||
(
|
|
||||||
"Timed out waiting for the next agent stream event "
|
|
||||||
f"after {max(_STREAM_IDLE_TIMEOUT_MS, 0)}ms"
|
|
||||||
),
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
if isinstance(chunk, Exception):
|
|
||||||
raise chunk
|
|
||||||
|
|
||||||
if isinstance(chunk, MsgEventEnd):
|
|
||||||
self.last_tokens_used = chunk.tokens_used
|
|
||||||
async for late_chunk in self._drain_post_end_events():
|
|
||||||
yield late_chunk
|
|
||||||
break
|
|
||||||
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if self._current_queue:
|
|
||||||
orphan_queue = self._current_queue
|
|
||||||
self._current_queue = None
|
|
||||||
|
|
||||||
while not orphan_queue.empty():
|
|
||||||
try:
|
|
||||||
orphan_msg = orphan_queue.get_nowait()
|
|
||||||
if isinstance(orphan_msg, Exception):
|
|
||||||
logger.debug(
|
|
||||||
"[%s] Dropped exception from queue during cleanup: %s",
|
|
||||||
self.id,
|
|
||||||
orphan_msg,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self.callback:
|
|
||||||
self.callback(orphan_msg)
|
|
||||||
else:
|
|
||||||
logger.debug("[%s] Dropped orphaned message during cleanup", self.id)
|
|
||||||
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
break
|
|
||||||
|
|
||||||
if self._request_lock.locked():
|
|
||||||
self._request_lock.release()
|
|
||||||
|
|
||||||
async def _drain_post_end_events(self) -> AsyncIterator[AgentEventUnion]:
|
|
||||||
if self._current_queue is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
timeout_s = max(_POST_END_DRAIN_MS, 0) / 1000
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
chunk = await asyncio.wait_for(self._current_queue.get(), timeout=timeout_s)
|
|
||||||
except TimeoutError:
|
|
||||||
break
|
|
||||||
|
|
||||||
if isinstance(chunk, Exception):
|
|
||||||
logger.warning("[%s] dropping post-END exception: %s", self.id, chunk)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(chunk, MsgEventEnd):
|
|
||||||
self.last_tokens_used = chunk.tokens_used
|
|
||||||
if _DEBUG_STREAM:
|
|
||||||
logger.warning(
|
|
||||||
"[%s] dropped duplicate END tokens=%s",
|
|
||||||
self.id,
|
|
||||||
chunk.tokens_used,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if _DEBUG_STREAM and self._is_text_event(chunk):
|
|
||||||
logger.warning(
|
|
||||||
"[%s] recovered post-END text chunk=%r",
|
|
||||||
self.id,
|
|
||||||
getattr(chunk, "text", "")[:80],
|
|
||||||
)
|
|
||||||
|
|
||||||
yield chunk
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from lambda_agent_api.server import MsgEventEnd, MsgEventTextChunk
|
|
||||||
|
|
||||||
import sdk.agent_api_wrapper as agent_api_wrapper_module
|
import sdk.agent_api_wrapper as agent_api_wrapper_module
|
||||||
from core.protocol import SettingsAction
|
from core.protocol import SettingsAction
|
||||||
|
|
@ -142,233 +141,61 @@ class TextChunkEvent:
|
||||||
self.type = "AGENT_EVENT_TEXT_CHUNK"
|
self.type = "AGENT_EVENT_TEXT_CHUNK"
|
||||||
self.text = text
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
class ToolCallChunkEvent:
|
|
||||||
def __init__(self, payload: str) -> None:
|
|
||||||
self.type = "AGENT_EVENT_TOOL_CALL_CHUNK"
|
|
||||||
self.payload = payload
|
|
||||||
|
|
||||||
|
|
||||||
class ToolResultEvent:
|
|
||||||
def __init__(self, payload: str) -> None:
|
|
||||||
self.type = "AGENT_EVENT_TOOL_RESULT"
|
|
||||||
self.payload = payload
|
|
||||||
|
|
||||||
|
|
||||||
class CustomUpdateEvent:
|
|
||||||
def __init__(self, payload: str) -> None:
|
|
||||||
self.type = "AGENT_EVENT_CUSTOM_UPDATE"
|
|
||||||
self.payload = payload
|
|
||||||
|
|
||||||
|
|
||||||
class EndEvent:
|
|
||||||
def __init__(self, tokens_used: int) -> None:
|
|
||||||
self.type = "AGENT_EVENT_END"
|
|
||||||
self.tokens_used = tokens_used
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorEvent:
|
|
||||||
def __init__(self, code: str, details: str) -> None:
|
|
||||||
self.type = "ERROR"
|
|
||||||
self.code = code
|
|
||||||
self.details = details
|
|
||||||
|
|
||||||
|
|
||||||
class GracefulDisconnectEvent:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.type = "GRACEFUL_DISCONNECT"
|
|
||||||
|
|
||||||
|
|
||||||
class FakeWSMessage:
|
|
||||||
def __init__(self, data: str) -> None:
|
|
||||||
self.type = agent_api_wrapper_module.aiohttp.WSMsgType.TEXT
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
|
|
||||||
class FakeWebSocket:
|
|
||||||
def __init__(self, messages: list[FakeWSMessage]) -> None:
|
|
||||||
self._messages = list(messages)
|
|
||||||
|
|
||||||
def __aiter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
if not self._messages:
|
|
||||||
raise StopAsyncIteration
|
|
||||||
return self._messages.pop(0)
|
|
||||||
|
|
||||||
|
|
||||||
class QueueFeedingWebSocket:
|
|
||||||
def __init__(self, owner, queued_events: list[object]) -> None:
|
|
||||||
self.owner = owner
|
|
||||||
self.queued_events = list(queued_events)
|
|
||||||
self.sent_payloads: list[str] = []
|
|
||||||
|
|
||||||
async def send_str(self, payload: str) -> None:
|
|
||||||
self.sent_payloads.append(payload)
|
|
||||||
for event in self.queued_events:
|
|
||||||
await self.owner._current_queue.put(event)
|
|
||||||
|
|
||||||
|
|
||||||
class SilentWebSocket:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.sent_payloads: list[str] = []
|
|
||||||
|
|
||||||
async def send_str(self, payload: str) -> None:
|
|
||||||
self.sent_payloads.append(payload)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageResponseWithAttachments(MessageResponse):
|
class MessageResponseWithAttachments(MessageResponse):
|
||||||
attachments: list[Attachment] = []
|
attachments: list[Attachment] = []
|
||||||
|
|
||||||
|
|
||||||
def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch):
|
def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch):
|
||||||
calls: list[dict[str, object]] = []
|
captured = {}
|
||||||
|
|
||||||
def fake_init(self, agent_id, base_url, chat_id, **kwargs):
|
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
||||||
calls.append(
|
captured["agent_id"] = agent_id
|
||||||
{
|
captured["base_url"] = base_url
|
||||||
"agent_id": agent_id,
|
captured["chat_id"] = chat_id
|
||||||
"base_url": base_url,
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"kwargs": kwargs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.id = agent_id
|
|
||||||
self.url = base_url
|
|
||||||
self.callback = kwargs.get("callback")
|
|
||||||
self.on_disconnect = kwargs.get("on_disconnect")
|
|
||||||
|
|
||||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||||
|
|
||||||
wrapper = AgentApiWrapper(
|
wrapper = AgentApiWrapper(
|
||||||
agent_id="agent-1",
|
agent_id="agent-1",
|
||||||
base_url="https://agent.example.com/v1/agent_ws",
|
base_url="ws://platform-agent:8000/v1/agent_ws/",
|
||||||
chat_id="chat-1",
|
chat_id="41",
|
||||||
callback="cb",
|
|
||||||
on_disconnect="disconnect",
|
|
||||||
)
|
|
||||||
child = wrapper.for_chat("chat-2")
|
|
||||||
|
|
||||||
assert calls == [
|
|
||||||
{
|
|
||||||
"agent_id": "agent-1",
|
|
||||||
"base_url": "https://agent.example.com",
|
|
||||||
"chat_id": "chat-1",
|
|
||||||
"kwargs": {"callback": "cb", "on_disconnect": "disconnect"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"agent_id": "agent-1",
|
|
||||||
"base_url": "https://agent.example.com",
|
|
||||||
"chat_id": "chat-2",
|
|
||||||
"kwargs": {"callback": "cb", "on_disconnect": "disconnect"},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
assert wrapper._base_url == "https://agent.example.com"
|
|
||||||
assert wrapper.chat_id == "chat-1"
|
|
||||||
assert wrapper.last_tokens_used == 0
|
|
||||||
assert child.chat_id == "chat-2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_agent_api_wrapper_falls_back_to_legacy_url_constructor(monkeypatch):
|
|
||||||
calls: list[dict[str, object]] = []
|
|
||||||
|
|
||||||
def fake_init(self, agent_id, url, callback=None, on_disconnect=None):
|
|
||||||
calls.append(
|
|
||||||
{
|
|
||||||
"agent_id": agent_id,
|
|
||||||
"url": url,
|
|
||||||
"callback": callback,
|
|
||||||
"on_disconnect": on_disconnect,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.id = agent_id
|
|
||||||
self.url = url
|
|
||||||
self.callback = callback
|
|
||||||
self.on_disconnect = on_disconnect
|
|
||||||
|
|
||||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
|
||||||
|
|
||||||
wrapper = AgentApiWrapper(
|
|
||||||
agent_id="agent-2",
|
|
||||||
url="https://agent.example.com/agent_ws/",
|
|
||||||
chat_id="chat-9",
|
|
||||||
callback="cb",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert calls == [
|
assert wrapper.chat_id == "41"
|
||||||
{
|
assert wrapper._base_url == "ws://platform-agent:8000"
|
||||||
"agent_id": "agent-2",
|
assert captured == {
|
||||||
"url": "https://agent.example.com/agent_ws/?thread_id=chat-9",
|
"agent_id": "agent-1",
|
||||||
"callback": "cb",
|
"base_url": "ws://platform-agent:8000",
|
||||||
"on_disconnect": None,
|
"chat_id": "41",
|
||||||
}
|
}
|
||||||
]
|
|
||||||
assert wrapper._base_url == "https://agent.example.com"
|
|
||||||
assert wrapper.chat_id == "chat-9"
|
|
||||||
assert wrapper.last_tokens_used == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_agent_api_wrapper_for_chat_reuses_normalized_base_url(monkeypatch):
|
||||||
async def test_agent_api_wrapper_recovers_late_text_after_first_end(monkeypatch):
|
init_calls = []
|
||||||
|
|
||||||
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
||||||
self.id = agent_id
|
self.id = agent_id
|
||||||
|
self.chat_id = chat_id
|
||||||
self.url = base_url
|
self.url = base_url
|
||||||
self.callback = kwargs.get("callback")
|
init_calls.append((agent_id, base_url, chat_id))
|
||||||
self.on_disconnect = kwargs.get("on_disconnect")
|
|
||||||
|
|
||||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||||
|
|
||||||
wrapper = AgentApiWrapper(
|
root = AgentApiWrapper(
|
||||||
agent_id="agent-1",
|
agent_id="agent-1",
|
||||||
base_url="https://agent.example.com/v1/agent_ws",
|
base_url="http://platform-agent:8000/v1/agent_ws/",
|
||||||
chat_id="chat-1",
|
chat_id="1",
|
||||||
)
|
|
||||||
wrapper._connected = True
|
|
||||||
wrapper._request_lock = asyncio.Lock()
|
|
||||||
wrapper._current_queue = None
|
|
||||||
wrapper._ws = QueueFeedingWebSocket(
|
|
||||||
wrapper,
|
|
||||||
[
|
|
||||||
MsgEventTextChunk(text="Иллюстра"),
|
|
||||||
MsgEventEnd(tokens_used=5),
|
|
||||||
MsgEventTextChunk(text="ция"),
|
|
||||||
MsgEventEnd(tokens_used=5),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chunks = []
|
child = root.for_chat("99")
|
||||||
async for chunk in wrapper.send_message("hello"):
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
assert [chunk.text for chunk in chunks] == ["Иллюстра", "ция"]
|
assert child is not root
|
||||||
assert wrapper.last_tokens_used == 5
|
assert child.chat_id == "99"
|
||||||
|
assert child._base_url == "http://platform-agent:8000"
|
||||||
|
assert init_calls == [
|
||||||
@pytest.mark.asyncio
|
("agent-1", "http://platform-agent:8000", "1"),
|
||||||
async def test_agent_api_wrapper_times_out_on_idle_stream(monkeypatch):
|
("agent-1", "http://platform-agent:8000", "99"),
|
||||||
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
]
|
||||||
self.id = agent_id
|
|
||||||
self.url = base_url
|
|
||||||
self.callback = kwargs.get("callback")
|
|
||||||
self.on_disconnect = kwargs.get("on_disconnect")
|
|
||||||
|
|
||||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
|
||||||
monkeypatch.setattr(agent_api_wrapper_module, "_STREAM_IDLE_TIMEOUT_MS", 10)
|
|
||||||
|
|
||||||
wrapper = AgentApiWrapper(
|
|
||||||
agent_id="agent-1",
|
|
||||||
base_url="https://agent.example.com/v1/agent_ws",
|
|
||||||
chat_id="chat-1",
|
|
||||||
)
|
|
||||||
wrapper._connected = True
|
|
||||||
wrapper._request_lock = asyncio.Lock()
|
|
||||||
wrapper._current_queue = None
|
|
||||||
wrapper._ws = SilentWebSocket()
|
|
||||||
|
|
||||||
with pytest.raises(agent_api_wrapper_module.AgentException, match="Timed out waiting"):
|
|
||||||
async for _ in wrapper.send_message("hello"):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -703,87 +530,3 @@ async def test_real_platform_client_settings_are_local():
|
||||||
assert isinstance(settings, UserSettings)
|
assert isinstance(settings, UserSettings)
|
||||||
assert settings.skills["browser"] is True
|
assert settings.skills["browser"] is True
|
||||||
assert settings.skills["web-search"] is True
|
assert settings.skills["web-search"] is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_agent_api_wrapper_transparently_surfaces_modern_events(monkeypatch):
|
|
||||||
callback_events: list[object] = []
|
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
|
||||||
event_map = {
|
|
||||||
"text": TextChunkEvent("he"),
|
|
||||||
"tool_call": ToolCallChunkEvent("call"),
|
|
||||||
"tool_result": ToolResultEvent("result"),
|
|
||||||
"custom_update": CustomUpdateEvent("update"),
|
|
||||||
"send_file": SendFileEvent(
|
|
||||||
workspace_path="/workspace/report.pdf",
|
|
||||||
mime_type="application/pdf",
|
|
||||||
filename="report.pdf",
|
|
||||||
size=123,
|
|
||||||
),
|
|
||||||
"end": EndEvent(tokens_used=11),
|
|
||||||
"error": ErrorEvent(code="BOOM", details="bad things"),
|
|
||||||
"disconnect": GracefulDisconnectEvent(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def fake_validate_json(data: str):
|
|
||||||
return event_map[data]
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
agent_api_wrapper_module,
|
|
||||||
"ServerMessage",
|
|
||||||
type("FakeServerMessage", (), {"validate_json": staticmethod(fake_validate_json)}),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def fake_cleanup(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApiWrapper, "_cleanup", fake_cleanup)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
agent_api_wrapper_module.AgentApi,
|
|
||||||
"__init__",
|
|
||||||
lambda self, agent_id, base_url=None, chat_id=0, **kwargs: (
|
|
||||||
setattr(self, "id", agent_id)
|
|
||||||
or setattr(self, "callback", kwargs.get("callback"))
|
|
||||||
or setattr(self, "on_disconnect", kwargs.get("on_disconnect"))
|
|
||||||
or setattr(self, "_current_queue", None)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
wrapper = AgentApiWrapper(
|
|
||||||
agent_id="agent-1",
|
|
||||||
base_url="https://agent.example.com/v1/agent_ws",
|
|
||||||
chat_id="chat-1",
|
|
||||||
callback=callback_events.append,
|
|
||||||
)
|
|
||||||
wrapper._current_queue = queue
|
|
||||||
wrapper._ws = FakeWebSocket(
|
|
||||||
[
|
|
||||||
FakeWSMessage("text"),
|
|
||||||
FakeWSMessage("tool_call"),
|
|
||||||
FakeWSMessage("tool_result"),
|
|
||||||
FakeWSMessage("custom_update"),
|
|
||||||
FakeWSMessage("send_file"),
|
|
||||||
FakeWSMessage("end"),
|
|
||||||
FakeWSMessage("error"),
|
|
||||||
FakeWSMessage("disconnect"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
await wrapper._listen()
|
|
||||||
|
|
||||||
queue_events = []
|
|
||||||
while not queue.empty():
|
|
||||||
queue_events.append(await queue.get())
|
|
||||||
|
|
||||||
assert queue_events[0].text == "he"
|
|
||||||
assert any(isinstance(event, SendFileEvent) for event in queue_events)
|
|
||||||
assert any(isinstance(event, EndEvent) for event in queue_events)
|
|
||||||
assert any(isinstance(event, GracefulDisconnectEvent) for event in queue_events)
|
|
||||||
assert callback_events[0].payload == "call"
|
|
||||||
assert callback_events[1].payload == "result"
|
|
||||||
assert callback_events[2].payload == "update"
|
|
||||||
assert any(isinstance(event, SendFileEvent) for event in callback_events)
|
|
||||||
assert any(isinstance(event, EndEvent) for event in callback_events)
|
|
||||||
assert any(isinstance(event, ErrorEvent) for event in callback_events)
|
|
||||||
assert any(isinstance(event, GracefulDisconnectEvent) for event in callback_events)
|
|
||||||
assert wrapper.last_tokens_used == 11
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue