refactor: shrink agent api wrapper to thin adapter
This commit is contained in:
parent
4d917ac794
commit
569824ead1
2 changed files with 47 additions and 557 deletions
|
|
@ -1,7 +1,6 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from lambda_agent_api.server import MsgEventEnd, MsgEventTextChunk
|
||||
|
||||
import sdk.agent_api_wrapper as agent_api_wrapper_module
|
||||
from core.protocol import SettingsAction
|
||||
|
|
@ -142,233 +141,61 @@ class TextChunkEvent:
|
|||
self.type = "AGENT_EVENT_TEXT_CHUNK"
|
||||
self.text = text
|
||||
|
||||
|
||||
class ToolCallChunkEvent:
|
||||
def __init__(self, payload: str) -> None:
|
||||
self.type = "AGENT_EVENT_TOOL_CALL_CHUNK"
|
||||
self.payload = payload
|
||||
|
||||
|
||||
class ToolResultEvent:
|
||||
def __init__(self, payload: str) -> None:
|
||||
self.type = "AGENT_EVENT_TOOL_RESULT"
|
||||
self.payload = payload
|
||||
|
||||
|
||||
class CustomUpdateEvent:
|
||||
def __init__(self, payload: str) -> None:
|
||||
self.type = "AGENT_EVENT_CUSTOM_UPDATE"
|
||||
self.payload = payload
|
||||
|
||||
|
||||
class EndEvent:
|
||||
def __init__(self, tokens_used: int) -> None:
|
||||
self.type = "AGENT_EVENT_END"
|
||||
self.tokens_used = tokens_used
|
||||
|
||||
|
||||
class ErrorEvent:
|
||||
def __init__(self, code: str, details: str) -> None:
|
||||
self.type = "ERROR"
|
||||
self.code = code
|
||||
self.details = details
|
||||
|
||||
|
||||
class GracefulDisconnectEvent:
|
||||
def __init__(self) -> None:
|
||||
self.type = "GRACEFUL_DISCONNECT"
|
||||
|
||||
|
||||
class FakeWSMessage:
|
||||
def __init__(self, data: str) -> None:
|
||||
self.type = agent_api_wrapper_module.aiohttp.WSMsgType.TEXT
|
||||
self.data = data
|
||||
|
||||
|
||||
class FakeWebSocket:
|
||||
def __init__(self, messages: list[FakeWSMessage]) -> None:
|
||||
self._messages = list(messages)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._messages:
|
||||
raise StopAsyncIteration
|
||||
return self._messages.pop(0)
|
||||
|
||||
|
||||
class QueueFeedingWebSocket:
|
||||
def __init__(self, owner, queued_events: list[object]) -> None:
|
||||
self.owner = owner
|
||||
self.queued_events = list(queued_events)
|
||||
self.sent_payloads: list[str] = []
|
||||
|
||||
async def send_str(self, payload: str) -> None:
|
||||
self.sent_payloads.append(payload)
|
||||
for event in self.queued_events:
|
||||
await self.owner._current_queue.put(event)
|
||||
|
||||
|
||||
class SilentWebSocket:
|
||||
def __init__(self) -> None:
|
||||
self.sent_payloads: list[str] = []
|
||||
|
||||
async def send_str(self, payload: str) -> None:
|
||||
self.sent_payloads.append(payload)
|
||||
|
||||
|
||||
class MessageResponseWithAttachments(MessageResponse):
|
||||
attachments: list[Attachment] = []
|
||||
|
||||
|
||||
def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch):
|
||||
calls: list[dict[str, object]] = []
|
||||
def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_init(self, agent_id, base_url, chat_id, **kwargs):
|
||||
calls.append(
|
||||
{
|
||||
"agent_id": agent_id,
|
||||
"base_url": base_url,
|
||||
"chat_id": chat_id,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
self.id = agent_id
|
||||
self.url = base_url
|
||||
self.callback = kwargs.get("callback")
|
||||
self.on_disconnect = kwargs.get("on_disconnect")
|
||||
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
||||
captured["agent_id"] = agent_id
|
||||
captured["base_url"] = base_url
|
||||
captured["chat_id"] = chat_id
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||
|
||||
wrapper = AgentApiWrapper(
|
||||
agent_id="agent-1",
|
||||
base_url="https://agent.example.com/v1/agent_ws",
|
||||
chat_id="chat-1",
|
||||
callback="cb",
|
||||
on_disconnect="disconnect",
|
||||
)
|
||||
child = wrapper.for_chat("chat-2")
|
||||
|
||||
assert calls == [
|
||||
{
|
||||
"agent_id": "agent-1",
|
||||
"base_url": "https://agent.example.com",
|
||||
"chat_id": "chat-1",
|
||||
"kwargs": {"callback": "cb", "on_disconnect": "disconnect"},
|
||||
},
|
||||
{
|
||||
"agent_id": "agent-1",
|
||||
"base_url": "https://agent.example.com",
|
||||
"chat_id": "chat-2",
|
||||
"kwargs": {"callback": "cb", "on_disconnect": "disconnect"},
|
||||
},
|
||||
]
|
||||
assert wrapper._base_url == "https://agent.example.com"
|
||||
assert wrapper.chat_id == "chat-1"
|
||||
assert wrapper.last_tokens_used == 0
|
||||
assert child.chat_id == "chat-2"
|
||||
|
||||
|
||||
def test_agent_api_wrapper_falls_back_to_legacy_url_constructor(monkeypatch):
|
||||
calls: list[dict[str, object]] = []
|
||||
|
||||
def fake_init(self, agent_id, url, callback=None, on_disconnect=None):
|
||||
calls.append(
|
||||
{
|
||||
"agent_id": agent_id,
|
||||
"url": url,
|
||||
"callback": callback,
|
||||
"on_disconnect": on_disconnect,
|
||||
}
|
||||
)
|
||||
self.id = agent_id
|
||||
self.url = url
|
||||
self.callback = callback
|
||||
self.on_disconnect = on_disconnect
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||
|
||||
wrapper = AgentApiWrapper(
|
||||
agent_id="agent-2",
|
||||
url="https://agent.example.com/agent_ws/",
|
||||
chat_id="chat-9",
|
||||
callback="cb",
|
||||
base_url="ws://platform-agent:8000/v1/agent_ws/",
|
||||
chat_id="41",
|
||||
)
|
||||
|
||||
assert calls == [
|
||||
{
|
||||
"agent_id": "agent-2",
|
||||
"url": "https://agent.example.com/agent_ws/?thread_id=chat-9",
|
||||
"callback": "cb",
|
||||
"on_disconnect": None,
|
||||
}
|
||||
]
|
||||
assert wrapper._base_url == "https://agent.example.com"
|
||||
assert wrapper.chat_id == "chat-9"
|
||||
assert wrapper.last_tokens_used == 0
|
||||
assert wrapper.chat_id == "41"
|
||||
assert wrapper._base_url == "ws://platform-agent:8000"
|
||||
assert captured == {
|
||||
"agent_id": "agent-1",
|
||||
"base_url": "ws://platform-agent:8000",
|
||||
"chat_id": "41",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_api_wrapper_recovers_late_text_after_first_end(monkeypatch):
|
||||
def test_agent_api_wrapper_for_chat_reuses_normalized_base_url(monkeypatch):
|
||||
init_calls = []
|
||||
|
||||
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
||||
self.id = agent_id
|
||||
self.chat_id = chat_id
|
||||
self.url = base_url
|
||||
self.callback = kwargs.get("callback")
|
||||
self.on_disconnect = kwargs.get("on_disconnect")
|
||||
init_calls.append((agent_id, base_url, chat_id))
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||
|
||||
wrapper = AgentApiWrapper(
|
||||
root = AgentApiWrapper(
|
||||
agent_id="agent-1",
|
||||
base_url="https://agent.example.com/v1/agent_ws",
|
||||
chat_id="chat-1",
|
||||
)
|
||||
wrapper._connected = True
|
||||
wrapper._request_lock = asyncio.Lock()
|
||||
wrapper._current_queue = None
|
||||
wrapper._ws = QueueFeedingWebSocket(
|
||||
wrapper,
|
||||
[
|
||||
MsgEventTextChunk(text="Иллюстра"),
|
||||
MsgEventEnd(tokens_used=5),
|
||||
MsgEventTextChunk(text="ция"),
|
||||
MsgEventEnd(tokens_used=5),
|
||||
],
|
||||
base_url="http://platform-agent:8000/v1/agent_ws/",
|
||||
chat_id="1",
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk in wrapper.send_message("hello"):
|
||||
chunks.append(chunk)
|
||||
child = root.for_chat("99")
|
||||
|
||||
assert [chunk.text for chunk in chunks] == ["Иллюстра", "ция"]
|
||||
assert wrapper.last_tokens_used == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_api_wrapper_times_out_on_idle_stream(monkeypatch):
|
||||
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
||||
self.id = agent_id
|
||||
self.url = base_url
|
||||
self.callback = kwargs.get("callback")
|
||||
self.on_disconnect = kwargs.get("on_disconnect")
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||
monkeypatch.setattr(agent_api_wrapper_module, "_STREAM_IDLE_TIMEOUT_MS", 10)
|
||||
|
||||
wrapper = AgentApiWrapper(
|
||||
agent_id="agent-1",
|
||||
base_url="https://agent.example.com/v1/agent_ws",
|
||||
chat_id="chat-1",
|
||||
)
|
||||
wrapper._connected = True
|
||||
wrapper._request_lock = asyncio.Lock()
|
||||
wrapper._current_queue = None
|
||||
wrapper._ws = SilentWebSocket()
|
||||
|
||||
with pytest.raises(agent_api_wrapper_module.AgentException, match="Timed out waiting"):
|
||||
async for _ in wrapper.send_message("hello"):
|
||||
pass
|
||||
assert child is not root
|
||||
assert child.chat_id == "99"
|
||||
assert child._base_url == "http://platform-agent:8000"
|
||||
assert init_calls == [
|
||||
("agent-1", "http://platform-agent:8000", "1"),
|
||||
("agent-1", "http://platform-agent:8000", "99"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -703,87 +530,3 @@ async def test_real_platform_client_settings_are_local():
|
|||
assert isinstance(settings, UserSettings)
|
||||
assert settings.skills["browser"] is True
|
||||
assert settings.skills["web-search"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_api_wrapper_transparently_surfaces_modern_events(monkeypatch):
|
||||
callback_events: list[object] = []
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
event_map = {
|
||||
"text": TextChunkEvent("he"),
|
||||
"tool_call": ToolCallChunkEvent("call"),
|
||||
"tool_result": ToolResultEvent("result"),
|
||||
"custom_update": CustomUpdateEvent("update"),
|
||||
"send_file": SendFileEvent(
|
||||
workspace_path="/workspace/report.pdf",
|
||||
mime_type="application/pdf",
|
||||
filename="report.pdf",
|
||||
size=123,
|
||||
),
|
||||
"end": EndEvent(tokens_used=11),
|
||||
"error": ErrorEvent(code="BOOM", details="bad things"),
|
||||
"disconnect": GracefulDisconnectEvent(),
|
||||
}
|
||||
|
||||
def fake_validate_json(data: str):
|
||||
return event_map[data]
|
||||
|
||||
monkeypatch.setattr(
|
||||
agent_api_wrapper_module,
|
||||
"ServerMessage",
|
||||
type("FakeServerMessage", (), {"validate_json": staticmethod(fake_validate_json)}),
|
||||
)
|
||||
|
||||
async def fake_cleanup(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApiWrapper, "_cleanup", fake_cleanup)
|
||||
monkeypatch.setattr(
|
||||
agent_api_wrapper_module.AgentApi,
|
||||
"__init__",
|
||||
lambda self, agent_id, base_url=None, chat_id=0, **kwargs: (
|
||||
setattr(self, "id", agent_id)
|
||||
or setattr(self, "callback", kwargs.get("callback"))
|
||||
or setattr(self, "on_disconnect", kwargs.get("on_disconnect"))
|
||||
or setattr(self, "_current_queue", None)
|
||||
),
|
||||
)
|
||||
|
||||
wrapper = AgentApiWrapper(
|
||||
agent_id="agent-1",
|
||||
base_url="https://agent.example.com/v1/agent_ws",
|
||||
chat_id="chat-1",
|
||||
callback=callback_events.append,
|
||||
)
|
||||
wrapper._current_queue = queue
|
||||
wrapper._ws = FakeWebSocket(
|
||||
[
|
||||
FakeWSMessage("text"),
|
||||
FakeWSMessage("tool_call"),
|
||||
FakeWSMessage("tool_result"),
|
||||
FakeWSMessage("custom_update"),
|
||||
FakeWSMessage("send_file"),
|
||||
FakeWSMessage("end"),
|
||||
FakeWSMessage("error"),
|
||||
FakeWSMessage("disconnect"),
|
||||
]
|
||||
)
|
||||
|
||||
await wrapper._listen()
|
||||
|
||||
queue_events = []
|
||||
while not queue.empty():
|
||||
queue_events.append(await queue.get())
|
||||
|
||||
assert queue_events[0].text == "he"
|
||||
assert any(isinstance(event, SendFileEvent) for event in queue_events)
|
||||
assert any(isinstance(event, EndEvent) for event in queue_events)
|
||||
assert any(isinstance(event, GracefulDisconnectEvent) for event in queue_events)
|
||||
assert callback_events[0].payload == "call"
|
||||
assert callback_events[1].payload == "result"
|
||||
assert callback_events[2].payload == "update"
|
||||
assert any(isinstance(event, SendFileEvent) for event in callback_events)
|
||||
assert any(isinstance(event, EndEvent) for event in callback_events)
|
||||
assert any(isinstance(event, ErrorEvent) for event in callback_events)
|
||||
assert any(isinstance(event, GracefulDisconnectEvent) for event in callback_events)
|
||||
assert wrapper.last_tokens_used == 11
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue