fix: use direct agent api per request
This commit is contained in:
parent
7d270d3d31
commit
7d58dd1caf
14 changed files with 285 additions and 400 deletions
|
|
@ -908,34 +908,21 @@ async def test_prepare_live_sync_returns_next_batch_from_bootstrap_sync():
|
|||
|
||||
|
||||
async def test_build_runtime_uses_real_platform_when_matrix_backend_is_real(monkeypatch):
|
||||
bot_module = importlib.import_module("adapter.matrix.bot")
|
||||
|
||||
class FakeAgentApiWrapper:
|
||||
def __init__(self, agent_id: str, base_url: str) -> None:
|
||||
self.agent_id = agent_id
|
||||
self.base_url = base_url
|
||||
|
||||
monkeypatch.setattr(bot_module, "AgentApiWrapper", FakeAgentApiWrapper)
|
||||
monkeypatch.setenv("MATRIX_PLATFORM_BACKEND", "real")
|
||||
monkeypatch.setenv("AGENT_WS_URL", "ws://agent.example/agent_ws/")
|
||||
monkeypatch.setenv("AGENT_BASE_URL", "http://agent.example")
|
||||
|
||||
runtime = build_runtime()
|
||||
|
||||
assert isinstance(runtime.platform, RealPlatformClient)
|
||||
assert runtime.platform.agent_api.base_url == "ws://agent.example/agent_ws/"
|
||||
assert runtime.platform.agent_base_url == "http://agent.example"
|
||||
assert runtime.platform.agent_id == "matrix-bot"
|
||||
|
||||
|
||||
async def test_matrix_main_closes_platform_without_connecting_root_agent(monkeypatch):
|
||||
bot_module = importlib.import_module("adapter.matrix.bot")
|
||||
|
||||
platform_close = AsyncMock()
|
||||
agent_connect = AsyncMock()
|
||||
runtime = SimpleNamespace(
|
||||
platform=SimpleNamespace(
|
||||
close=platform_close,
|
||||
agent_api=SimpleNamespace(connect=agent_connect),
|
||||
)
|
||||
)
|
||||
runtime = SimpleNamespace(platform=SimpleNamespace(close=platform_close))
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -959,7 +946,6 @@ async def test_matrix_main_closes_platform_without_connecting_root_agent(monkeyp
|
|||
|
||||
await bot_module.main()
|
||||
|
||||
agent_connect.assert_not_awaited()
|
||||
platform_close.assert_awaited_once()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ Smoke test: полный цикл через dispatcher + реальные manag
|
|||
Имитирует что делает адаптер (Telegram или Matrix) при получении события.
|
||||
"""
|
||||
import pytest
|
||||
from lambda_agent_api.server import MsgEventTextChunk
|
||||
|
||||
from core.auth import AuthManager
|
||||
from core.chat import ChatManager
|
||||
|
|
@ -23,10 +22,13 @@ from core.store import InMemoryStore
|
|||
from sdk.mock import MockPlatformClient
|
||||
from sdk.prototype_state import PrototypeStateStore
|
||||
from sdk.real import RealPlatformClient
|
||||
from sdk.upstream_agent_api import MsgEventTextChunk
|
||||
|
||||
|
||||
class FakeAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
|
||||
self.agent_id = agent_id
|
||||
self.base_url = base_url
|
||||
self.chat_id = chat_id
|
||||
self.calls: list[tuple[str, list[str]]] = []
|
||||
self.connect_calls = 0
|
||||
|
|
@ -46,12 +48,12 @@ class FakeAgentApi:
|
|||
class FakeAgentApiFactory:
|
||||
def __init__(self) -> None:
|
||||
self.created_chat_ids: list[str] = []
|
||||
self.instances: dict[str, FakeAgentApi] = {}
|
||||
self.instances: dict[str, list[FakeAgentApi]] = {}
|
||||
|
||||
def for_chat(self, chat_id: str) -> FakeAgentApi:
|
||||
chat_api = FakeAgentApi(chat_id)
|
||||
def __call__(self, agent_id: str, base_url: str, chat_id: str) -> FakeAgentApi:
|
||||
chat_api = FakeAgentApi(agent_id, base_url, chat_id)
|
||||
self.created_chat_ids.append(chat_id)
|
||||
self.instances[chat_id] = chat_api
|
||||
self.instances.setdefault(chat_id, []).append(chat_api)
|
||||
return chat_api
|
||||
|
||||
|
||||
|
|
@ -73,7 +75,9 @@ def dispatcher():
|
|||
def real_dispatcher():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
platform = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
agent_id="matrix-bot",
|
||||
agent_base_url="http://platform-agent:8000",
|
||||
agent_api_cls=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
|
@ -147,7 +151,7 @@ async def test_toggle_skill_callback(dispatcher):
|
|||
assert any("browser" in r.text for r in result if isinstance(r, OutgoingMessage))
|
||||
|
||||
|
||||
async def test_full_flow_with_real_platform_uses_shared_agent_api(real_dispatcher):
|
||||
async def test_full_flow_with_real_platform_uses_direct_agent_api(real_dispatcher):
|
||||
dispatcher, agent_api = real_dispatcher
|
||||
|
||||
start = IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="start")
|
||||
|
|
@ -160,7 +164,7 @@ async def test_full_flow_with_real_platform_uses_shared_agent_api(real_dispatche
|
|||
|
||||
assert texts == ["[REAL] Привет!"]
|
||||
assert agent_api.created_chat_ids == ["C1"]
|
||||
assert agent_api.instances["C1"].calls == [("Привет!", [])]
|
||||
assert [instance.calls for instance in agent_api.instances["C1"]] == [[("Привет!", [])]]
|
||||
|
||||
|
||||
async def test_full_flow_with_real_platform_forwards_workspace_attachment(real_dispatcher):
|
||||
|
|
@ -185,6 +189,6 @@ async def test_full_flow_with_real_platform_forwards_workspace_attachment(real_d
|
|||
)
|
||||
await dispatcher.dispatch(msg)
|
||||
|
||||
assert agent_api.instances["C1"].calls == [
|
||||
("Посмотри файл", ["surfaces/matrix/u1/room/inbox/report.pdf"])
|
||||
assert [instance.calls for instance in agent_api.instances["C1"]] == [
|
||||
[("Посмотри файл", ["surfaces/matrix/u1/room/inbox/report.pdf"])]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,16 +1,10 @@
|
|||
"""Compatibility tests after the Phase 4 migration."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_api_root = Path(__file__).resolve().parents[2] / "external" / "platform-agent_api"
|
||||
if str(_api_root) not in sys.path:
|
||||
sys.path.insert(0, str(_api_root))
|
||||
|
||||
|
||||
def test_lambda_agent_api_module_is_importable():
|
||||
from lambda_agent_api.agent_api import AgentApi
|
||||
from sdk.upstream_agent_api import AgentApi
|
||||
|
||||
assert AgentApi is not None
|
||||
|
||||
|
|
@ -18,4 +12,4 @@ def test_lambda_agent_api_module_is_importable():
|
|||
def test_agent_session_module_is_intentionally_stubbed():
|
||||
contents = Path(__file__).resolve().parents[2] / "sdk" / "agent_session.py"
|
||||
|
||||
assert "replaced by AgentApiWrapper" in contents.read_text()
|
||||
assert "replaced by direct AgentApi usage" in contents.read_text()
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk
|
||||
from pydantic import Field
|
||||
|
||||
import sdk.agent_api_wrapper as agent_api_wrapper_module
|
||||
from core.protocol import SettingsAction
|
||||
from sdk.agent_api_wrapper import AgentApiWrapper
|
||||
from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformError, UserSettings
|
||||
from sdk.prototype_state import PrototypeStateStore
|
||||
from sdk.real import RealPlatformClient
|
||||
from sdk.upstream_agent_api import MsgEventSendFile, MsgEventTextChunk
|
||||
|
||||
|
||||
class FakeChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
|
||||
self.agent_id = agent_id
|
||||
self.base_url = base_url
|
||||
self.chat_id = str(chat_id)
|
||||
self.calls: list[str] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
|
|
@ -33,155 +33,125 @@ class FakeChatAgentApi:
|
|||
|
||||
|
||||
class FakeAgentApiFactory:
|
||||
def __init__(self) -> None:
|
||||
self.created_chat_ids: list[str] = []
|
||||
self.instances: dict[str, FakeChatAgentApi] = {}
|
||||
def __init__(self, chat_api_cls=FakeChatAgentApi) -> None:
|
||||
self.chat_api_cls = chat_api_cls
|
||||
self.created_calls: list[tuple[str, str, str]] = []
|
||||
self.instances_by_chat: dict[str, list[FakeChatAgentApi]] = {}
|
||||
|
||||
def for_chat(self, chat_id: str) -> FakeChatAgentApi:
|
||||
chat_api = FakeChatAgentApi(chat_id)
|
||||
self.created_chat_ids.append(chat_id)
|
||||
self.instances[chat_id] = chat_api
|
||||
def __call__(self, agent_id: str, base_url: str, chat_id: str):
|
||||
chat_key = str(chat_id)
|
||||
chat_api = self.chat_api_cls(agent_id, base_url, chat_key)
|
||||
self.created_calls.append((agent_id, base_url, chat_key))
|
||||
self.instances_by_chat.setdefault(chat_key, []).append(chat_api)
|
||||
return chat_api
|
||||
|
||||
def latest(self, chat_id: str):
|
||||
return self.instances_by_chat[str(chat_id)][-1]
|
||||
|
||||
class BlockingChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.calls: list[str] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
|
||||
class BlockingTracker:
|
||||
def __init__(self) -> None:
|
||||
self.active_calls = 0
|
||||
self.max_active_calls = 0
|
||||
self.started = asyncio.Event()
|
||||
self.release = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
||||
async def close(self) -> None:
|
||||
self.close_calls += 1
|
||||
class BlockingChatAgentApi(FakeChatAgentApi):
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
base_url: str,
|
||||
chat_id: str,
|
||||
*,
|
||||
tracker: BlockingTracker,
|
||||
) -> None:
|
||||
super().__init__(agent_id, base_url, chat_id)
|
||||
self._tracker = tracker
|
||||
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append(text)
|
||||
self.active_calls += 1
|
||||
self.max_active_calls = max(self.max_active_calls, self.active_calls)
|
||||
self.started.set()
|
||||
await self.release.wait()
|
||||
self.active_calls -= 1
|
||||
self._tracker.active_calls += 1
|
||||
self._tracker.max_active_calls = max(
|
||||
self._tracker.max_active_calls,
|
||||
self._tracker.active_calls,
|
||||
)
|
||||
self._tracker.started.set()
|
||||
await self._tracker.release.wait()
|
||||
self._tracker.active_calls -= 1
|
||||
yield MsgEventTextChunk(text=text)
|
||||
|
||||
|
||||
class AttachmentTrackingChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
class BlockingAgentApiFactory(FakeAgentApiFactory):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tracker = BlockingTracker()
|
||||
|
||||
def __call__(self, agent_id: str, base_url: str, chat_id: str):
|
||||
chat_key = str(chat_id)
|
||||
chat_api = BlockingChatAgentApi(
|
||||
agent_id,
|
||||
base_url,
|
||||
chat_key,
|
||||
tracker=self.tracker,
|
||||
)
|
||||
self.created_calls.append((agent_id, base_url, chat_key))
|
||||
self.instances_by_chat.setdefault(chat_key, []).append(chat_api)
|
||||
return chat_api
|
||||
|
||||
|
||||
class AttachmentTrackingChatAgentApi(FakeChatAgentApi):
|
||||
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
|
||||
super().__init__(agent_id, base_url, chat_id)
|
||||
self.calls: list[tuple[str, list[str] | None]] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
||||
async def close(self) -> None:
|
||||
self.close_calls += 1
|
||||
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append((text, attachments))
|
||||
yield MsgEventTextChunk(text=text)
|
||||
|
||||
|
||||
class AttachmentTrackingAgentApiFactory:
|
||||
def __init__(self, chat_api_cls=AttachmentTrackingChatAgentApi) -> None:
|
||||
self.chat_api_cls = chat_api_cls
|
||||
self.created_chat_ids: list[str] = []
|
||||
self.instances: dict[str, AttachmentTrackingChatAgentApi] = {}
|
||||
|
||||
def for_chat(self, chat_id: str) -> AttachmentTrackingChatAgentApi:
|
||||
chat_api = self.chat_api_cls(chat_id)
|
||||
self.created_chat_ids.append(chat_id)
|
||||
self.instances[chat_id] = chat_api
|
||||
return chat_api
|
||||
|
||||
|
||||
class FlakyChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
||||
async def close(self) -> None:
|
||||
self.close_calls += 1
|
||||
|
||||
class FlakyChatAgentApi(FakeChatAgentApi):
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
raise ConnectionError("Connection closed")
|
||||
yield
|
||||
|
||||
|
||||
class ReuseSensitiveChatAgentApi(FakeChatAgentApi):
|
||||
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
|
||||
super().__init__(agent_id, base_url, chat_id)
|
||||
self._send_calls = 0
|
||||
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append(text)
|
||||
self._send_calls += 1
|
||||
if text == "first":
|
||||
yield MsgEventTextChunk(text="tool ok")
|
||||
return
|
||||
if text == "second" and self._send_calls == 1:
|
||||
yield MsgEventTextChunk(text="Missing")
|
||||
|
||||
|
||||
class MessageResponseWithAttachments(MessageResponse):
|
||||
attachments: list[Attachment] = Field(default_factory=list)
|
||||
|
||||
|
||||
def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
||||
captured["agent_id"] = agent_id
|
||||
captured["base_url"] = base_url
|
||||
captured["chat_id"] = chat_id
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||
|
||||
wrapper = AgentApiWrapper(
|
||||
agent_id="agent-1",
|
||||
base_url="ws://platform-agent:8000/v1/agent_ws/",
|
||||
chat_id="41",
|
||||
def make_real_platform_client(
|
||||
agent_api_cls,
|
||||
*,
|
||||
prototype_state: PrototypeStateStore | None = None,
|
||||
) -> RealPlatformClient:
|
||||
return RealPlatformClient(
|
||||
agent_id="matrix-bot",
|
||||
agent_base_url="http://platform-agent:8000",
|
||||
agent_api_cls=agent_api_cls,
|
||||
prototype_state=prototype_state or PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
||||
assert wrapper.chat_id == "41"
|
||||
assert wrapper._base_url == "ws://platform-agent:8000"
|
||||
assert captured == {
|
||||
"agent_id": "agent-1",
|
||||
"base_url": "ws://platform-agent:8000",
|
||||
"chat_id": "41",
|
||||
}
|
||||
|
||||
|
||||
def test_agent_api_wrapper_for_chat_reuses_normalized_base_url(monkeypatch):
|
||||
init_calls = []
|
||||
|
||||
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
||||
self.id = agent_id
|
||||
self.chat_id = chat_id
|
||||
self.url = base_url
|
||||
init_calls.append((agent_id, base_url, chat_id))
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||
|
||||
root = AgentApiWrapper(
|
||||
agent_id="agent-1",
|
||||
base_url="http://platform-agent:8000/v1/agent_ws/",
|
||||
chat_id="1",
|
||||
)
|
||||
|
||||
child = root.for_chat("99")
|
||||
|
||||
assert child is not root
|
||||
assert child.chat_id == "99"
|
||||
assert child._base_url == "http://platform-agent:8000"
|
||||
assert init_calls == [
|
||||
("agent-1", "http://platform-agent:8000", "1"),
|
||||
("agent-1", "http://platform-agent:8000", "99"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_get_or_create_user_uses_local_state():
|
||||
client = RealPlatformClient(
|
||||
agent_api=FakeAgentApiFactory(),
|
||||
prototype_state=PrototypeStateStore(),
|
||||
)
|
||||
client = make_real_platform_client(FakeAgentApiFactory())
|
||||
|
||||
first = await client.get_or_create_user("u1", "matrix", "Alice")
|
||||
second = await client.get_or_create_user("u1", "matrix")
|
||||
|
|
@ -194,14 +164,10 @@ async def test_real_platform_client_get_or_create_user_uses_local_state():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_send_message_uses_chat_bound_client():
|
||||
async def test_real_platform_client_send_message_uses_direct_agent_api_per_chat():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
prototype_state = PrototypeStateStore()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=prototype_state,
|
||||
platform="matrix",
|
||||
)
|
||||
client = make_real_platform_client(agent_api, prototype_state=prototype_state)
|
||||
|
||||
result = await client.send_message("@alice:example.org", "chat-7", "hello")
|
||||
|
||||
|
|
@ -211,21 +177,18 @@ async def test_real_platform_client_send_message_uses_chat_bound_client():
|
|||
tokens_used=0,
|
||||
finished=True,
|
||||
)
|
||||
assert agent_api.created_chat_ids == ["chat-7"]
|
||||
assert agent_api.instances["chat-7"].chat_id == "chat-7"
|
||||
assert agent_api.instances["chat-7"].calls == ["hello"]
|
||||
assert agent_api.instances["chat-7"].connect_calls == 1
|
||||
assert agent_api.created_calls == [("matrix-bot", "http://platform-agent:8000", "chat-7")]
|
||||
assert agent_api.latest("chat-7").chat_id == "chat-7"
|
||||
assert agent_api.latest("chat-7").calls == ["hello"]
|
||||
assert agent_api.latest("chat-7").connect_calls == 1
|
||||
assert agent_api.latest("chat-7").close_calls == 1
|
||||
assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_forwards_attachments_to_chat_api():
|
||||
agent_api = AttachmentTrackingAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=AttachmentTrackingChatAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
attachment = Attachment(
|
||||
workspace_path="surfaces/matrix/alice/room/inbox/report.pdf",
|
||||
mime_type="application/pdf",
|
||||
|
|
@ -240,7 +203,7 @@ async def test_real_platform_client_forwards_attachments_to_chat_api():
|
|||
attachments=[attachment],
|
||||
)
|
||||
|
||||
assert agent_api.instances["chat-7"].calls == [
|
||||
assert agent_api.latest("chat-7").calls == [
|
||||
("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"])
|
||||
]
|
||||
assert result.response == "hello"
|
||||
|
|
@ -256,17 +219,10 @@ async def test_real_platform_client_preserves_send_file_events_in_sync_result(mo
|
|||
yield MsgEventSendFile(path="report.pdf")
|
||||
yield MsgEventTextChunk(text="llo")
|
||||
|
||||
agent_api = AttachmentTrackingAgentApiFactory(chat_api_cls=FileEventAgentApi)
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=FileEventAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"sdk.real.MessageResponse",
|
||||
MessageResponseWithAttachments,
|
||||
)
|
||||
monkeypatch.setattr("sdk.real.MessageResponse", MessageResponseWithAttachments)
|
||||
|
||||
result = await client.send_message("@alice:example.org", "chat-7", "hello")
|
||||
|
||||
|
|
@ -284,63 +240,61 @@ async def test_real_platform_client_preserves_send_file_events_in_sync_result(mo
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_reuses_cached_chat_client():
|
||||
async def test_real_platform_client_uses_fresh_agent_connection_per_request():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
await client.send_message("@alice:example.org", "chat-1", "again")
|
||||
|
||||
assert agent_api.created_chat_ids == ["chat-1"]
|
||||
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
||||
assert agent_api.instances["chat-1"].connect_calls == 1
|
||||
assert agent_api.instances["chat-1"].close_calls == 0
|
||||
assert agent_api.created_calls == [
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
]
|
||||
assert [instance.calls for instance in agent_api.instances_by_chat["chat-1"]] == [
|
||||
["hello"],
|
||||
["again"],
|
||||
]
|
||||
assert all(instance.connect_calls == 1 for instance in agent_api.instances_by_chat["chat-1"])
|
||||
assert all(instance.close_calls == 1 for instance in agent_api.instances_by_chat["chat-1"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_avoids_reuse_sensitive_second_message_loss():
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=ReuseSensitiveChatAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
first = await client.send_message("@alice:example.org", "chat-1", "first")
|
||||
second = await client.send_message("@alice:example.org", "chat-1", "second")
|
||||
|
||||
assert first.response == "tool ok"
|
||||
assert second.response == "Missing"
|
||||
assert len(agent_api.instances_by_chat["chat-1"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_wraps_connection_closed_as_platform_error():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
agent_api.instances["chat-1"] = FlakyChatAgentApi("chat-1")
|
||||
agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(
|
||||
chat_id, FlakyChatAgentApi(chat_id)
|
||||
)
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=FlakyChatAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
with pytest.raises(PlatformError, match="Connection closed") as exc_info:
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
|
||||
assert exc_info.value.code == "PLATFORM_CONNECTION_ERROR"
|
||||
assert "chat-1" not in client._chat_apis
|
||||
assert agent_api.instances["chat-1"].close_calls == 1
|
||||
assert agent_api.latest("chat-1").close_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_reconnects_after_closed_chat_api():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
flaky = FlakyChatAgentApi("chat-1")
|
||||
healthy = AttachmentTrackingChatAgentApi("chat-1")
|
||||
provided = iter([flaky, healthy])
|
||||
async def test_real_platform_client_uses_fresh_connection_after_failure():
|
||||
class SometimesFlakyAgentApi(FakeChatAgentApi):
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
if text == "hello":
|
||||
raise ConnectionError("Connection closed")
|
||||
self.calls.append(text)
|
||||
yield MsgEventTextChunk(text=text)
|
||||
|
||||
def for_chat(chat_id: str):
|
||||
chat_api = next(provided)
|
||||
agent_api.created_chat_ids.append(chat_id)
|
||||
agent_api.instances[chat_id] = chat_api
|
||||
return chat_api
|
||||
|
||||
agent_api.for_chat = for_chat
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=SometimesFlakyAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
with pytest.raises(PlatformError, match="Connection closed"):
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
|
|
@ -348,60 +302,17 @@ async def test_real_platform_client_reconnects_after_closed_chat_api():
|
|||
result = await client.send_message("@alice:example.org", "chat-1", "again")
|
||||
|
||||
assert result.response == "again"
|
||||
assert agent_api.created_chat_ids == ["chat-1", "chat-1"]
|
||||
assert healthy.calls == [("again", None)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_creates_chat_client_atomically_for_concurrent_requests():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
client.send_message("@alice:example.org", "chat-1", "hello"),
|
||||
client.send_message("@alice:example.org", "chat-1", "again"),
|
||||
)
|
||||
|
||||
assert [result.response for result in results] == ["hello", "again"]
|
||||
assert agent_api.created_chat_ids == ["chat-1"]
|
||||
assert agent_api.instances["chat-1"].connect_calls == 1
|
||||
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_creates_distinct_clients_per_chat():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
await client.send_message("@alice:example.org", "chat-2", "world")
|
||||
|
||||
assert agent_api.created_chat_ids == ["chat-1", "chat-2"]
|
||||
assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"]
|
||||
assert agent_api.instances["chat-1"].calls == ["hello"]
|
||||
assert agent_api.instances["chat-2"].calls == ["world"]
|
||||
assert agent_api.created_calls == [
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
]
|
||||
assert agent_api.latest("chat-1").calls == ["again"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_serializes_same_chat_streams_across_send_paths():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
agent_api.instances["chat-1"] = BlockingChatAgentApi("chat-1")
|
||||
agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(
|
||||
chat_id, BlockingChatAgentApi(chat_id)
|
||||
)
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = BlockingAgentApiFactory()
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
async def consume_stream():
|
||||
chunks = []
|
||||
|
|
@ -410,32 +321,48 @@ async def test_real_platform_client_serializes_same_chat_streams_across_send_pat
|
|||
return chunks
|
||||
|
||||
stream_task = asyncio.create_task(consume_stream())
|
||||
await asyncio.wait_for(agent_api.instances["chat-1"].started.wait(), timeout=1)
|
||||
await asyncio.wait_for(agent_api.tracker.started.wait(), timeout=1)
|
||||
|
||||
send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again"))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert agent_api.instances["chat-1"].calls == ["hello"]
|
||||
assert agent_api.instances["chat-1"].max_active_calls == 1
|
||||
assert len(agent_api.instances_by_chat["chat-1"]) == 1
|
||||
assert agent_api.instances_by_chat["chat-1"][0].calls == ["hello"]
|
||||
assert agent_api.tracker.max_active_calls == 1
|
||||
|
||||
agent_api.instances["chat-1"].release.set()
|
||||
agent_api.tracker.release.set()
|
||||
stream_chunks = await stream_task
|
||||
send_result = await send_task
|
||||
|
||||
assert [chunk.delta for chunk in stream_chunks] == ["hello", ""]
|
||||
assert send_result.response == "again"
|
||||
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
||||
assert agent_api.instances["chat-1"].max_active_calls == 1
|
||||
assert [instance.calls for instance in agent_api.instances_by_chat["chat-1"]] == [
|
||||
["hello"],
|
||||
["again"],
|
||||
]
|
||||
assert agent_api.tracker.max_active_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_creates_distinct_connections_per_chat():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
await client.send_message("@alice:example.org", "chat-2", "world")
|
||||
|
||||
assert agent_api.created_calls == [
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-2"),
|
||||
]
|
||||
assert agent_api.latest("chat-1").calls == ["hello"]
|
||||
assert agent_api.latest("chat-2").calls == ["world"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
chunks = []
|
||||
async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
|
||||
|
|
@ -461,17 +388,14 @@ async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
|
|||
tokens_used=0,
|
||||
),
|
||||
]
|
||||
assert agent_api.created_chat_ids == ["chat-1"]
|
||||
assert agent_api.instances["chat-1"].calls == ["hello"]
|
||||
assert agent_api.created_calls == [("matrix-bot", "http://platform-agent:8000", "chat-1")]
|
||||
assert agent_api.latest("chat-1").calls == ["hello"]
|
||||
assert agent_api.latest("chat-1").close_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_settings_are_local():
|
||||
client = RealPlatformClient(
|
||||
agent_api=FakeAgentApiFactory(),
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
client = make_real_platform_client(FakeAgentApiFactory())
|
||||
|
||||
await client.update_settings(
|
||||
"usr-matrix-u1",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue