surfaces/tests/platform/test_real.py

532 lines
16 KiB
Python

import asyncio
import pytest
import sdk.agent_api_wrapper as agent_api_wrapper_module
from core.protocol import SettingsAction
from sdk.agent_api_wrapper import AgentApiWrapper
from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformError, UserSettings
from sdk.prototype_state import PrototypeStateStore
from sdk.real import RealPlatformClient
class FakeChunk:
def __init__(self, text: str) -> None:
self.text = text
class FakeChatAgentApi:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.calls: list[str] = []
self.connect_calls = 0
self.close_calls = 0
self.last_tokens_used = 0
async def connect(self) -> None:
self.connect_calls += 1
async def close(self) -> None:
self.close_calls += 1
async def send_message(self, text: str):
self.calls.append(text)
midpoint = len(text) // 2
yield FakeChunk(text[:midpoint])
yield FakeChunk(text[midpoint:])
self.last_tokens_used = 3
class FakeAgentApiFactory:
def __init__(self) -> None:
self.created_chat_ids: list[str] = []
self.instances: dict[str, FakeChatAgentApi] = {}
def for_chat(self, chat_id: str) -> FakeChatAgentApi:
chat_api = FakeChatAgentApi(chat_id)
self.created_chat_ids.append(chat_id)
self.instances[chat_id] = chat_api
return chat_api
class LegacyAgentApi:
def __init__(self) -> None:
self.calls: list[str] = []
self.last_tokens_used = 0
async def send_message(self, text: str):
self.calls.append(text)
yield FakeChunk(text[:2])
yield FakeChunk(text[2:])
self.last_tokens_used = 7
class BlockingChatAgentApi:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.calls: list[str] = []
self.connect_calls = 0
self.close_calls = 0
self.last_tokens_used = 0
self.active_calls = 0
self.max_active_calls = 0
self.started = asyncio.Event()
self.release = asyncio.Event()
async def connect(self) -> None:
self.connect_calls += 1
async def close(self) -> None:
self.close_calls += 1
async def send_message(self, text: str):
self.calls.append(text)
self.active_calls += 1
self.max_active_calls = max(self.max_active_calls, self.active_calls)
self.started.set()
await self.release.wait()
self.active_calls -= 1
yield FakeChunk(text)
self.last_tokens_used = len(text)
class AttachmentTrackingChatAgentApi:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.calls: list[tuple[str, list[str] | None]] = []
self.connect_calls = 0
self.close_calls = 0
self.last_tokens_used = 0
async def connect(self) -> None:
self.connect_calls += 1
async def close(self) -> None:
self.close_calls += 1
async def send_message(self, text: str, attachments: list[str] | None = None):
self.calls.append((text, attachments))
yield FakeChunk(text)
self.last_tokens_used = 5
class FlakyChatAgentApi:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.connect_calls = 0
self.close_calls = 0
async def connect(self) -> None:
self.connect_calls += 1
async def close(self) -> None:
self.close_calls += 1
async def send_message(self, text: str, attachments: list[str] | None = None):
raise ConnectionError("Connection closed")
yield
class SendFileEvent:
def __init__(self, *, workspace_path: str, mime_type: str, filename: str, size: int) -> None:
self.type = "AGENT_EVENT_SEND_FILE"
self.workspace_path = workspace_path
self.mime_type = mime_type
self.filename = filename
self.size = size
class TextChunkEvent:
def __init__(self, text: str) -> None:
self.type = "AGENT_EVENT_TEXT_CHUNK"
self.text = text
class MessageResponseWithAttachments(MessageResponse):
attachments: list[Attachment] = []
def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch):
captured = {}
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
captured["agent_id"] = agent_id
captured["base_url"] = base_url
captured["chat_id"] = chat_id
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
wrapper = AgentApiWrapper(
agent_id="agent-1",
base_url="ws://platform-agent:8000/v1/agent_ws/",
chat_id="41",
)
assert wrapper.chat_id == "41"
assert wrapper._base_url == "ws://platform-agent:8000"
assert captured == {
"agent_id": "agent-1",
"base_url": "ws://platform-agent:8000",
"chat_id": "41",
}
def test_agent_api_wrapper_for_chat_reuses_normalized_base_url(monkeypatch):
init_calls = []
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
self.id = agent_id
self.chat_id = chat_id
self.url = base_url
init_calls.append((agent_id, base_url, chat_id))
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
root = AgentApiWrapper(
agent_id="agent-1",
base_url="http://platform-agent:8000/v1/agent_ws/",
chat_id="1",
)
child = root.for_chat("99")
assert child is not root
assert child.chat_id == "99"
assert child._base_url == "http://platform-agent:8000"
assert init_calls == [
("agent-1", "http://platform-agent:8000", "1"),
("agent-1", "http://platform-agent:8000", "99"),
]
@pytest.mark.asyncio
async def test_real_platform_client_get_or_create_user_uses_local_state():
client = RealPlatformClient(
agent_api=FakeAgentApiFactory(),
prototype_state=PrototypeStateStore(),
)
first = await client.get_or_create_user("u1", "matrix", "Alice")
second = await client.get_or_create_user("u1", "matrix")
assert first.user_id == "usr-matrix-u1"
assert first.is_new is True
assert second.user_id == first.user_id
assert second.is_new is False
assert second.display_name == "Alice"
@pytest.mark.asyncio
async def test_real_platform_client_send_message_uses_chat_bound_client():
agent_api = FakeAgentApiFactory()
prototype_state = PrototypeStateStore()
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=prototype_state,
platform="matrix",
)
result = await client.send_message("@alice:example.org", "chat-7", "hello")
assert result == MessageResponse(
message_id="@alice:example.org",
response="hello",
tokens_used=3,
finished=True,
)
assert agent_api.created_chat_ids == ["chat-7"]
assert agent_api.instances["chat-7"].chat_id == "chat-7"
assert agent_api.instances["chat-7"].calls == ["hello"]
assert agent_api.instances["chat-7"].connect_calls == 1
assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 3
@pytest.mark.asyncio
async def test_real_platform_client_forwards_attachments_to_chat_api():
agent_api = AttachmentTrackingChatAgentApi("chat-7")
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
attachment = Attachment(
workspace_path="surfaces/matrix/alice/room/inbox/report.pdf",
mime_type="application/pdf",
filename="report.pdf",
size=123,
)
result = await client.send_message(
"@alice:example.org",
"chat-7",
"hello",
attachments=[attachment],
)
assert agent_api.calls == [("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"])]
assert result.response == "hello"
assert result.tokens_used == 5
@pytest.mark.asyncio
async def test_real_platform_client_preserves_send_file_events_in_sync_result(monkeypatch):
agent_api = AttachmentTrackingChatAgentApi("chat-7")
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
class FileEventAgentApi(AttachmentTrackingChatAgentApi):
async def send_message(self, text: str, attachments: list[str] | None = None):
self.calls.append((text, attachments))
yield TextChunkEvent("he")
yield SendFileEvent(
workspace_path="/workspace/report.pdf",
mime_type="application/pdf",
filename="report.pdf",
size=123,
)
yield TextChunkEvent("llo")
self.last_tokens_used = 9
monkeypatch.setattr(
"sdk.real.MessageResponse",
MessageResponseWithAttachments,
)
client._agent_api = FileEventAgentApi("chat-7")
result = await client.send_message("@alice:example.org", "chat-7", "hello")
assert result.response == "hello"
assert result.tokens_used == 9
assert result.attachments == [
Attachment(
url="/workspace/report.pdf",
mime_type="application/pdf",
filename="report.pdf",
size=123,
workspace_path="report.pdf",
)
]
@pytest.mark.asyncio
async def test_real_platform_client_works_with_legacy_agent_api_without_for_chat():
legacy_api = LegacyAgentApi()
client = RealPlatformClient(
agent_api=legacy_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
result = await client.send_message("@alice:example.org", "chat-legacy", "hello")
assert result == MessageResponse(
message_id="@alice:example.org",
response="hello",
tokens_used=7,
finished=True,
)
assert legacy_api.calls == ["hello"]
@pytest.mark.asyncio
async def test_real_platform_client_reuses_cached_chat_client():
agent_api = FakeAgentApiFactory()
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
await client.send_message("@alice:example.org", "chat-1", "hello")
await client.send_message("@alice:example.org", "chat-1", "again")
assert agent_api.created_chat_ids == ["chat-1"]
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
assert agent_api.instances["chat-1"].connect_calls == 1
assert agent_api.instances["chat-1"].close_calls == 0
@pytest.mark.asyncio
async def test_real_platform_client_wraps_connection_closed_as_platform_error():
agent_api = FakeAgentApiFactory()
agent_api.instances["chat-1"] = FlakyChatAgentApi("chat-1")
agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(
chat_id, FlakyChatAgentApi(chat_id)
)
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
with pytest.raises(PlatformError, match="Connection closed") as exc_info:
await client.send_message("@alice:example.org", "chat-1", "hello")
assert exc_info.value.code == "PLATFORM_CONNECTION_ERROR"
assert "chat-1" not in client._chat_apis
assert agent_api.instances["chat-1"].close_calls == 1
@pytest.mark.asyncio
async def test_real_platform_client_reconnects_after_closed_chat_api():
agent_api = FakeAgentApiFactory()
flaky = FlakyChatAgentApi("chat-1")
healthy = AttachmentTrackingChatAgentApi("chat-1")
provided = iter([flaky, healthy])
def for_chat(chat_id: str):
chat_api = next(provided)
agent_api.created_chat_ids.append(chat_id)
agent_api.instances[chat_id] = chat_api
return chat_api
agent_api.for_chat = for_chat
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
with pytest.raises(PlatformError, match="Connection closed"):
await client.send_message("@alice:example.org", "chat-1", "hello")
result = await client.send_message("@alice:example.org", "chat-1", "again")
assert result.response == "again"
assert agent_api.created_chat_ids == ["chat-1", "chat-1"]
assert healthy.calls == [("again", None)]
@pytest.mark.asyncio
async def test_real_platform_client_creates_chat_client_atomically_for_concurrent_requests():
agent_api = FakeAgentApiFactory()
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
results = await asyncio.gather(
client.send_message("@alice:example.org", "chat-1", "hello"),
client.send_message("@alice:example.org", "chat-1", "again"),
)
assert [result.response for result in results] == ["hello", "again"]
assert agent_api.created_chat_ids == ["chat-1"]
assert agent_api.instances["chat-1"].connect_calls == 1
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
@pytest.mark.asyncio
async def test_real_platform_client_creates_distinct_clients_per_chat():
agent_api = FakeAgentApiFactory()
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
await client.send_message("@alice:example.org", "chat-1", "hello")
await client.send_message("@alice:example.org", "chat-2", "world")
assert agent_api.created_chat_ids == ["chat-1", "chat-2"]
assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"]
assert agent_api.instances["chat-1"].calls == ["hello"]
assert agent_api.instances["chat-2"].calls == ["world"]
@pytest.mark.asyncio
async def test_real_platform_client_serializes_same_chat_streams_across_send_paths():
agent_api = FakeAgentApiFactory()
agent_api.instances["chat-1"] = BlockingChatAgentApi("chat-1")
agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(
chat_id, BlockingChatAgentApi(chat_id)
)
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
async def consume_stream():
chunks = []
async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
chunks.append(chunk)
return chunks
stream_task = asyncio.create_task(consume_stream())
await asyncio.wait_for(agent_api.instances["chat-1"].started.wait(), timeout=1)
send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again"))
await asyncio.sleep(0)
assert agent_api.instances["chat-1"].calls == ["hello"]
assert agent_api.instances["chat-1"].max_active_calls == 1
agent_api.instances["chat-1"].release.set()
stream_chunks = await stream_task
send_result = await send_task
assert [chunk.delta for chunk in stream_chunks] == ["hello", ""]
assert send_result.response == "again"
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
assert agent_api.instances["chat-1"].max_active_calls == 1
@pytest.mark.asyncio
async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
agent_api = FakeAgentApiFactory()
client = RealPlatformClient(
agent_api=agent_api,
prototype_state=PrototypeStateStore(),
platform="matrix",
)
chunks = []
async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
chunks.append(chunk)
assert chunks == [
MessageChunk(
message_id="@alice:example.org",
delta="he",
finished=False,
tokens_used=0,
),
MessageChunk(
message_id="@alice:example.org",
delta="llo",
finished=False,
tokens_used=0,
),
MessageChunk(
message_id="@alice:example.org",
delta="",
finished=True,
tokens_used=3,
),
]
assert agent_api.created_chat_ids == ["chat-1"]
assert agent_api.instances["chat-1"].calls == ["hello"]
@pytest.mark.asyncio
async def test_real_platform_client_settings_are_local():
client = RealPlatformClient(
agent_api=FakeAgentApiFactory(),
prototype_state=PrototypeStateStore(),
platform="matrix",
)
await client.update_settings(
"usr-matrix-u1",
SettingsAction(action="toggle_skill", payload={"skill": "browser", "enabled": True}),
)
settings = await client.get_settings("usr-matrix-u1")
assert isinstance(settings, UserSettings)
assert settings.skills["browser"] is True
assert settings.skills["web-search"] is True