import asyncio import pytest from lambda_agent_api.server import MsgEventEnd, MsgEventTextChunk import sdk.agent_api_wrapper as agent_api_wrapper_module from core.protocol import SettingsAction from sdk.agent_api_wrapper import AgentApiWrapper from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformError, UserSettings from sdk.prototype_state import PrototypeStateStore from sdk.real import RealPlatformClient class FakeChunk: def __init__(self, text: str) -> None: self.text = text class FakeChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.calls: list[str] = [] self.connect_calls = 0 self.close_calls = 0 self.last_tokens_used = 0 async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str): self.calls.append(text) midpoint = len(text) // 2 yield FakeChunk(text[:midpoint]) yield FakeChunk(text[midpoint:]) self.last_tokens_used = 3 class FakeAgentApiFactory: def __init__(self) -> None: self.created_chat_ids: list[str] = [] self.instances: dict[str, FakeChatAgentApi] = {} def for_chat(self, chat_id: str) -> FakeChatAgentApi: chat_api = FakeChatAgentApi(chat_id) self.created_chat_ids.append(chat_id) self.instances[chat_id] = chat_api return chat_api class LegacyAgentApi: def __init__(self) -> None: self.calls: list[str] = [] self.last_tokens_used = 0 async def send_message(self, text: str): self.calls.append(text) yield FakeChunk(text[:2]) yield FakeChunk(text[2:]) self.last_tokens_used = 7 class BlockingChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.calls: list[str] = [] self.connect_calls = 0 self.close_calls = 0 self.last_tokens_used = 0 self.active_calls = 0 self.max_active_calls = 0 self.started = asyncio.Event() self.release = asyncio.Event() async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str): self.calls.append(text) self.active_calls += 1 self.max_active_calls = max(self.max_active_calls, self.active_calls) self.started.set() await self.release.wait() self.active_calls -= 1 yield FakeChunk(text) self.last_tokens_used = len(text) class AttachmentTrackingChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.calls: list[tuple[str, list[str] | None]] = [] self.connect_calls = 0 self.close_calls = 0 self.last_tokens_used = 0 async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str, attachments: list[str] | None = None): self.calls.append((text, attachments)) yield FakeChunk(text) self.last_tokens_used = 5 class FlakyChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.connect_calls = 0 self.close_calls = 0 async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str, attachments: list[str] | None = None): raise ConnectionError("Connection closed") yield class SendFileEvent: def __init__(self, *, workspace_path: str, mime_type: str, filename: str, size: int) -> None: self.type = "AGENT_EVENT_SEND_FILE" self.workspace_path = workspace_path self.mime_type = mime_type self.filename = filename self.size = size class TextChunkEvent: def __init__(self, text: str) -> None: self.type = "AGENT_EVENT_TEXT_CHUNK" self.text = text class ToolCallChunkEvent: def __init__(self, payload: str) -> None: self.type = "AGENT_EVENT_TOOL_CALL_CHUNK" self.payload = payload class ToolResultEvent: def __init__(self, payload: str) -> None: self.type = "AGENT_EVENT_TOOL_RESULT" self.payload = payload class CustomUpdateEvent: def __init__(self, payload: str) -> None: self.type = "AGENT_EVENT_CUSTOM_UPDATE" self.payload = payload class EndEvent: def __init__(self, tokens_used: int) -> None: self.type = "AGENT_EVENT_END" self.tokens_used = tokens_used class ErrorEvent: def __init__(self, code: str, details: str) -> None: self.type = "ERROR" self.code = code self.details = details class GracefulDisconnectEvent: def __init__(self) -> None: self.type = "GRACEFUL_DISCONNECT" class FakeWSMessage: def __init__(self, data: str) -> None: self.type = agent_api_wrapper_module.aiohttp.WSMsgType.TEXT self.data = data class FakeWebSocket: def __init__(self, messages: list[FakeWSMessage]) -> None: self._messages = list(messages) def __aiter__(self): return self async def __anext__(self): if not self._messages: raise StopAsyncIteration return self._messages.pop(0) class QueueFeedingWebSocket: def __init__(self, owner, queued_events: list[object]) -> None: self.owner = owner self.queued_events = list(queued_events) self.sent_payloads: list[str] = [] async def send_str(self, payload: str) -> None: self.sent_payloads.append(payload) for event in self.queued_events: await self.owner._current_queue.put(event) class SilentWebSocket: def __init__(self) -> None: self.sent_payloads: list[str] = [] async def send_str(self, payload: str) -> None: self.sent_payloads.append(payload) class MessageResponseWithAttachments(MessageResponse): attachments: list[Attachment] = [] def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch): calls: list[dict[str, object]] = [] def fake_init(self, agent_id, base_url, chat_id, **kwargs): calls.append( { "agent_id": agent_id, "base_url": base_url, "chat_id": chat_id, "kwargs": kwargs, } ) self.id = agent_id self.url = base_url self.callback = kwargs.get("callback") self.on_disconnect = kwargs.get("on_disconnect") monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) wrapper = AgentApiWrapper( agent_id="agent-1", base_url="https://agent.example.com/v1/agent_ws", chat_id="chat-1", callback="cb", on_disconnect="disconnect", ) child = wrapper.for_chat("chat-2") assert calls == [ { "agent_id": "agent-1", "base_url": "https://agent.example.com", "chat_id": "chat-1", "kwargs": {"callback": "cb", "on_disconnect": "disconnect"}, }, { "agent_id": "agent-1", "base_url": "https://agent.example.com", "chat_id": "chat-2", "kwargs": {"callback": "cb", "on_disconnect": "disconnect"}, }, ] assert wrapper._base_url == "https://agent.example.com" assert wrapper.chat_id == "chat-1" assert wrapper.last_tokens_used == 0 assert child.chat_id == "chat-2" def test_agent_api_wrapper_falls_back_to_legacy_url_constructor(monkeypatch): calls: list[dict[str, object]] = [] def fake_init(self, agent_id, url, callback=None, on_disconnect=None): calls.append( { "agent_id": agent_id, "url": url, "callback": callback, "on_disconnect": on_disconnect, } ) self.id = agent_id self.url = url self.callback = callback self.on_disconnect = on_disconnect monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) wrapper = AgentApiWrapper( agent_id="agent-2", url="https://agent.example.com/agent_ws/", chat_id="chat-9", callback="cb", ) assert calls == [ { "agent_id": "agent-2", "url": "https://agent.example.com/agent_ws/?thread_id=chat-9", "callback": "cb", "on_disconnect": None, } ] assert wrapper._base_url == "https://agent.example.com" assert wrapper.chat_id == "chat-9" assert wrapper.last_tokens_used == 0 @pytest.mark.asyncio async def test_agent_api_wrapper_recovers_late_text_after_first_end(monkeypatch): def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): self.id = agent_id self.url = base_url self.callback = kwargs.get("callback") self.on_disconnect = kwargs.get("on_disconnect") monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) wrapper = AgentApiWrapper( agent_id="agent-1", base_url="https://agent.example.com/v1/agent_ws", chat_id="chat-1", ) wrapper._connected = True wrapper._request_lock = asyncio.Lock() wrapper._current_queue = None wrapper._ws = QueueFeedingWebSocket( wrapper, [ MsgEventTextChunk(text="Иллюстра"), MsgEventEnd(tokens_used=5), MsgEventTextChunk(text="ция"), MsgEventEnd(tokens_used=5), ], ) chunks = [] async for chunk in wrapper.send_message("hello"): chunks.append(chunk) assert [chunk.text for chunk in chunks] == ["Иллюстра", "ция"] assert wrapper.last_tokens_used == 5 @pytest.mark.asyncio async def test_agent_api_wrapper_times_out_on_idle_stream(monkeypatch): def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): self.id = agent_id self.url = base_url self.callback = kwargs.get("callback") self.on_disconnect = kwargs.get("on_disconnect") monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) monkeypatch.setattr(agent_api_wrapper_module, "_STREAM_IDLE_TIMEOUT_MS", 10) wrapper = AgentApiWrapper( agent_id="agent-1", base_url="https://agent.example.com/v1/agent_ws", chat_id="chat-1", ) wrapper._connected = True wrapper._request_lock = asyncio.Lock() wrapper._current_queue = None wrapper._ws = SilentWebSocket() with pytest.raises(agent_api_wrapper_module.AgentException, match="Timed out waiting"): async for _ in wrapper.send_message("hello"): pass @pytest.mark.asyncio async def test_real_platform_client_get_or_create_user_uses_local_state(): client = RealPlatformClient( agent_api=FakeAgentApiFactory(), prototype_state=PrototypeStateStore(), ) first = await client.get_or_create_user("u1", "matrix", "Alice") second = await client.get_or_create_user("u1", "matrix") assert first.user_id == "usr-matrix-u1" assert first.is_new is True assert second.user_id == first.user_id assert second.is_new is False assert second.display_name == "Alice" @pytest.mark.asyncio async def test_real_platform_client_send_message_uses_chat_bound_client(): agent_api = FakeAgentApiFactory() prototype_state = PrototypeStateStore() client = RealPlatformClient( agent_api=agent_api, prototype_state=prototype_state, platform="matrix", ) result = await client.send_message("@alice:example.org", "chat-7", "hello") assert result == MessageResponse( message_id="@alice:example.org", response="hello", tokens_used=3, finished=True, ) assert agent_api.created_chat_ids == ["chat-7"] assert agent_api.instances["chat-7"].chat_id == "chat-7" assert agent_api.instances["chat-7"].calls == ["hello"] assert agent_api.instances["chat-7"].connect_calls == 1 assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 3 @pytest.mark.asyncio async def test_real_platform_client_forwards_attachments_to_chat_api(): agent_api = AttachmentTrackingChatAgentApi("chat-7") client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) attachment = Attachment( workspace_path="surfaces/matrix/alice/room/inbox/report.pdf", mime_type="application/pdf", filename="report.pdf", size=123, ) result = await client.send_message( "@alice:example.org", "chat-7", "hello", attachments=[attachment], ) assert agent_api.calls == [("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"])] assert result.response == "hello" assert result.tokens_used == 5 @pytest.mark.asyncio async def test_real_platform_client_preserves_send_file_events_in_sync_result(monkeypatch): agent_api = AttachmentTrackingChatAgentApi("chat-7") client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) class FileEventAgentApi(AttachmentTrackingChatAgentApi): async def send_message(self, text: str, attachments: list[str] | None = None): self.calls.append((text, attachments)) yield TextChunkEvent("he") yield SendFileEvent( workspace_path="/workspace/report.pdf", mime_type="application/pdf", filename="report.pdf", size=123, ) yield TextChunkEvent("llo") self.last_tokens_used = 9 monkeypatch.setattr( "sdk.real.MessageResponse", MessageResponseWithAttachments, ) client._agent_api = FileEventAgentApi("chat-7") result = await client.send_message("@alice:example.org", "chat-7", "hello") assert result.response == "hello" assert result.tokens_used == 9 assert result.attachments == [ Attachment( url="/workspace/report.pdf", mime_type="application/pdf", filename="report.pdf", size=123, workspace_path="report.pdf", ) ] @pytest.mark.asyncio async def test_real_platform_client_works_with_legacy_agent_api_without_for_chat(): legacy_api = LegacyAgentApi() client = RealPlatformClient( agent_api=legacy_api, prototype_state=PrototypeStateStore(), platform="matrix", ) result = await client.send_message("@alice:example.org", "chat-legacy", "hello") assert result == MessageResponse( message_id="@alice:example.org", response="hello", tokens_used=7, finished=True, ) assert legacy_api.calls == ["hello"] @pytest.mark.asyncio async def test_real_platform_client_reuses_cached_chat_client(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) await client.send_message("@alice:example.org", "chat-1", "hello") await client.send_message("@alice:example.org", "chat-1", "again") assert agent_api.created_chat_ids == ["chat-1"] assert agent_api.instances["chat-1"].calls == ["hello", "again"] assert agent_api.instances["chat-1"].connect_calls == 1 assert agent_api.instances["chat-1"].close_calls == 0 @pytest.mark.asyncio async def test_real_platform_client_wraps_connection_closed_as_platform_error(): agent_api = FakeAgentApiFactory() agent_api.instances["chat-1"] = FlakyChatAgentApi("chat-1") agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault( chat_id, FlakyChatAgentApi(chat_id) ) client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) with pytest.raises(PlatformError, match="Connection closed") as exc_info: await client.send_message("@alice:example.org", "chat-1", "hello") assert exc_info.value.code == "PLATFORM_CONNECTION_ERROR" assert "chat-1" not in client._chat_apis assert agent_api.instances["chat-1"].close_calls == 1 @pytest.mark.asyncio async def test_real_platform_client_reconnects_after_closed_chat_api(): agent_api = FakeAgentApiFactory() flaky = FlakyChatAgentApi("chat-1") healthy = AttachmentTrackingChatAgentApi("chat-1") provided = iter([flaky, healthy]) def for_chat(chat_id: str): chat_api = next(provided) agent_api.created_chat_ids.append(chat_id) agent_api.instances[chat_id] = chat_api return chat_api agent_api.for_chat = for_chat client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) with pytest.raises(PlatformError, match="Connection closed"): await client.send_message("@alice:example.org", "chat-1", "hello") result = await client.send_message("@alice:example.org", "chat-1", "again") assert result.response == "again" assert agent_api.created_chat_ids == ["chat-1", "chat-1"] assert healthy.calls == [("again", None)] @pytest.mark.asyncio async def test_real_platform_client_creates_chat_client_atomically_for_concurrent_requests(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) results = await asyncio.gather( client.send_message("@alice:example.org", "chat-1", "hello"), client.send_message("@alice:example.org", "chat-1", "again"), ) assert [result.response for result in results] == ["hello", "again"] assert agent_api.created_chat_ids == ["chat-1"] assert agent_api.instances["chat-1"].connect_calls == 1 assert agent_api.instances["chat-1"].calls == ["hello", "again"] @pytest.mark.asyncio async def test_real_platform_client_creates_distinct_clients_per_chat(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) await client.send_message("@alice:example.org", "chat-1", "hello") await client.send_message("@alice:example.org", "chat-2", "world") assert agent_api.created_chat_ids == ["chat-1", "chat-2"] assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"] assert agent_api.instances["chat-1"].calls == ["hello"] assert agent_api.instances["chat-2"].calls == ["world"] @pytest.mark.asyncio async def test_real_platform_client_serializes_same_chat_streams_across_send_paths(): agent_api = FakeAgentApiFactory() agent_api.instances["chat-1"] = BlockingChatAgentApi("chat-1") agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault( chat_id, BlockingChatAgentApi(chat_id) ) client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) async def consume_stream(): chunks = [] async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"): chunks.append(chunk) return chunks stream_task = asyncio.create_task(consume_stream()) await asyncio.wait_for(agent_api.instances["chat-1"].started.wait(), timeout=1) send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again")) await asyncio.sleep(0) assert agent_api.instances["chat-1"].calls == ["hello"] assert agent_api.instances["chat-1"].max_active_calls == 1 agent_api.instances["chat-1"].release.set() stream_chunks = await stream_task send_result = await send_task assert [chunk.delta for chunk in stream_chunks] == ["hello", ""] assert send_result.response == "again" assert agent_api.instances["chat-1"].calls == ["hello", "again"] assert agent_api.instances["chat-1"].max_active_calls == 1 @pytest.mark.asyncio async def test_real_platform_client_stream_message_emits_final_tokens_chunk(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) chunks = [] async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"): chunks.append(chunk) assert chunks == [ MessageChunk( message_id="@alice:example.org", delta="he", finished=False, tokens_used=0, ), MessageChunk( message_id="@alice:example.org", delta="llo", finished=False, tokens_used=0, ), MessageChunk( message_id="@alice:example.org", delta="", finished=True, tokens_used=3, ), ] assert agent_api.created_chat_ids == ["chat-1"] assert agent_api.instances["chat-1"].calls == ["hello"] @pytest.mark.asyncio async def test_real_platform_client_settings_are_local(): client = RealPlatformClient( agent_api=FakeAgentApiFactory(), prototype_state=PrototypeStateStore(), platform="matrix", ) await client.update_settings( "usr-matrix-u1", SettingsAction(action="toggle_skill", payload={"skill": "browser", "enabled": True}), ) settings = await client.get_settings("usr-matrix-u1") assert isinstance(settings, UserSettings) assert settings.skills["browser"] is True assert settings.skills["web-search"] is True @pytest.mark.asyncio async def test_agent_api_wrapper_transparently_surfaces_modern_events(monkeypatch): callback_events: list[object] = [] queue: asyncio.Queue = asyncio.Queue() event_map = { "text": TextChunkEvent("he"), "tool_call": ToolCallChunkEvent("call"), "tool_result": ToolResultEvent("result"), "custom_update": CustomUpdateEvent("update"), "send_file": SendFileEvent( workspace_path="/workspace/report.pdf", mime_type="application/pdf", filename="report.pdf", size=123, ), "end": EndEvent(tokens_used=11), "error": ErrorEvent(code="BOOM", details="bad things"), "disconnect": GracefulDisconnectEvent(), } def fake_validate_json(data: str): return event_map[data] monkeypatch.setattr( agent_api_wrapper_module, "ServerMessage", type("FakeServerMessage", (), {"validate_json": staticmethod(fake_validate_json)}), ) async def fake_cleanup(self): return None monkeypatch.setattr(agent_api_wrapper_module.AgentApiWrapper, "_cleanup", fake_cleanup) monkeypatch.setattr( agent_api_wrapper_module.AgentApi, "__init__", lambda self, agent_id, base_url=None, chat_id=0, **kwargs: ( setattr(self, "id", agent_id) or setattr(self, "callback", kwargs.get("callback")) or setattr(self, "on_disconnect", kwargs.get("on_disconnect")) or setattr(self, "_current_queue", None) ), ) wrapper = AgentApiWrapper( agent_id="agent-1", base_url="https://agent.example.com/v1/agent_ws", chat_id="chat-1", callback=callback_events.append, ) wrapper._current_queue = queue wrapper._ws = FakeWebSocket( [ FakeWSMessage("text"), FakeWSMessage("tool_call"), FakeWSMessage("tool_result"), FakeWSMessage("custom_update"), FakeWSMessage("send_file"), FakeWSMessage("end"), FakeWSMessage("error"), FakeWSMessage("disconnect"), ] ) await wrapper._listen() queue_events = [] while not queue.empty(): queue_events.append(await queue.get()) assert queue_events[0].text == "he" assert any(isinstance(event, SendFileEvent) for event in queue_events) assert any(isinstance(event, EndEvent) for event in queue_events) assert any(isinstance(event, GracefulDisconnectEvent) for event in queue_events) assert callback_events[0].payload == "call" assert callback_events[1].payload == "result" assert callback_events[2].payload == "update" assert any(isinstance(event, SendFileEvent) for event in callback_events) assert any(isinstance(event, EndEvent) for event in callback_events) assert any(isinstance(event, ErrorEvent) for event in callback_events) assert any(isinstance(event, GracefulDisconnectEvent) for event in callback_events) assert wrapper.last_tokens_used == 11