import asyncio import pytest from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk from pydantic import Field import sdk.agent_api_wrapper as agent_api_wrapper_module from core.protocol import SettingsAction from sdk.agent_api_wrapper import AgentApiWrapper from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformError, UserSettings from sdk.prototype_state import PrototypeStateStore from sdk.real import RealPlatformClient class FakeChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.calls: list[str] = [] self.connect_calls = 0 self.close_calls = 0 async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str, attachments: list[str] | None = None): self.calls.append(text) midpoint = len(text) // 2 yield MsgEventTextChunk(text=text[:midpoint]) yield MsgEventTextChunk(text=text[midpoint:]) class FakeAgentApiFactory: def __init__(self) -> None: self.created_chat_ids: list[str] = [] self.instances: dict[str, FakeChatAgentApi] = {} def for_chat(self, chat_id: str) -> FakeChatAgentApi: chat_api = FakeChatAgentApi(chat_id) self.created_chat_ids.append(chat_id) self.instances[chat_id] = chat_api return chat_api class BlockingChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.calls: list[str] = [] self.connect_calls = 0 self.close_calls = 0 self.active_calls = 0 self.max_active_calls = 0 self.started = asyncio.Event() self.release = asyncio.Event() async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str, attachments: list[str] | None = None): self.calls.append(text) self.active_calls += 1 self.max_active_calls = max(self.max_active_calls, self.active_calls) self.started.set() await self.release.wait() self.active_calls -= 1 yield MsgEventTextChunk(text=text) class AttachmentTrackingChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.calls: list[tuple[str, list[str] | None]] = [] self.connect_calls = 0 self.close_calls = 0 async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str, attachments: list[str] | None = None): self.calls.append((text, attachments)) yield MsgEventTextChunk(text=text) class AttachmentTrackingAgentApiFactory: def __init__(self, chat_api_cls=AttachmentTrackingChatAgentApi) -> None: self.chat_api_cls = chat_api_cls self.created_chat_ids: list[str] = [] self.instances: dict[str, AttachmentTrackingChatAgentApi] = {} def for_chat(self, chat_id: str) -> AttachmentTrackingChatAgentApi: chat_api = self.chat_api_cls(chat_id) self.created_chat_ids.append(chat_id) self.instances[chat_id] = chat_api return chat_api class FlakyChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.connect_calls = 0 self.close_calls = 0 async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str, attachments: list[str] | None = None): raise ConnectionError("Connection closed") yield class MessageResponseWithAttachments(MessageResponse): attachments: list[Attachment] = Field(default_factory=list) def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch): captured = {} def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): captured["agent_id"] = agent_id captured["base_url"] = base_url captured["chat_id"] = chat_id monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) wrapper = AgentApiWrapper( agent_id="agent-1", base_url="ws://platform-agent:8000/v1/agent_ws/", chat_id="41", ) assert wrapper.chat_id == "41" assert wrapper._base_url == "ws://platform-agent:8000" assert captured == { "agent_id": "agent-1", "base_url": "ws://platform-agent:8000", "chat_id": "41", } def test_agent_api_wrapper_for_chat_reuses_normalized_base_url(monkeypatch): init_calls = [] def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): self.id = agent_id self.chat_id = chat_id self.url = base_url init_calls.append((agent_id, base_url, chat_id)) monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) root = AgentApiWrapper( agent_id="agent-1", base_url="http://platform-agent:8000/v1/agent_ws/", chat_id="1", ) child = root.for_chat("99") assert child is not root assert child.chat_id == "99" assert child._base_url == "http://platform-agent:8000" assert init_calls == [ ("agent-1", "http://platform-agent:8000", "1"), ("agent-1", "http://platform-agent:8000", "99"), ] @pytest.mark.asyncio async def test_real_platform_client_get_or_create_user_uses_local_state(): client = RealPlatformClient( agent_api=FakeAgentApiFactory(), prototype_state=PrototypeStateStore(), ) first = await client.get_or_create_user("u1", "matrix", "Alice") second = await client.get_or_create_user("u1", "matrix") assert first.user_id == "usr-matrix-u1" assert first.is_new is True assert second.user_id == first.user_id assert second.is_new is False assert second.display_name == "Alice" @pytest.mark.asyncio async def test_real_platform_client_send_message_uses_chat_bound_client(): agent_api = FakeAgentApiFactory() prototype_state = PrototypeStateStore() client = RealPlatformClient( agent_api=agent_api, prototype_state=prototype_state, platform="matrix", ) result = await client.send_message("@alice:example.org", "chat-7", "hello") assert result == MessageResponse( message_id="@alice:example.org", response="hello", tokens_used=0, finished=True, ) assert agent_api.created_chat_ids == ["chat-7"] assert agent_api.instances["chat-7"].chat_id == "chat-7" assert agent_api.instances["chat-7"].calls == ["hello"] assert agent_api.instances["chat-7"].connect_calls == 1 assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 0 @pytest.mark.asyncio async def test_real_platform_client_forwards_attachments_to_chat_api(): agent_api = AttachmentTrackingAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) attachment = Attachment( workspace_path="surfaces/matrix/alice/room/inbox/report.pdf", mime_type="application/pdf", filename="report.pdf", size=123, ) result = await client.send_message( "@alice:example.org", "chat-7", "hello", attachments=[attachment], ) assert agent_api.instances["chat-7"].calls == [ ("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"]) ] assert result.response == "hello" assert result.tokens_used == 0 @pytest.mark.asyncio async def test_real_platform_client_preserves_send_file_events_in_sync_result(monkeypatch): class FileEventAgentApi(AttachmentTrackingChatAgentApi): async def send_message(self, text: str, attachments: list[str] | None = None): self.calls.append((text, attachments)) yield MsgEventTextChunk(text="he") yield MsgEventSendFile(path="report.pdf") yield MsgEventTextChunk(text="llo") agent_api = AttachmentTrackingAgentApiFactory(chat_api_cls=FileEventAgentApi) client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) monkeypatch.setattr( "sdk.real.MessageResponse", MessageResponseWithAttachments, ) result = await client.send_message("@alice:example.org", "chat-7", "hello") assert result.response == "hello" assert result.tokens_used == 0 assert result.attachments == [ Attachment( url="report.pdf", mime_type="application/octet-stream", filename="report.pdf", size=None, workspace_path="report.pdf", ) ] @pytest.mark.asyncio async def test_real_platform_client_reuses_cached_chat_client(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) await client.send_message("@alice:example.org", "chat-1", "hello") await client.send_message("@alice:example.org", "chat-1", "again") assert agent_api.created_chat_ids == ["chat-1"] assert agent_api.instances["chat-1"].calls == ["hello", "again"] assert agent_api.instances["chat-1"].connect_calls == 1 assert agent_api.instances["chat-1"].close_calls == 0 @pytest.mark.asyncio async def test_real_platform_client_wraps_connection_closed_as_platform_error(): agent_api = FakeAgentApiFactory() agent_api.instances["chat-1"] = FlakyChatAgentApi("chat-1") agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault( chat_id, FlakyChatAgentApi(chat_id) ) client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) with pytest.raises(PlatformError, match="Connection closed") as exc_info: await client.send_message("@alice:example.org", "chat-1", "hello") assert exc_info.value.code == "PLATFORM_CONNECTION_ERROR" assert "chat-1" not in client._chat_apis assert agent_api.instances["chat-1"].close_calls == 1 @pytest.mark.asyncio async def test_real_platform_client_reconnects_after_closed_chat_api(): agent_api = FakeAgentApiFactory() flaky = FlakyChatAgentApi("chat-1") healthy = AttachmentTrackingChatAgentApi("chat-1") provided = iter([flaky, healthy]) def for_chat(chat_id: str): chat_api = next(provided) agent_api.created_chat_ids.append(chat_id) agent_api.instances[chat_id] = chat_api return chat_api agent_api.for_chat = for_chat client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) with pytest.raises(PlatformError, match="Connection closed"): await client.send_message("@alice:example.org", "chat-1", "hello") result = await client.send_message("@alice:example.org", "chat-1", "again") assert result.response == "again" assert agent_api.created_chat_ids == ["chat-1", "chat-1"] assert healthy.calls == [("again", None)] @pytest.mark.asyncio async def test_real_platform_client_creates_chat_client_atomically_for_concurrent_requests(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) results = await asyncio.gather( client.send_message("@alice:example.org", "chat-1", "hello"), client.send_message("@alice:example.org", "chat-1", "again"), ) assert [result.response for result in results] == ["hello", "again"] assert agent_api.created_chat_ids == ["chat-1"] assert agent_api.instances["chat-1"].connect_calls == 1 assert agent_api.instances["chat-1"].calls == ["hello", "again"] @pytest.mark.asyncio async def test_real_platform_client_creates_distinct_clients_per_chat(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) await client.send_message("@alice:example.org", "chat-1", "hello") await client.send_message("@alice:example.org", "chat-2", "world") assert agent_api.created_chat_ids == ["chat-1", "chat-2"] assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"] assert agent_api.instances["chat-1"].calls == ["hello"] assert agent_api.instances["chat-2"].calls == ["world"] @pytest.mark.asyncio async def test_real_platform_client_serializes_same_chat_streams_across_send_paths(): agent_api = FakeAgentApiFactory() agent_api.instances["chat-1"] = BlockingChatAgentApi("chat-1") agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault( chat_id, BlockingChatAgentApi(chat_id) ) client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) async def consume_stream(): chunks = [] async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"): chunks.append(chunk) return chunks stream_task = asyncio.create_task(consume_stream()) await asyncio.wait_for(agent_api.instances["chat-1"].started.wait(), timeout=1) send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again")) await asyncio.sleep(0) assert agent_api.instances["chat-1"].calls == ["hello"] assert agent_api.instances["chat-1"].max_active_calls == 1 agent_api.instances["chat-1"].release.set() stream_chunks = await stream_task send_result = await send_task assert [chunk.delta for chunk in stream_chunks] == ["hello", ""] assert send_result.response == "again" assert agent_api.instances["chat-1"].calls == ["hello", "again"] assert agent_api.instances["chat-1"].max_active_calls == 1 @pytest.mark.asyncio async def test_real_platform_client_stream_message_emits_final_tokens_chunk(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) chunks = [] async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"): chunks.append(chunk) assert chunks == [ MessageChunk( message_id="@alice:example.org", delta="he", finished=False, tokens_used=0, ), MessageChunk( message_id="@alice:example.org", delta="llo", finished=False, tokens_used=0, ), MessageChunk( message_id="@alice:example.org", delta="", finished=True, tokens_used=0, ), ] assert agent_api.created_chat_ids == ["chat-1"] assert agent_api.instances["chat-1"].calls == ["hello"] @pytest.mark.asyncio async def test_real_platform_client_settings_are_local(): client = RealPlatformClient( agent_api=FakeAgentApiFactory(), prototype_state=PrototypeStateStore(), platform="matrix", ) await client.update_settings( "usr-matrix-u1", SettingsAction(action="toggle_skill", payload={"skill": "browser", "enabled": True}), ) settings = await client.get_settings("usr-matrix-u1") assert isinstance(settings, UserSettings) assert settings.skills["browser"] is True assert settings.skills["web-search"] is True