refactor: use thin upstream transport adapter
This commit is contained in:
parent
569824ead1
commit
0c2884c2b1
8 changed files with 420 additions and 255 deletions
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk
|
||||
from pydantic import Field
|
||||
|
||||
import sdk.agent_api_wrapper as agent_api_wrapper_module
|
||||
from core.protocol import SettingsAction
|
||||
|
|
@ -10,18 +12,12 @@ from sdk.prototype_state import PrototypeStateStore
|
|||
from sdk.real import RealPlatformClient
|
||||
|
||||
|
||||
class FakeChunk:
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
|
||||
class FakeChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.calls: list[str] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
self.last_tokens_used = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
|
@ -29,12 +25,11 @@ class FakeChatAgentApi:
|
|||
async def close(self) -> None:
|
||||
self.close_calls += 1
|
||||
|
||||
async def send_message(self, text: str):
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append(text)
|
||||
midpoint = len(text) // 2
|
||||
yield FakeChunk(text[:midpoint])
|
||||
yield FakeChunk(text[midpoint:])
|
||||
self.last_tokens_used = 3
|
||||
yield MsgEventTextChunk(text=text[:midpoint])
|
||||
yield MsgEventTextChunk(text=text[midpoint:])
|
||||
|
||||
|
||||
class FakeAgentApiFactory:
|
||||
|
|
@ -49,25 +44,12 @@ class FakeAgentApiFactory:
|
|||
return chat_api
|
||||
|
||||
|
||||
class LegacyAgentApi:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[str] = []
|
||||
self.last_tokens_used = 0
|
||||
|
||||
async def send_message(self, text: str):
|
||||
self.calls.append(text)
|
||||
yield FakeChunk(text[:2])
|
||||
yield FakeChunk(text[2:])
|
||||
self.last_tokens_used = 7
|
||||
|
||||
|
||||
class BlockingChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.calls: list[str] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
self.last_tokens_used = 0
|
||||
self.active_calls = 0
|
||||
self.max_active_calls = 0
|
||||
self.started = asyncio.Event()
|
||||
|
|
@ -79,15 +61,14 @@ class BlockingChatAgentApi:
|
|||
async def close(self) -> None:
|
||||
self.close_calls += 1
|
||||
|
||||
async def send_message(self, text: str):
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append(text)
|
||||
self.active_calls += 1
|
||||
self.max_active_calls = max(self.max_active_calls, self.active_calls)
|
||||
self.started.set()
|
||||
await self.release.wait()
|
||||
self.active_calls -= 1
|
||||
yield FakeChunk(text)
|
||||
self.last_tokens_used = len(text)
|
||||
yield MsgEventTextChunk(text=text)
|
||||
|
||||
|
||||
class AttachmentTrackingChatAgentApi:
|
||||
|
|
@ -96,7 +77,6 @@ class AttachmentTrackingChatAgentApi:
|
|||
self.calls: list[tuple[str, list[str] | None]] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
self.last_tokens_used = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
|
@ -106,8 +86,20 @@ class AttachmentTrackingChatAgentApi:
|
|||
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append((text, attachments))
|
||||
yield FakeChunk(text)
|
||||
self.last_tokens_used = 5
|
||||
yield MsgEventTextChunk(text=text)
|
||||
|
||||
|
||||
class AttachmentTrackingAgentApiFactory:
|
||||
def __init__(self, chat_api_cls=AttachmentTrackingChatAgentApi) -> None:
|
||||
self.chat_api_cls = chat_api_cls
|
||||
self.created_chat_ids: list[str] = []
|
||||
self.instances: dict[str, AttachmentTrackingChatAgentApi] = {}
|
||||
|
||||
def for_chat(self, chat_id: str) -> AttachmentTrackingChatAgentApi:
|
||||
chat_api = self.chat_api_cls(chat_id)
|
||||
self.created_chat_ids.append(chat_id)
|
||||
self.instances[chat_id] = chat_api
|
||||
return chat_api
|
||||
|
||||
|
||||
class FlakyChatAgentApi:
|
||||
|
|
@ -127,22 +119,8 @@ class FlakyChatAgentApi:
|
|||
yield
|
||||
|
||||
|
||||
class SendFileEvent:
|
||||
def __init__(self, *, workspace_path: str, mime_type: str, filename: str, size: int) -> None:
|
||||
self.type = "AGENT_EVENT_SEND_FILE"
|
||||
self.workspace_path = workspace_path
|
||||
self.mime_type = mime_type
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
|
||||
|
||||
class TextChunkEvent:
|
||||
def __init__(self, text: str) -> None:
|
||||
self.type = "AGENT_EVENT_TEXT_CHUNK"
|
||||
self.text = text
|
||||
|
||||
class MessageResponseWithAttachments(MessageResponse):
|
||||
attachments: list[Attachment] = []
|
||||
attachments: list[Attachment] = Field(default_factory=list)
|
||||
|
||||
|
||||
def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch):
|
||||
|
|
@ -230,19 +208,19 @@ async def test_real_platform_client_send_message_uses_chat_bound_client():
|
|||
assert result == MessageResponse(
|
||||
message_id="@alice:example.org",
|
||||
response="hello",
|
||||
tokens_used=3,
|
||||
tokens_used=0,
|
||||
finished=True,
|
||||
)
|
||||
assert agent_api.created_chat_ids == ["chat-7"]
|
||||
assert agent_api.instances["chat-7"].chat_id == "chat-7"
|
||||
assert agent_api.instances["chat-7"].calls == ["hello"]
|
||||
assert agent_api.instances["chat-7"].connect_calls == 1
|
||||
assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 3
|
||||
assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_forwards_attachments_to_chat_api():
|
||||
agent_api = AttachmentTrackingChatAgentApi("chat-7")
|
||||
agent_api = AttachmentTrackingAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
|
|
@ -262,74 +240,49 @@ async def test_real_platform_client_forwards_attachments_to_chat_api():
|
|||
attachments=[attachment],
|
||||
)
|
||||
|
||||
assert agent_api.calls == [("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"])]
|
||||
assert agent_api.instances["chat-7"].calls == [
|
||||
("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"])
|
||||
]
|
||||
assert result.response == "hello"
|
||||
assert result.tokens_used == 5
|
||||
assert result.tokens_used == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_preserves_send_file_events_in_sync_result(monkeypatch):
|
||||
agent_api = AttachmentTrackingChatAgentApi("chat-7")
|
||||
class FileEventAgentApi(AttachmentTrackingChatAgentApi):
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append((text, attachments))
|
||||
yield MsgEventTextChunk(text="he")
|
||||
yield MsgEventSendFile(path="report.pdf")
|
||||
yield MsgEventTextChunk(text="llo")
|
||||
|
||||
agent_api = AttachmentTrackingAgentApiFactory(chat_api_cls=FileEventAgentApi)
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
||||
class FileEventAgentApi(AttachmentTrackingChatAgentApi):
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append((text, attachments))
|
||||
yield TextChunkEvent("he")
|
||||
yield SendFileEvent(
|
||||
workspace_path="/workspace/report.pdf",
|
||||
mime_type="application/pdf",
|
||||
filename="report.pdf",
|
||||
size=123,
|
||||
)
|
||||
yield TextChunkEvent("llo")
|
||||
self.last_tokens_used = 9
|
||||
|
||||
monkeypatch.setattr(
|
||||
"sdk.real.MessageResponse",
|
||||
MessageResponseWithAttachments,
|
||||
)
|
||||
client._agent_api = FileEventAgentApi("chat-7")
|
||||
|
||||
result = await client.send_message("@alice:example.org", "chat-7", "hello")
|
||||
|
||||
assert result.response == "hello"
|
||||
assert result.tokens_used == 9
|
||||
assert result.tokens_used == 0
|
||||
assert result.attachments == [
|
||||
Attachment(
|
||||
url="/workspace/report.pdf",
|
||||
mime_type="application/pdf",
|
||||
url="report.pdf",
|
||||
mime_type="application/octet-stream",
|
||||
filename="report.pdf",
|
||||
size=123,
|
||||
size=None,
|
||||
workspace_path="report.pdf",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_works_with_legacy_agent_api_without_for_chat():
|
||||
legacy_api = LegacyAgentApi()
|
||||
client = RealPlatformClient(
|
||||
agent_api=legacy_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
||||
result = await client.send_message("@alice:example.org", "chat-legacy", "hello")
|
||||
|
||||
assert result == MessageResponse(
|
||||
message_id="@alice:example.org",
|
||||
response="hello",
|
||||
tokens_used=7,
|
||||
finished=True,
|
||||
)
|
||||
assert legacy_api.calls == ["hello"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_reuses_cached_chat_client():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
|
|
@ -505,7 +458,7 @@ async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
|
|||
message_id="@alice:example.org",
|
||||
delta="",
|
||||
finished=True,
|
||||
tokens_used=3,
|
||||
tokens_used=0,
|
||||
),
|
||||
]
|
||||
assert agent_api.created_chat_ids == ["chat-1"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue