surfaces/tests/platform/test_real.py

409 lines
14 KiB
Python

import asyncio
import pytest
from pydantic import Field
from core.protocol import SettingsAction
from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformError, UserSettings
from sdk.prototype_state import PrototypeStateStore
from sdk.real import RealPlatformClient
from sdk.upstream_agent_api import MsgEventSendFile, MsgEventTextChunk
class FakeChatAgentApi:
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
self.agent_id = agent_id
self.base_url = base_url
self.chat_id = str(chat_id)
self.calls: list[str] = []
self.connect_calls = 0
self.close_calls = 0
async def connect(self) -> None:
self.connect_calls += 1
async def close(self) -> None:
self.close_calls += 1
async def send_message(self, text: str, attachments: list[str] | None = None):
self.calls.append(text)
midpoint = len(text) // 2
yield MsgEventTextChunk(text=text[:midpoint])
yield MsgEventTextChunk(text=text[midpoint:])
class FakeAgentApiFactory:
def __init__(self, chat_api_cls=FakeChatAgentApi) -> None:
self.chat_api_cls = chat_api_cls
self.created_calls: list[tuple[str, str, str]] = []
self.instances_by_chat: dict[str, list[FakeChatAgentApi]] = {}
def __call__(self, agent_id: str, base_url: str, chat_id: str):
chat_key = str(chat_id)
chat_api = self.chat_api_cls(agent_id, base_url, chat_key)
self.created_calls.append((agent_id, base_url, chat_key))
self.instances_by_chat.setdefault(chat_key, []).append(chat_api)
return chat_api
def latest(self, chat_id: str):
return self.instances_by_chat[str(chat_id)][-1]
class BlockingTracker:
def __init__(self) -> None:
self.active_calls = 0
self.max_active_calls = 0
self.started = asyncio.Event()
self.release = asyncio.Event()
class BlockingChatAgentApi(FakeChatAgentApi):
def __init__(
self,
agent_id: str,
base_url: str,
chat_id: str,
*,
tracker: BlockingTracker,
) -> None:
super().__init__(agent_id, base_url, chat_id)
self._tracker = tracker
async def send_message(self, text: str, attachments: list[str] | None = None):
self.calls.append(text)
self._tracker.active_calls += 1
self._tracker.max_active_calls = max(
self._tracker.max_active_calls,
self._tracker.active_calls,
)
self._tracker.started.set()
await self._tracker.release.wait()
self._tracker.active_calls -= 1
yield MsgEventTextChunk(text=text)
class BlockingAgentApiFactory(FakeAgentApiFactory):
def __init__(self) -> None:
super().__init__()
self.tracker = BlockingTracker()
def __call__(self, agent_id: str, base_url: str, chat_id: str):
chat_key = str(chat_id)
chat_api = BlockingChatAgentApi(
agent_id,
base_url,
chat_key,
tracker=self.tracker,
)
self.created_calls.append((agent_id, base_url, chat_key))
self.instances_by_chat.setdefault(chat_key, []).append(chat_api)
return chat_api
class AttachmentTrackingChatAgentApi(FakeChatAgentApi):
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
super().__init__(agent_id, base_url, chat_id)
self.calls: list[tuple[str, list[str] | None]] = []
async def send_message(self, text: str, attachments: list[str] | None = None):
self.calls.append((text, attachments))
yield MsgEventTextChunk(text=text)
class FlakyChatAgentApi(FakeChatAgentApi):
async def send_message(self, text: str, attachments: list[str] | None = None):
raise ConnectionError("Connection closed")
yield
class ReuseSensitiveChatAgentApi(FakeChatAgentApi):
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
super().__init__(agent_id, base_url, chat_id)
self._send_calls = 0
async def send_message(self, text: str, attachments: list[str] | None = None):
self.calls.append(text)
self._send_calls += 1
if text == "first":
yield MsgEventTextChunk(text="tool ok")
return
if text == "second" and self._send_calls == 1:
yield MsgEventTextChunk(text="Missing")
class MessageResponseWithAttachments(MessageResponse):
attachments: list[Attachment] = Field(default_factory=list)
def make_real_platform_client(
agent_api_cls,
*,
prototype_state: PrototypeStateStore | None = None,
) -> RealPlatformClient:
return RealPlatformClient(
agent_id="matrix-bot",
agent_base_url="http://platform-agent:8000",
agent_api_cls=agent_api_cls,
prototype_state=prototype_state or PrototypeStateStore(),
platform="matrix",
)
@pytest.mark.asyncio
async def test_real_platform_client_get_or_create_user_uses_local_state():
client = make_real_platform_client(FakeAgentApiFactory())
first = await client.get_or_create_user("u1", "matrix", "Alice")
second = await client.get_or_create_user("u1", "matrix")
assert first.user_id == "usr-matrix-u1"
assert first.is_new is True
assert second.user_id == first.user_id
assert second.is_new is False
assert second.display_name == "Alice"
@pytest.mark.asyncio
async def test_real_platform_client_send_message_uses_direct_agent_api_per_chat():
agent_api = FakeAgentApiFactory()
prototype_state = PrototypeStateStore()
client = make_real_platform_client(agent_api, prototype_state=prototype_state)
result = await client.send_message("@alice:example.org", "chat-7", "hello")
assert result == MessageResponse(
message_id="@alice:example.org",
response="hello",
tokens_used=0,
finished=True,
)
assert agent_api.created_calls == [("matrix-bot", "http://platform-agent:8000", "chat-7")]
assert agent_api.latest("chat-7").chat_id == "chat-7"
assert agent_api.latest("chat-7").calls == ["hello"]
assert agent_api.latest("chat-7").connect_calls == 1
assert agent_api.latest("chat-7").close_calls == 1
assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 0
@pytest.mark.asyncio
async def test_real_platform_client_forwards_attachments_to_chat_api():
agent_api = FakeAgentApiFactory(chat_api_cls=AttachmentTrackingChatAgentApi)
client = make_real_platform_client(agent_api)
attachment = Attachment(
workspace_path="surfaces/matrix/alice/room/inbox/report.pdf",
mime_type="application/pdf",
filename="report.pdf",
size=123,
)
result = await client.send_message(
"@alice:example.org",
"chat-7",
"hello",
attachments=[attachment],
)
assert agent_api.latest("chat-7").calls == [
("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"])
]
assert result.response == "hello"
assert result.tokens_used == 0
@pytest.mark.asyncio
async def test_real_platform_client_preserves_send_file_events_in_sync_result(monkeypatch):
class FileEventAgentApi(AttachmentTrackingChatAgentApi):
async def send_message(self, text: str, attachments: list[str] | None = None):
self.calls.append((text, attachments))
yield MsgEventTextChunk(text="he")
yield MsgEventSendFile(path="report.pdf")
yield MsgEventTextChunk(text="llo")
agent_api = FakeAgentApiFactory(chat_api_cls=FileEventAgentApi)
client = make_real_platform_client(agent_api)
monkeypatch.setattr("sdk.real.MessageResponse", MessageResponseWithAttachments)
result = await client.send_message("@alice:example.org", "chat-7", "hello")
assert result.response == "hello"
assert result.tokens_used == 0
assert result.attachments == [
Attachment(
url="report.pdf",
mime_type="application/octet-stream",
filename="report.pdf",
size=None,
workspace_path="report.pdf",
)
]
@pytest.mark.asyncio
async def test_real_platform_client_uses_fresh_agent_connection_per_request():
agent_api = FakeAgentApiFactory()
client = make_real_platform_client(agent_api)
await client.send_message("@alice:example.org", "chat-1", "hello")
await client.send_message("@alice:example.org", "chat-1", "again")
assert agent_api.created_calls == [
("matrix-bot", "http://platform-agent:8000", "chat-1"),
("matrix-bot", "http://platform-agent:8000", "chat-1"),
]
assert [instance.calls for instance in agent_api.instances_by_chat["chat-1"]] == [
["hello"],
["again"],
]
assert all(instance.connect_calls == 1 for instance in agent_api.instances_by_chat["chat-1"])
assert all(instance.close_calls == 1 for instance in agent_api.instances_by_chat["chat-1"])
@pytest.mark.asyncio
async def test_real_platform_client_avoids_reuse_sensitive_second_message_loss():
agent_api = FakeAgentApiFactory(chat_api_cls=ReuseSensitiveChatAgentApi)
client = make_real_platform_client(agent_api)
first = await client.send_message("@alice:example.org", "chat-1", "first")
second = await client.send_message("@alice:example.org", "chat-1", "second")
assert first.response == "tool ok"
assert second.response == "Missing"
assert len(agent_api.instances_by_chat["chat-1"]) == 2
@pytest.mark.asyncio
async def test_real_platform_client_wraps_connection_closed_as_platform_error():
agent_api = FakeAgentApiFactory(chat_api_cls=FlakyChatAgentApi)
client = make_real_platform_client(agent_api)
with pytest.raises(PlatformError, match="Connection closed") as exc_info:
await client.send_message("@alice:example.org", "chat-1", "hello")
assert exc_info.value.code == "PLATFORM_CONNECTION_ERROR"
assert agent_api.latest("chat-1").close_calls == 1
@pytest.mark.asyncio
async def test_real_platform_client_uses_fresh_connection_after_failure():
class SometimesFlakyAgentApi(FakeChatAgentApi):
async def send_message(self, text: str, attachments: list[str] | None = None):
if text == "hello":
raise ConnectionError("Connection closed")
self.calls.append(text)
yield MsgEventTextChunk(text=text)
agent_api = FakeAgentApiFactory(chat_api_cls=SometimesFlakyAgentApi)
client = make_real_platform_client(agent_api)
with pytest.raises(PlatformError, match="Connection closed"):
await client.send_message("@alice:example.org", "chat-1", "hello")
result = await client.send_message("@alice:example.org", "chat-1", "again")
assert result.response == "again"
assert agent_api.created_calls == [
("matrix-bot", "http://platform-agent:8000", "chat-1"),
("matrix-bot", "http://platform-agent:8000", "chat-1"),
]
assert agent_api.latest("chat-1").calls == ["again"]
@pytest.mark.asyncio
async def test_real_platform_client_serializes_same_chat_streams_across_send_paths():
agent_api = BlockingAgentApiFactory()
client = make_real_platform_client(agent_api)
async def consume_stream():
chunks = []
async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
chunks.append(chunk)
return chunks
stream_task = asyncio.create_task(consume_stream())
await asyncio.wait_for(agent_api.tracker.started.wait(), timeout=1)
send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again"))
await asyncio.sleep(0)
assert len(agent_api.instances_by_chat["chat-1"]) == 1
assert agent_api.instances_by_chat["chat-1"][0].calls == ["hello"]
assert agent_api.tracker.max_active_calls == 1
agent_api.tracker.release.set()
stream_chunks = await stream_task
send_result = await send_task
assert [chunk.delta for chunk in stream_chunks] == ["hello", ""]
assert send_result.response == "again"
assert [instance.calls for instance in agent_api.instances_by_chat["chat-1"]] == [
["hello"],
["again"],
]
assert agent_api.tracker.max_active_calls == 1
@pytest.mark.asyncio
async def test_real_platform_client_creates_distinct_connections_per_chat():
agent_api = FakeAgentApiFactory()
client = make_real_platform_client(agent_api)
await client.send_message("@alice:example.org", "chat-1", "hello")
await client.send_message("@alice:example.org", "chat-2", "world")
assert agent_api.created_calls == [
("matrix-bot", "http://platform-agent:8000", "chat-1"),
("matrix-bot", "http://platform-agent:8000", "chat-2"),
]
assert agent_api.latest("chat-1").calls == ["hello"]
assert agent_api.latest("chat-2").calls == ["world"]
@pytest.mark.asyncio
async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
agent_api = FakeAgentApiFactory()
client = make_real_platform_client(agent_api)
chunks = []
async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
chunks.append(chunk)
assert chunks == [
MessageChunk(
message_id="@alice:example.org",
delta="he",
finished=False,
tokens_used=0,
),
MessageChunk(
message_id="@alice:example.org",
delta="llo",
finished=False,
tokens_used=0,
),
MessageChunk(
message_id="@alice:example.org",
delta="",
finished=True,
tokens_used=0,
),
]
assert agent_api.created_calls == [("matrix-bot", "http://platform-agent:8000", "chat-1")]
assert agent_api.latest("chat-1").calls == ["hello"]
assert agent_api.latest("chat-1").close_calls == 1
@pytest.mark.asyncio
async def test_real_platform_client_settings_are_local():
client = make_real_platform_client(FakeAgentApiFactory())
await client.update_settings(
"usr-matrix-u1",
SettingsAction(action="toggle_skill", payload={"skill": "browser", "enabled": True}),
)
settings = await client.get_settings("usr-matrix-u1")
assert isinstance(settings, UserSettings)
assert settings.skills["browser"] is True
assert settings.skills["web-search"] is True