789 lines
25 KiB
Python
789 lines
25 KiB
Python
import asyncio
|
|
|
|
import pytest
|
|
from lambda_agent_api.server import MsgEventEnd, MsgEventTextChunk
|
|
|
|
import sdk.agent_api_wrapper as agent_api_wrapper_module
|
|
from core.protocol import SettingsAction
|
|
from sdk.agent_api_wrapper import AgentApiWrapper
|
|
from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformError, UserSettings
|
|
from sdk.prototype_state import PrototypeStateStore
|
|
from sdk.real import RealPlatformClient
|
|
|
|
|
|
class FakeChunk:
|
|
def __init__(self, text: str) -> None:
|
|
self.text = text
|
|
|
|
|
|
class FakeChatAgentApi:
|
|
def __init__(self, chat_id: str) -> None:
|
|
self.chat_id = chat_id
|
|
self.calls: list[str] = []
|
|
self.connect_calls = 0
|
|
self.close_calls = 0
|
|
self.last_tokens_used = 0
|
|
|
|
async def connect(self) -> None:
|
|
self.connect_calls += 1
|
|
|
|
async def close(self) -> None:
|
|
self.close_calls += 1
|
|
|
|
async def send_message(self, text: str):
|
|
self.calls.append(text)
|
|
midpoint = len(text) // 2
|
|
yield FakeChunk(text[:midpoint])
|
|
yield FakeChunk(text[midpoint:])
|
|
self.last_tokens_used = 3
|
|
|
|
|
|
class FakeAgentApiFactory:
|
|
def __init__(self) -> None:
|
|
self.created_chat_ids: list[str] = []
|
|
self.instances: dict[str, FakeChatAgentApi] = {}
|
|
|
|
def for_chat(self, chat_id: str) -> FakeChatAgentApi:
|
|
chat_api = FakeChatAgentApi(chat_id)
|
|
self.created_chat_ids.append(chat_id)
|
|
self.instances[chat_id] = chat_api
|
|
return chat_api
|
|
|
|
|
|
class LegacyAgentApi:
|
|
def __init__(self) -> None:
|
|
self.calls: list[str] = []
|
|
self.last_tokens_used = 0
|
|
|
|
async def send_message(self, text: str):
|
|
self.calls.append(text)
|
|
yield FakeChunk(text[:2])
|
|
yield FakeChunk(text[2:])
|
|
self.last_tokens_used = 7
|
|
|
|
|
|
class BlockingChatAgentApi:
|
|
def __init__(self, chat_id: str) -> None:
|
|
self.chat_id = chat_id
|
|
self.calls: list[str] = []
|
|
self.connect_calls = 0
|
|
self.close_calls = 0
|
|
self.last_tokens_used = 0
|
|
self.active_calls = 0
|
|
self.max_active_calls = 0
|
|
self.started = asyncio.Event()
|
|
self.release = asyncio.Event()
|
|
|
|
async def connect(self) -> None:
|
|
self.connect_calls += 1
|
|
|
|
async def close(self) -> None:
|
|
self.close_calls += 1
|
|
|
|
async def send_message(self, text: str):
|
|
self.calls.append(text)
|
|
self.active_calls += 1
|
|
self.max_active_calls = max(self.max_active_calls, self.active_calls)
|
|
self.started.set()
|
|
await self.release.wait()
|
|
self.active_calls -= 1
|
|
yield FakeChunk(text)
|
|
self.last_tokens_used = len(text)
|
|
|
|
|
|
class AttachmentTrackingChatAgentApi:
|
|
def __init__(self, chat_id: str) -> None:
|
|
self.chat_id = chat_id
|
|
self.calls: list[tuple[str, list[str] | None]] = []
|
|
self.connect_calls = 0
|
|
self.close_calls = 0
|
|
self.last_tokens_used = 0
|
|
|
|
async def connect(self) -> None:
|
|
self.connect_calls += 1
|
|
|
|
async def close(self) -> None:
|
|
self.close_calls += 1
|
|
|
|
async def send_message(self, text: str, attachments: list[str] | None = None):
|
|
self.calls.append((text, attachments))
|
|
yield FakeChunk(text)
|
|
self.last_tokens_used = 5
|
|
|
|
|
|
class FlakyChatAgentApi:
|
|
def __init__(self, chat_id: str) -> None:
|
|
self.chat_id = chat_id
|
|
self.connect_calls = 0
|
|
self.close_calls = 0
|
|
|
|
async def connect(self) -> None:
|
|
self.connect_calls += 1
|
|
|
|
async def close(self) -> None:
|
|
self.close_calls += 1
|
|
|
|
async def send_message(self, text: str, attachments: list[str] | None = None):
|
|
raise ConnectionError("Connection closed")
|
|
yield
|
|
|
|
|
|
class SendFileEvent:
|
|
def __init__(self, *, workspace_path: str, mime_type: str, filename: str, size: int) -> None:
|
|
self.type = "AGENT_EVENT_SEND_FILE"
|
|
self.workspace_path = workspace_path
|
|
self.mime_type = mime_type
|
|
self.filename = filename
|
|
self.size = size
|
|
|
|
|
|
class TextChunkEvent:
|
|
def __init__(self, text: str) -> None:
|
|
self.type = "AGENT_EVENT_TEXT_CHUNK"
|
|
self.text = text
|
|
|
|
|
|
class ToolCallChunkEvent:
|
|
def __init__(self, payload: str) -> None:
|
|
self.type = "AGENT_EVENT_TOOL_CALL_CHUNK"
|
|
self.payload = payload
|
|
|
|
|
|
class ToolResultEvent:
|
|
def __init__(self, payload: str) -> None:
|
|
self.type = "AGENT_EVENT_TOOL_RESULT"
|
|
self.payload = payload
|
|
|
|
|
|
class CustomUpdateEvent:
|
|
def __init__(self, payload: str) -> None:
|
|
self.type = "AGENT_EVENT_CUSTOM_UPDATE"
|
|
self.payload = payload
|
|
|
|
|
|
class EndEvent:
|
|
def __init__(self, tokens_used: int) -> None:
|
|
self.type = "AGENT_EVENT_END"
|
|
self.tokens_used = tokens_used
|
|
|
|
|
|
class ErrorEvent:
|
|
def __init__(self, code: str, details: str) -> None:
|
|
self.type = "ERROR"
|
|
self.code = code
|
|
self.details = details
|
|
|
|
|
|
class GracefulDisconnectEvent:
|
|
def __init__(self) -> None:
|
|
self.type = "GRACEFUL_DISCONNECT"
|
|
|
|
|
|
class FakeWSMessage:
|
|
def __init__(self, data: str) -> None:
|
|
self.type = agent_api_wrapper_module.aiohttp.WSMsgType.TEXT
|
|
self.data = data
|
|
|
|
|
|
class FakeWebSocket:
|
|
def __init__(self, messages: list[FakeWSMessage]) -> None:
|
|
self._messages = list(messages)
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if not self._messages:
|
|
raise StopAsyncIteration
|
|
return self._messages.pop(0)
|
|
|
|
|
|
class QueueFeedingWebSocket:
|
|
def __init__(self, owner, queued_events: list[object]) -> None:
|
|
self.owner = owner
|
|
self.queued_events = list(queued_events)
|
|
self.sent_payloads: list[str] = []
|
|
|
|
async def send_str(self, payload: str) -> None:
|
|
self.sent_payloads.append(payload)
|
|
for event in self.queued_events:
|
|
await self.owner._current_queue.put(event)
|
|
|
|
|
|
class SilentWebSocket:
|
|
def __init__(self) -> None:
|
|
self.sent_payloads: list[str] = []
|
|
|
|
async def send_str(self, payload: str) -> None:
|
|
self.sent_payloads.append(payload)
|
|
|
|
|
|
class MessageResponseWithAttachments(MessageResponse):
|
|
attachments: list[Attachment] = []
|
|
|
|
|
|
def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch):
|
|
calls: list[dict[str, object]] = []
|
|
|
|
def fake_init(self, agent_id, base_url, chat_id, **kwargs):
|
|
calls.append(
|
|
{
|
|
"agent_id": agent_id,
|
|
"base_url": base_url,
|
|
"chat_id": chat_id,
|
|
"kwargs": kwargs,
|
|
}
|
|
)
|
|
self.id = agent_id
|
|
self.url = base_url
|
|
self.callback = kwargs.get("callback")
|
|
self.on_disconnect = kwargs.get("on_disconnect")
|
|
|
|
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
|
|
|
wrapper = AgentApiWrapper(
|
|
agent_id="agent-1",
|
|
base_url="https://agent.example.com/v1/agent_ws",
|
|
chat_id="chat-1",
|
|
callback="cb",
|
|
on_disconnect="disconnect",
|
|
)
|
|
child = wrapper.for_chat("chat-2")
|
|
|
|
assert calls == [
|
|
{
|
|
"agent_id": "agent-1",
|
|
"base_url": "https://agent.example.com",
|
|
"chat_id": "chat-1",
|
|
"kwargs": {"callback": "cb", "on_disconnect": "disconnect"},
|
|
},
|
|
{
|
|
"agent_id": "agent-1",
|
|
"base_url": "https://agent.example.com",
|
|
"chat_id": "chat-2",
|
|
"kwargs": {"callback": "cb", "on_disconnect": "disconnect"},
|
|
},
|
|
]
|
|
assert wrapper._base_url == "https://agent.example.com"
|
|
assert wrapper.chat_id == "chat-1"
|
|
assert wrapper.last_tokens_used == 0
|
|
assert child.chat_id == "chat-2"
|
|
|
|
|
|
def test_agent_api_wrapper_falls_back_to_legacy_url_constructor(monkeypatch):
|
|
calls: list[dict[str, object]] = []
|
|
|
|
def fake_init(self, agent_id, url, callback=None, on_disconnect=None):
|
|
calls.append(
|
|
{
|
|
"agent_id": agent_id,
|
|
"url": url,
|
|
"callback": callback,
|
|
"on_disconnect": on_disconnect,
|
|
}
|
|
)
|
|
self.id = agent_id
|
|
self.url = url
|
|
self.callback = callback
|
|
self.on_disconnect = on_disconnect
|
|
|
|
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
|
|
|
wrapper = AgentApiWrapper(
|
|
agent_id="agent-2",
|
|
url="https://agent.example.com/agent_ws/",
|
|
chat_id="chat-9",
|
|
callback="cb",
|
|
)
|
|
|
|
assert calls == [
|
|
{
|
|
"agent_id": "agent-2",
|
|
"url": "https://agent.example.com/agent_ws/?thread_id=chat-9",
|
|
"callback": "cb",
|
|
"on_disconnect": None,
|
|
}
|
|
]
|
|
assert wrapper._base_url == "https://agent.example.com"
|
|
assert wrapper.chat_id == "chat-9"
|
|
assert wrapper.last_tokens_used == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_api_wrapper_recovers_late_text_after_first_end(monkeypatch):
|
|
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
|
self.id = agent_id
|
|
self.url = base_url
|
|
self.callback = kwargs.get("callback")
|
|
self.on_disconnect = kwargs.get("on_disconnect")
|
|
|
|
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
|
|
|
wrapper = AgentApiWrapper(
|
|
agent_id="agent-1",
|
|
base_url="https://agent.example.com/v1/agent_ws",
|
|
chat_id="chat-1",
|
|
)
|
|
wrapper._connected = True
|
|
wrapper._request_lock = asyncio.Lock()
|
|
wrapper._current_queue = None
|
|
wrapper._ws = QueueFeedingWebSocket(
|
|
wrapper,
|
|
[
|
|
MsgEventTextChunk(text="Иллюстра"),
|
|
MsgEventEnd(tokens_used=5),
|
|
MsgEventTextChunk(text="ция"),
|
|
MsgEventEnd(tokens_used=5),
|
|
],
|
|
)
|
|
|
|
chunks = []
|
|
async for chunk in wrapper.send_message("hello"):
|
|
chunks.append(chunk)
|
|
|
|
assert [chunk.text for chunk in chunks] == ["Иллюстра", "ция"]
|
|
assert wrapper.last_tokens_used == 5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_api_wrapper_times_out_on_idle_stream(monkeypatch):
|
|
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
|
self.id = agent_id
|
|
self.url = base_url
|
|
self.callback = kwargs.get("callback")
|
|
self.on_disconnect = kwargs.get("on_disconnect")
|
|
|
|
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
|
monkeypatch.setattr(agent_api_wrapper_module, "_STREAM_IDLE_TIMEOUT_MS", 10)
|
|
|
|
wrapper = AgentApiWrapper(
|
|
agent_id="agent-1",
|
|
base_url="https://agent.example.com/v1/agent_ws",
|
|
chat_id="chat-1",
|
|
)
|
|
wrapper._connected = True
|
|
wrapper._request_lock = asyncio.Lock()
|
|
wrapper._current_queue = None
|
|
wrapper._ws = SilentWebSocket()
|
|
|
|
with pytest.raises(agent_api_wrapper_module.AgentException, match="Timed out waiting"):
|
|
async for _ in wrapper.send_message("hello"):
|
|
pass
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_get_or_create_user_uses_local_state():
|
|
client = RealPlatformClient(
|
|
agent_api=FakeAgentApiFactory(),
|
|
prototype_state=PrototypeStateStore(),
|
|
)
|
|
|
|
first = await client.get_or_create_user("u1", "matrix", "Alice")
|
|
second = await client.get_or_create_user("u1", "matrix")
|
|
|
|
assert first.user_id == "usr-matrix-u1"
|
|
assert first.is_new is True
|
|
assert second.user_id == first.user_id
|
|
assert second.is_new is False
|
|
assert second.display_name == "Alice"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_send_message_uses_chat_bound_client():
|
|
agent_api = FakeAgentApiFactory()
|
|
prototype_state = PrototypeStateStore()
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=prototype_state,
|
|
platform="matrix",
|
|
)
|
|
|
|
result = await client.send_message("@alice:example.org", "chat-7", "hello")
|
|
|
|
assert result == MessageResponse(
|
|
message_id="@alice:example.org",
|
|
response="hello",
|
|
tokens_used=3,
|
|
finished=True,
|
|
)
|
|
assert agent_api.created_chat_ids == ["chat-7"]
|
|
assert agent_api.instances["chat-7"].chat_id == "chat-7"
|
|
assert agent_api.instances["chat-7"].calls == ["hello"]
|
|
assert agent_api.instances["chat-7"].connect_calls == 1
|
|
assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_forwards_attachments_to_chat_api():
|
|
agent_api = AttachmentTrackingChatAgentApi("chat-7")
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
attachment = Attachment(
|
|
workspace_path="surfaces/matrix/alice/room/inbox/report.pdf",
|
|
mime_type="application/pdf",
|
|
filename="report.pdf",
|
|
size=123,
|
|
)
|
|
|
|
result = await client.send_message(
|
|
"@alice:example.org",
|
|
"chat-7",
|
|
"hello",
|
|
attachments=[attachment],
|
|
)
|
|
|
|
assert agent_api.calls == [("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"])]
|
|
assert result.response == "hello"
|
|
assert result.tokens_used == 5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_preserves_send_file_events_in_sync_result(monkeypatch):
|
|
agent_api = AttachmentTrackingChatAgentApi("chat-7")
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
class FileEventAgentApi(AttachmentTrackingChatAgentApi):
|
|
async def send_message(self, text: str, attachments: list[str] | None = None):
|
|
self.calls.append((text, attachments))
|
|
yield TextChunkEvent("he")
|
|
yield SendFileEvent(
|
|
workspace_path="/workspace/report.pdf",
|
|
mime_type="application/pdf",
|
|
filename="report.pdf",
|
|
size=123,
|
|
)
|
|
yield TextChunkEvent("llo")
|
|
self.last_tokens_used = 9
|
|
|
|
monkeypatch.setattr(
|
|
"sdk.real.MessageResponse",
|
|
MessageResponseWithAttachments,
|
|
)
|
|
client._agent_api = FileEventAgentApi("chat-7")
|
|
|
|
result = await client.send_message("@alice:example.org", "chat-7", "hello")
|
|
|
|
assert result.response == "hello"
|
|
assert result.tokens_used == 9
|
|
assert result.attachments == [
|
|
Attachment(
|
|
url="/workspace/report.pdf",
|
|
mime_type="application/pdf",
|
|
filename="report.pdf",
|
|
size=123,
|
|
workspace_path="report.pdf",
|
|
)
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_works_with_legacy_agent_api_without_for_chat():
|
|
legacy_api = LegacyAgentApi()
|
|
client = RealPlatformClient(
|
|
agent_api=legacy_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
result = await client.send_message("@alice:example.org", "chat-legacy", "hello")
|
|
|
|
assert result == MessageResponse(
|
|
message_id="@alice:example.org",
|
|
response="hello",
|
|
tokens_used=7,
|
|
finished=True,
|
|
)
|
|
assert legacy_api.calls == ["hello"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_reuses_cached_chat_client():
|
|
agent_api = FakeAgentApiFactory()
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
await client.send_message("@alice:example.org", "chat-1", "hello")
|
|
await client.send_message("@alice:example.org", "chat-1", "again")
|
|
|
|
assert agent_api.created_chat_ids == ["chat-1"]
|
|
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
|
assert agent_api.instances["chat-1"].connect_calls == 1
|
|
assert agent_api.instances["chat-1"].close_calls == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_wraps_connection_closed_as_platform_error():
|
|
agent_api = FakeAgentApiFactory()
|
|
agent_api.instances["chat-1"] = FlakyChatAgentApi("chat-1")
|
|
agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(
|
|
chat_id, FlakyChatAgentApi(chat_id)
|
|
)
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
with pytest.raises(PlatformError, match="Connection closed") as exc_info:
|
|
await client.send_message("@alice:example.org", "chat-1", "hello")
|
|
|
|
assert exc_info.value.code == "PLATFORM_CONNECTION_ERROR"
|
|
assert "chat-1" not in client._chat_apis
|
|
assert agent_api.instances["chat-1"].close_calls == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_reconnects_after_closed_chat_api():
|
|
agent_api = FakeAgentApiFactory()
|
|
flaky = FlakyChatAgentApi("chat-1")
|
|
healthy = AttachmentTrackingChatAgentApi("chat-1")
|
|
provided = iter([flaky, healthy])
|
|
|
|
def for_chat(chat_id: str):
|
|
chat_api = next(provided)
|
|
agent_api.created_chat_ids.append(chat_id)
|
|
agent_api.instances[chat_id] = chat_api
|
|
return chat_api
|
|
|
|
agent_api.for_chat = for_chat
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
with pytest.raises(PlatformError, match="Connection closed"):
|
|
await client.send_message("@alice:example.org", "chat-1", "hello")
|
|
|
|
result = await client.send_message("@alice:example.org", "chat-1", "again")
|
|
|
|
assert result.response == "again"
|
|
assert agent_api.created_chat_ids == ["chat-1", "chat-1"]
|
|
assert healthy.calls == [("again", None)]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_creates_chat_client_atomically_for_concurrent_requests():
|
|
agent_api = FakeAgentApiFactory()
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
results = await asyncio.gather(
|
|
client.send_message("@alice:example.org", "chat-1", "hello"),
|
|
client.send_message("@alice:example.org", "chat-1", "again"),
|
|
)
|
|
|
|
assert [result.response for result in results] == ["hello", "again"]
|
|
assert agent_api.created_chat_ids == ["chat-1"]
|
|
assert agent_api.instances["chat-1"].connect_calls == 1
|
|
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_creates_distinct_clients_per_chat():
|
|
agent_api = FakeAgentApiFactory()
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
await client.send_message("@alice:example.org", "chat-1", "hello")
|
|
await client.send_message("@alice:example.org", "chat-2", "world")
|
|
|
|
assert agent_api.created_chat_ids == ["chat-1", "chat-2"]
|
|
assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"]
|
|
assert agent_api.instances["chat-1"].calls == ["hello"]
|
|
assert agent_api.instances["chat-2"].calls == ["world"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_serializes_same_chat_streams_across_send_paths():
|
|
agent_api = FakeAgentApiFactory()
|
|
agent_api.instances["chat-1"] = BlockingChatAgentApi("chat-1")
|
|
agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(
|
|
chat_id, BlockingChatAgentApi(chat_id)
|
|
)
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
async def consume_stream():
|
|
chunks = []
|
|
async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
|
|
chunks.append(chunk)
|
|
return chunks
|
|
|
|
stream_task = asyncio.create_task(consume_stream())
|
|
await asyncio.wait_for(agent_api.instances["chat-1"].started.wait(), timeout=1)
|
|
|
|
send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again"))
|
|
await asyncio.sleep(0)
|
|
|
|
assert agent_api.instances["chat-1"].calls == ["hello"]
|
|
assert agent_api.instances["chat-1"].max_active_calls == 1
|
|
|
|
agent_api.instances["chat-1"].release.set()
|
|
stream_chunks = await stream_task
|
|
send_result = await send_task
|
|
|
|
assert [chunk.delta for chunk in stream_chunks] == ["hello", ""]
|
|
assert send_result.response == "again"
|
|
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
|
assert agent_api.instances["chat-1"].max_active_calls == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
|
|
agent_api = FakeAgentApiFactory()
|
|
client = RealPlatformClient(
|
|
agent_api=agent_api,
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
chunks = []
|
|
async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
|
|
chunks.append(chunk)
|
|
|
|
assert chunks == [
|
|
MessageChunk(
|
|
message_id="@alice:example.org",
|
|
delta="he",
|
|
finished=False,
|
|
tokens_used=0,
|
|
),
|
|
MessageChunk(
|
|
message_id="@alice:example.org",
|
|
delta="llo",
|
|
finished=False,
|
|
tokens_used=0,
|
|
),
|
|
MessageChunk(
|
|
message_id="@alice:example.org",
|
|
delta="",
|
|
finished=True,
|
|
tokens_used=3,
|
|
),
|
|
]
|
|
assert agent_api.created_chat_ids == ["chat-1"]
|
|
assert agent_api.instances["chat-1"].calls == ["hello"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_platform_client_settings_are_local():
|
|
client = RealPlatformClient(
|
|
agent_api=FakeAgentApiFactory(),
|
|
prototype_state=PrototypeStateStore(),
|
|
platform="matrix",
|
|
)
|
|
|
|
await client.update_settings(
|
|
"usr-matrix-u1",
|
|
SettingsAction(action="toggle_skill", payload={"skill": "browser", "enabled": True}),
|
|
)
|
|
|
|
settings = await client.get_settings("usr-matrix-u1")
|
|
|
|
assert isinstance(settings, UserSettings)
|
|
assert settings.skills["browser"] is True
|
|
assert settings.skills["web-search"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_api_wrapper_transparently_surfaces_modern_events(monkeypatch):
|
|
callback_events: list[object] = []
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
event_map = {
|
|
"text": TextChunkEvent("he"),
|
|
"tool_call": ToolCallChunkEvent("call"),
|
|
"tool_result": ToolResultEvent("result"),
|
|
"custom_update": CustomUpdateEvent("update"),
|
|
"send_file": SendFileEvent(
|
|
workspace_path="/workspace/report.pdf",
|
|
mime_type="application/pdf",
|
|
filename="report.pdf",
|
|
size=123,
|
|
),
|
|
"end": EndEvent(tokens_used=11),
|
|
"error": ErrorEvent(code="BOOM", details="bad things"),
|
|
"disconnect": GracefulDisconnectEvent(),
|
|
}
|
|
|
|
def fake_validate_json(data: str):
|
|
return event_map[data]
|
|
|
|
monkeypatch.setattr(
|
|
agent_api_wrapper_module,
|
|
"ServerMessage",
|
|
type("FakeServerMessage", (), {"validate_json": staticmethod(fake_validate_json)}),
|
|
)
|
|
|
|
async def fake_cleanup(self):
|
|
return None
|
|
|
|
monkeypatch.setattr(agent_api_wrapper_module.AgentApiWrapper, "_cleanup", fake_cleanup)
|
|
monkeypatch.setattr(
|
|
agent_api_wrapper_module.AgentApi,
|
|
"__init__",
|
|
lambda self, agent_id, base_url=None, chat_id=0, **kwargs: (
|
|
setattr(self, "id", agent_id)
|
|
or setattr(self, "callback", kwargs.get("callback"))
|
|
or setattr(self, "on_disconnect", kwargs.get("on_disconnect"))
|
|
or setattr(self, "_current_queue", None)
|
|
),
|
|
)
|
|
|
|
wrapper = AgentApiWrapper(
|
|
agent_id="agent-1",
|
|
base_url="https://agent.example.com/v1/agent_ws",
|
|
chat_id="chat-1",
|
|
callback=callback_events.append,
|
|
)
|
|
wrapper._current_queue = queue
|
|
wrapper._ws = FakeWebSocket(
|
|
[
|
|
FakeWSMessage("text"),
|
|
FakeWSMessage("tool_call"),
|
|
FakeWSMessage("tool_result"),
|
|
FakeWSMessage("custom_update"),
|
|
FakeWSMessage("send_file"),
|
|
FakeWSMessage("end"),
|
|
FakeWSMessage("error"),
|
|
FakeWSMessage("disconnect"),
|
|
]
|
|
)
|
|
|
|
await wrapper._listen()
|
|
|
|
queue_events = []
|
|
while not queue.empty():
|
|
queue_events.append(await queue.get())
|
|
|
|
assert queue_events[0].text == "he"
|
|
assert any(isinstance(event, SendFileEvent) for event in queue_events)
|
|
assert any(isinstance(event, EndEvent) for event in queue_events)
|
|
assert any(isinstance(event, GracefulDisconnectEvent) for event in queue_events)
|
|
assert callback_events[0].payload == "call"
|
|
assert callback_events[1].payload == "result"
|
|
assert callback_events[2].payload == "update"
|
|
assert any(isinstance(event, SendFileEvent) for event in callback_events)
|
|
assert any(isinstance(event, EndEvent) for event in callback_events)
|
|
assert any(isinstance(event, ErrorEvent) for event in callback_events)
|
|
assert any(isinstance(event, GracefulDisconnectEvent) for event in callback_events)
|
|
assert wrapper.last_tokens_used == 11
|