import asyncio import pytest from core.protocol import SettingsAction import sdk.agent_api_wrapper as agent_api_wrapper_module from sdk.agent_api_wrapper import AgentApiWrapper from sdk.interface import MessageChunk, MessageResponse, UserSettings from sdk.prototype_state import PrototypeStateStore from sdk.real import RealPlatformClient class FakeChunk: def __init__(self, text: str) -> None: self.text = text class FakeChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.calls: list[str] = [] self.connect_calls = 0 self.close_calls = 0 self.last_tokens_used = 0 async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str): self.calls.append(text) midpoint = len(text) // 2 yield FakeChunk(text[:midpoint]) yield FakeChunk(text[midpoint:]) self.last_tokens_used = 3 class FakeAgentApiFactory: def __init__(self) -> None: self.created_chat_ids: list[str] = [] self.instances: dict[str, FakeChatAgentApi] = {} def for_chat(self, chat_id: str) -> FakeChatAgentApi: chat_api = FakeChatAgentApi(chat_id) self.created_chat_ids.append(chat_id) self.instances[chat_id] = chat_api return chat_api class LegacyAgentApi: def __init__(self) -> None: self.calls: list[str] = [] self.last_tokens_used = 0 async def send_message(self, text: str): self.calls.append(text) yield FakeChunk(text[:2]) yield FakeChunk(text[2:]) self.last_tokens_used = 7 class BlockingChatAgentApi: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.calls: list[str] = [] self.connect_calls = 0 self.close_calls = 0 self.last_tokens_used = 0 self.active_calls = 0 self.max_active_calls = 0 self.started = asyncio.Event() self.release = asyncio.Event() async def connect(self) -> None: self.connect_calls += 1 async def close(self) -> None: self.close_calls += 1 async def send_message(self, text: str): self.calls.append(text) self.active_calls += 1 self.max_active_calls = max(self.max_active_calls, self.active_calls) self.started.set() await self.release.wait() self.active_calls -= 1 yield FakeChunk(text) self.last_tokens_used = len(text) def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch): calls: list[dict[str, object]] = [] def fake_init(self, agent_id, base_url, chat_id, **kwargs): calls.append( { "agent_id": agent_id, "base_url": base_url, "chat_id": chat_id, "kwargs": kwargs, } ) self.id = agent_id self.url = base_url self.callback = kwargs.get("callback") self.on_disconnect = kwargs.get("on_disconnect") monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) wrapper = AgentApiWrapper( agent_id="agent-1", base_url="https://agent.example.com/v1/agent_ws", chat_id="chat-1", callback="cb", on_disconnect="disconnect", ) child = wrapper.for_chat("chat-2") assert calls == [ { "agent_id": "agent-1", "base_url": "https://agent.example.com", "chat_id": "chat-1", "kwargs": {"callback": "cb", "on_disconnect": "disconnect"}, }, { "agent_id": "agent-1", "base_url": "https://agent.example.com", "chat_id": "chat-2", "kwargs": {"callback": "cb", "on_disconnect": "disconnect"}, }, ] assert wrapper._base_url == "https://agent.example.com" assert wrapper.chat_id == "chat-1" assert wrapper.last_tokens_used == 0 assert child.chat_id == "chat-2" def test_agent_api_wrapper_falls_back_to_legacy_url_constructor(monkeypatch): calls: list[dict[str, object]] = [] def fake_init(self, agent_id, url, callback=None, on_disconnect=None): calls.append( { "agent_id": agent_id, "url": url, "callback": callback, "on_disconnect": on_disconnect, } ) self.id = agent_id self.url = url self.callback = callback self.on_disconnect = on_disconnect monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) wrapper = AgentApiWrapper( agent_id="agent-2", url="https://agent.example.com/v1/agent_ws/chat-9/", chat_id="chat-9", callback="cb", ) assert calls == [ { "agent_id": "agent-2", "url": "https://agent.example.com/v1/agent_ws/chat-9/", "callback": "cb", "on_disconnect": None, } ] assert wrapper._base_url == "https://agent.example.com" assert wrapper.chat_id == "chat-9" assert wrapper.last_tokens_used == 0 @pytest.mark.asyncio async def test_real_platform_client_get_or_create_user_uses_local_state(): client = RealPlatformClient( agent_api=FakeAgentApiFactory(), prototype_state=PrototypeStateStore(), ) first = await client.get_or_create_user("u1", "matrix", "Alice") second = await client.get_or_create_user("u1", "matrix") assert first.user_id == "usr-matrix-u1" assert first.is_new is True assert second.user_id == first.user_id assert second.is_new is False assert second.display_name == "Alice" @pytest.mark.asyncio async def test_real_platform_client_send_message_uses_chat_bound_client(): agent_api = FakeAgentApiFactory() prototype_state = PrototypeStateStore() client = RealPlatformClient( agent_api=agent_api, prototype_state=prototype_state, platform="matrix", ) result = await client.send_message("@alice:example.org", "chat-7", "hello") assert result == MessageResponse( message_id="@alice:example.org", response="hello", tokens_used=3, finished=True, ) assert agent_api.created_chat_ids == ["chat-7"] assert agent_api.instances["chat-7"].chat_id == "chat-7" assert agent_api.instances["chat-7"].calls == ["hello"] assert agent_api.instances["chat-7"].connect_calls == 1 assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 3 @pytest.mark.asyncio async def test_real_platform_client_works_with_legacy_agent_api_without_for_chat(): legacy_api = LegacyAgentApi() client = RealPlatformClient( agent_api=legacy_api, prototype_state=PrototypeStateStore(), platform="matrix", ) result = await client.send_message("@alice:example.org", "chat-legacy", "hello") assert result == MessageResponse( message_id="@alice:example.org", response="hello", tokens_used=7, finished=True, ) assert legacy_api.calls == ["hello"] @pytest.mark.asyncio async def test_real_platform_client_reuses_cached_chat_client(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) await client.send_message("@alice:example.org", "chat-1", "hello") await client.send_message("@alice:example.org", "chat-1", "again") assert agent_api.created_chat_ids == ["chat-1"] assert agent_api.instances["chat-1"].calls == ["hello", "again"] assert agent_api.instances["chat-1"].connect_calls == 1 @pytest.mark.asyncio async def test_real_platform_client_creates_chat_client_atomically_for_concurrent_requests(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) results = await asyncio.gather( client.send_message("@alice:example.org", "chat-1", "hello"), client.send_message("@alice:example.org", "chat-1", "again"), ) assert [result.response for result in results] == ["hello", "again"] assert agent_api.created_chat_ids == ["chat-1"] assert agent_api.instances["chat-1"].connect_calls == 1 assert agent_api.instances["chat-1"].calls == ["hello", "again"] @pytest.mark.asyncio async def test_real_platform_client_creates_distinct_clients_per_chat(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) await client.send_message("@alice:example.org", "chat-1", "hello") await client.send_message("@alice:example.org", "chat-2", "world") assert agent_api.created_chat_ids == ["chat-1", "chat-2"] assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"] assert agent_api.instances["chat-1"].calls == ["hello"] assert agent_api.instances["chat-2"].calls == ["world"] @pytest.mark.asyncio async def test_real_platform_client_serializes_same_chat_streams_across_send_paths(): agent_api = FakeAgentApiFactory() agent_api.instances["chat-1"] = BlockingChatAgentApi("chat-1") agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(chat_id, BlockingChatAgentApi(chat_id)) client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) async def consume_stream(): chunks = [] async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"): chunks.append(chunk) return chunks stream_task = asyncio.create_task(consume_stream()) await asyncio.wait_for(agent_api.instances["chat-1"].started.wait(), timeout=1) send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again")) await asyncio.sleep(0) assert agent_api.instances["chat-1"].calls == ["hello"] assert agent_api.instances["chat-1"].max_active_calls == 1 agent_api.instances["chat-1"].release.set() stream_chunks = await stream_task send_result = await send_task assert [chunk.delta for chunk in stream_chunks] == ["hello", ""] assert send_result.response == "again" assert agent_api.instances["chat-1"].calls == ["hello", "again"] assert agent_api.instances["chat-1"].max_active_calls == 1 @pytest.mark.asyncio async def test_real_platform_client_stream_message_emits_final_tokens_chunk(): agent_api = FakeAgentApiFactory() client = RealPlatformClient( agent_api=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) chunks = [] async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"): chunks.append(chunk) assert chunks == [ MessageChunk( message_id="@alice:example.org", delta="he", finished=False, tokens_used=0, ), MessageChunk( message_id="@alice:example.org", delta="llo", finished=False, tokens_used=0, ), MessageChunk( message_id="@alice:example.org", delta="", finished=True, tokens_used=3, ), ] assert agent_api.created_chat_ids == ["chat-1"] assert agent_api.instances["chat-1"].calls == ["hello"] @pytest.mark.asyncio async def test_real_platform_client_settings_are_local(): client = RealPlatformClient( agent_api=FakeAgentApiFactory(), prototype_state=PrototypeStateStore(), platform="matrix", ) await client.update_settings( "usr-matrix-u1", SettingsAction(action="toggle_skill", payload={"skill": "browser", "enabled": True}), ) settings = await client.get_settings("usr-matrix-u1") assert isinstance(settings, UserSettings) assert settings.skills["browser"] is True assert settings.skills["web-search"] is True