feat: support shared-workspace file flow for matrix
This commit is contained in:
parent
323a6d3144
commit
6422c7db58
18 changed files with 871 additions and 80 deletions
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from core.protocol import SettingsAction
|
||||
import sdk.agent_api_wrapper as agent_api_wrapper_module
|
||||
from sdk.agent_api_wrapper import AgentApiWrapper
|
||||
from sdk.interface import MessageChunk, MessageResponse, UserSettings
|
||||
from sdk.interface import Attachment, MessageChunk, MessageResponse, UserSettings
|
||||
from sdk.prototype_state import PrototypeStateStore
|
||||
from sdk.real import RealPlatformClient
|
||||
|
||||
|
|
@ -90,6 +90,100 @@ class BlockingChatAgentApi:
|
|||
self.last_tokens_used = len(text)
|
||||
|
||||
|
||||
class AttachmentTrackingChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.calls: list[tuple[str, list[str] | None]] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
self.last_tokens_used = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
||||
async def close(self) -> None:
|
||||
self.close_calls += 1
|
||||
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append((text, attachments))
|
||||
yield FakeChunk(text)
|
||||
self.last_tokens_used = 5
|
||||
|
||||
|
||||
class SendFileEvent:
|
||||
def __init__(self, *, workspace_path: str, mime_type: str, filename: str, size: int) -> None:
|
||||
self.type = "AGENT_EVENT_SEND_FILE"
|
||||
self.workspace_path = workspace_path
|
||||
self.mime_type = mime_type
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
|
||||
|
||||
class TextChunkEvent:
|
||||
def __init__(self, text: str) -> None:
|
||||
self.type = "AGENT_EVENT_TEXT_CHUNK"
|
||||
self.text = text
|
||||
|
||||
|
||||
class ToolCallChunkEvent:
|
||||
def __init__(self, payload: str) -> None:
|
||||
self.type = "AGENT_EVENT_TOOL_CALL_CHUNK"
|
||||
self.payload = payload
|
||||
|
||||
|
||||
class ToolResultEvent:
|
||||
def __init__(self, payload: str) -> None:
|
||||
self.type = "AGENT_EVENT_TOOL_RESULT"
|
||||
self.payload = payload
|
||||
|
||||
|
||||
class CustomUpdateEvent:
|
||||
def __init__(self, payload: str) -> None:
|
||||
self.type = "AGENT_EVENT_CUSTOM_UPDATE"
|
||||
self.payload = payload
|
||||
|
||||
|
||||
class EndEvent:
|
||||
def __init__(self, tokens_used: int) -> None:
|
||||
self.type = "AGENT_EVENT_END"
|
||||
self.tokens_used = tokens_used
|
||||
|
||||
|
||||
class ErrorEvent:
|
||||
def __init__(self, code: str, details: str) -> None:
|
||||
self.type = "ERROR"
|
||||
self.code = code
|
||||
self.details = details
|
||||
|
||||
|
||||
class GracefulDisconnectEvent:
|
||||
def __init__(self) -> None:
|
||||
self.type = "GRACEFUL_DISCONNECT"
|
||||
|
||||
|
||||
class FakeWSMessage:
|
||||
def __init__(self, data: str) -> None:
|
||||
self.type = agent_api_wrapper_module.aiohttp.WSMsgType.TEXT
|
||||
self.data = data
|
||||
|
||||
|
||||
class FakeWebSocket:
|
||||
def __init__(self, messages: list[FakeWSMessage]) -> None:
|
||||
self._messages = list(messages)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._messages:
|
||||
raise StopAsyncIteration
|
||||
return self._messages.pop(0)
|
||||
|
||||
|
||||
class MessageResponseWithAttachments(MessageResponse):
|
||||
attachments: list[Attachment] = []
|
||||
|
||||
|
||||
def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch):
|
||||
calls: list[dict[str, object]] = []
|
||||
|
||||
|
|
@ -219,6 +313,76 @@ async def test_real_platform_client_send_message_uses_chat_bound_client():
|
|||
assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_forwards_attachments_to_chat_api():
|
||||
agent_api = AttachmentTrackingChatAgentApi("chat-7")
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
attachment = Attachment(
|
||||
workspace_path="surfaces/matrix/alice/room/inbox/report.pdf",
|
||||
mime_type="application/pdf",
|
||||
filename="report.pdf",
|
||||
size=123,
|
||||
)
|
||||
|
||||
result = await client.send_message(
|
||||
"@alice:example.org",
|
||||
"chat-7",
|
||||
"hello",
|
||||
attachments=[attachment],
|
||||
)
|
||||
|
||||
assert agent_api.calls == [("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"])]
|
||||
assert result.response == "hello"
|
||||
assert result.tokens_used == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_preserves_send_file_events_in_sync_result(monkeypatch):
|
||||
agent_api = AttachmentTrackingChatAgentApi("chat-7")
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
||||
class FileEventAgentApi(AttachmentTrackingChatAgentApi):
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append((text, attachments))
|
||||
yield TextChunkEvent("he")
|
||||
yield SendFileEvent(
|
||||
workspace_path="/workspace/report.pdf",
|
||||
mime_type="application/pdf",
|
||||
filename="report.pdf",
|
||||
size=123,
|
||||
)
|
||||
yield TextChunkEvent("llo")
|
||||
self.last_tokens_used = 9
|
||||
|
||||
monkeypatch.setattr(
|
||||
"sdk.real.MessageResponse",
|
||||
MessageResponseWithAttachments,
|
||||
)
|
||||
client._agent_api = FileEventAgentApi("chat-7")
|
||||
|
||||
result = await client.send_message("@alice:example.org", "chat-7", "hello")
|
||||
|
||||
assert result.response == "hello"
|
||||
assert result.tokens_used == 9
|
||||
assert result.attachments == [
|
||||
Attachment(
|
||||
url="/workspace/report.pdf",
|
||||
mime_type="application/pdf",
|
||||
filename="report.pdf",
|
||||
size=123,
|
||||
workspace_path="report.pdf",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_works_with_legacy_agent_api_without_for_chat():
|
||||
legacy_api = LegacyAgentApi()
|
||||
|
|
@ -385,3 +549,85 @@ async def test_real_platform_client_settings_are_local():
|
|||
assert isinstance(settings, UserSettings)
|
||||
assert settings.skills["browser"] is True
|
||||
assert settings.skills["web-search"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_api_wrapper_transparently_surfaces_modern_events(monkeypatch):
|
||||
callback_events: list[object] = []
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
event_map = {
|
||||
"text": TextChunkEvent("he"),
|
||||
"tool_call": ToolCallChunkEvent("call"),
|
||||
"tool_result": ToolResultEvent("result"),
|
||||
"custom_update": CustomUpdateEvent("update"),
|
||||
"send_file": SendFileEvent(
|
||||
workspace_path="/workspace/report.pdf",
|
||||
mime_type="application/pdf",
|
||||
filename="report.pdf",
|
||||
size=123,
|
||||
),
|
||||
"end": EndEvent(tokens_used=11),
|
||||
"error": ErrorEvent(code="BOOM", details="bad things"),
|
||||
"disconnect": GracefulDisconnectEvent(),
|
||||
}
|
||||
|
||||
def fake_validate_json(data: str):
|
||||
return event_map[data]
|
||||
|
||||
monkeypatch.setattr(
|
||||
agent_api_wrapper_module,
|
||||
"ServerMessage",
|
||||
type("FakeServerMessage", (), {"validate_json": staticmethod(fake_validate_json)}),
|
||||
)
|
||||
|
||||
async def fake_cleanup(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApiWrapper, "_cleanup", fake_cleanup)
|
||||
monkeypatch.setattr(
|
||||
agent_api_wrapper_module.AgentApi,
|
||||
"__init__",
|
||||
lambda self, agent_id, base_url=None, chat_id=0, **kwargs: setattr(self, "id", agent_id)
|
||||
or setattr(self, "callback", kwargs.get("callback"))
|
||||
or setattr(self, "on_disconnect", kwargs.get("on_disconnect"))
|
||||
or setattr(self, "_current_queue", None),
|
||||
)
|
||||
|
||||
wrapper = AgentApiWrapper(
|
||||
agent_id="agent-1",
|
||||
base_url="https://agent.example.com/v1/agent_ws",
|
||||
chat_id="chat-1",
|
||||
callback=callback_events.append,
|
||||
)
|
||||
wrapper._current_queue = queue
|
||||
wrapper._ws = FakeWebSocket(
|
||||
[
|
||||
FakeWSMessage("text"),
|
||||
FakeWSMessage("tool_call"),
|
||||
FakeWSMessage("tool_result"),
|
||||
FakeWSMessage("custom_update"),
|
||||
FakeWSMessage("send_file"),
|
||||
FakeWSMessage("end"),
|
||||
FakeWSMessage("error"),
|
||||
FakeWSMessage("disconnect"),
|
||||
]
|
||||
)
|
||||
|
||||
await wrapper._listen()
|
||||
|
||||
queue_events = []
|
||||
while not queue.empty():
|
||||
queue_events.append(await queue.get())
|
||||
|
||||
assert queue_events[0].text == "he"
|
||||
assert any(isinstance(event, SendFileEvent) for event in queue_events)
|
||||
assert any(isinstance(event, EndEvent) for event in queue_events)
|
||||
assert any(isinstance(event, GracefulDisconnectEvent) for event in queue_events)
|
||||
assert callback_events[0].payload == "call"
|
||||
assert callback_events[1].payload == "result"
|
||||
assert callback_events[2].payload == "update"
|
||||
assert any(isinstance(event, SendFileEvent) for event in callback_events)
|
||||
assert any(isinstance(event, EndEvent) for event in callback_events)
|
||||
assert any(isinstance(event, ErrorEvent) for event in callback_events)
|
||||
assert any(isinstance(event, GracefulDisconnectEvent) for event in callback_events)
|
||||
assert wrapper.last_tokens_used == 11
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue