diff --git a/.env.example b/.env.example index 54287aa..5c1cb66 100644 --- a/.env.example +++ b/.env.example @@ -11,7 +11,6 @@ MATRIX_PLATFORM_BACKEND=real SURFACES_WORKSPACE_DIR=/workspace # Compose-local platform-agent route -AGENT_WS_URL=ws://platform-agent:8000/v1/agent_ws/ AGENT_BASE_URL=http://platform-agent:8000 # platform-agent provider diff --git a/README.md b/README.md index 4c4a480..93782c6 100644 --- a/README.md +++ b/README.md @@ -69,8 +69,8 @@ surfaces-bot/ - **Диалог** — сообщения, вложения, подтверждения `!yes` / `!no` и routing через `EventDispatcher` - **Стабильность** — перед `sync_forever()` бот делает bootstrap sync и стартует с `since`, чтобы не переигрывать старую timeline после рестарта - **Текущее ограничение** — encrypted DM официально не поддержан; ручное тестирование Matrix ведётся в незашифрованных комнатах и зависит от локального state-store бота -- **Backend selection** — `MATRIX_PLATFORM_BACKEND=mock` остаётся значением по умолчанию; `MATRIX_PLATFORM_BACKEND=real` использует `platform-agent` из compose и WebSocket contract `/v1/agent_ws/{chat_id}/` -- **Ограничения real backend** — локальный runtime использует shared `/workspace`, файлы передаются как относительные пути в `attachments`, а transport layer со стороны `surfaces` использует pinned upstream `platform-agent_api.AgentApi` почти без локальной stream-логики; текущая реализация рабочая, но после tool/file flow остаётся подтверждённый upstream streaming bug, из-за которого начало ответа может пропадать +- **Backend selection** — `MATRIX_PLATFORM_BACKEND=mock` остаётся значением по умолчанию; `MATRIX_PLATFORM_BACKEND=real` использует `platform-agent` из compose и upstream `AgentApi` по contract `/v1/agent_ws/{chat_id}/` +- **Ограничения real backend** — локальный runtime использует shared `/workspace`, файлы передаются как относительные пути в `attachments`, а transport layer со стороны `surfaces` использует прямой upstream `platform-agent_api.AgentApi` без локального subclass; prod-default lifecycle открывает отдельное соединение на каждый запрос, но после tool/file flow всё ещё остаётся подтверждённый upstream streaming bug, из-за которого начало ответа может пропадать --- @@ -122,9 +122,6 @@ MATRIX_PASSWORD=... # или MATRIX_ACCESS_TOKEN=... MATRIX_PLATFORM_BACKEND=real # compose runtime: platform-agent service name + shared /workspace -# значение передаётся в thin wrapper как base URL; wrapper сам нормализует его -# до upstream WS route /v1/agent_ws/{chat_id}/ -AGENT_WS_URL=ws://platform-agent:8000/v1/agent_ws/ AGENT_BASE_URL=http://platform-agent:8000 SURFACES_WORKSPACE_DIR=/workspace diff --git a/adapter/matrix/bot.py b/adapter/matrix/bot.py index e7e68b2..debd2fa 100644 --- a/adapter/matrix/bot.py +++ b/adapter/matrix/bot.py @@ -2,8 +2,10 @@ from __future__ import annotations import asyncio import os +import re from dataclasses import dataclass from pathlib import Path +from urllib.parse import urlsplit, urlunsplit import structlog from dotenv import load_dotenv @@ -63,7 +65,6 @@ from core.protocol import ( ) from core.settings import SettingsManager from core.store import InMemoryStore, SQLiteStore, StateStore -from sdk.agent_api_wrapper import AgentApiWrapper from sdk.interface import PlatformClient, PlatformError from sdk.mock import MockPlatformClient from sdk.prototype_state import PrototypeStateStore @@ -89,8 +90,7 @@ def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> Event auth_mgr = AuthManager(platform, store) settings_mgr = SettingsManager(platform, store) prototype_state = getattr(platform, "_prototype_state", None) - agent_api = getattr(platform, "_agent_api", None) - agent_base_url = os.environ.get("AGENT_BASE_URL", "http://127.0.0.1:8000") + agent_base_url = _agent_base_url_from_env() dispatcher = EventDispatcher( platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr ) @@ -98,19 +98,32 @@ def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> Event register_matrix_handlers( dispatcher, store=store, - agent_api=agent_api, prototype_state=prototype_state, agent_base_url=agent_base_url, ) return dispatcher +def _normalize_agent_base_url(url: str) -> str: + parsed = urlsplit(url) + path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/")) + return urlunsplit((parsed.scheme, parsed.netloc, path, "", "")) + + +def _agent_base_url_from_env() -> str: + if base_url := os.environ.get("AGENT_BASE_URL"): + return base_url + if ws_url := os.environ.get("AGENT_WS_URL"): + return _normalize_agent_base_url(ws_url) + return "http://127.0.0.1:8000" + + def _build_platform_from_env() -> PlatformClient: backend = os.environ.get("MATRIX_PLATFORM_BACKEND", "mock").strip().lower() if backend == "real": - ws_url = os.environ["AGENT_WS_URL"] return RealPlatformClient( - agent_api=AgentApiWrapper(agent_id="matrix-bot", base_url=ws_url), + agent_id="matrix-bot", + agent_base_url=_agent_base_url_from_env(), prototype_state=PrototypeStateStore(), platform="matrix", ) @@ -128,8 +141,7 @@ def build_runtime( auth_mgr = AuthManager(platform, store) settings_mgr = SettingsManager(platform, store) prototype_state = getattr(platform, "_prototype_state", None) - agent_api = getattr(platform, "_agent_api", None) - agent_base_url = os.environ.get("AGENT_BASE_URL", "http://127.0.0.1:8000") + agent_base_url = _agent_base_url_from_env() dispatcher = EventDispatcher( platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr ) @@ -138,7 +150,6 @@ def build_runtime( dispatcher, client=client, store=store, - agent_api=agent_api, prototype_state=prototype_state, agent_base_url=agent_base_url, ) diff --git a/adapter/matrix/handlers/__init__.py b/adapter/matrix/handlers/__init__.py index 28e70eb..c028735 100644 --- a/adapter/matrix/handlers/__init__.py +++ b/adapter/matrix/handlers/__init__.py @@ -34,7 +34,6 @@ def register_matrix_handlers( dispatcher: EventDispatcher, client=None, store=None, - agent_api=None, prototype_state=None, agent_base_url: str = "http://127.0.0.1:8000", ) -> None: @@ -64,11 +63,11 @@ def register_matrix_handlers( dispatcher.register(IncomingCallback, "toggle_skill", handle_toggle_skill) dispatcher.register(IncomingCommand, "*", handle_unknown_command) - if agent_api is not None and prototype_state is not None: + if prototype_state is not None: dispatcher.register( IncomingCommand, "save", - make_handle_save(agent_api, store, prototype_state), + make_handle_save(None, store, prototype_state), ) dispatcher.register(IncomingCommand, "load", make_handle_load(store, prototype_state)) dispatcher.register(IncomingCommand, "context", make_handle_context(store, prototype_state)) diff --git a/docker-compose.yml b/docker-compose.yml index d6c2e4d..4de9fac 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -27,7 +27,6 @@ services: env_file: .env environment: AGENT_BASE_URL: http://platform-agent:8000 - AGENT_WS_URL: ws://platform-agent:8000/v1/agent_ws/ SURFACES_WORKSPACE_DIR: /workspace depends_on: - platform-agent diff --git a/docs/matrix-direct-agent-prototype-ru.md b/docs/matrix-direct-agent-prototype-ru.md index 6729520..8f1dcee 100644 --- a/docs/matrix-direct-agent-prototype-ru.md +++ b/docs/matrix-direct-agent-prototype-ru.md @@ -33,7 +33,7 @@ - переключение Matrix backend через env: - `MATRIX_PLATFORM_BACKEND=mock` - `MATRIX_PLATFORM_BACKEND=real` -- прямую отправку текста в live agent через `AGENT_WS_URL` +- прямую отправку текста в live agent через `AGENT_BASE_URL` - локальное хранение settings и user mapping - изоляцию backend memory по `thread_id` - исправление повторных invite: бот теперь сначала `join()`, а уже потом решает, нужно ли пере-провиженить Space/chat tree @@ -154,7 +154,7 @@ ws://127.0.0.1:8000/agent_ws/ cd /Users/a/MAI/sem2/lambda/surfaces-bot export MATRIX_PLATFORM_BACKEND=real -export AGENT_WS_URL=ws://127.0.0.1:8000/agent_ws/ +export AGENT_BASE_URL=http://127.0.0.1:8000 export MATRIX_HOMESERVER=https://matrix.lambda.coredump.ru export MATRIX_USER_ID=@lambda_surface_test_bot:matrix.lambda.coredump.ru export MATRIX_PASSWORD='YOUR_PASSWORD' @@ -193,7 +193,7 @@ uv run uvicorn src.main:app --host 0.0.0.0 --port 8000 cd /Users/a/MAI/sem2/lambda/surfaces-bot export MATRIX_PLATFORM_BACKEND=real -export AGENT_WS_URL=ws://127.0.0.1:8000/agent_ws/ +export AGENT_BASE_URL=http://127.0.0.1:8000 export MATRIX_HOMESERVER=https://matrix.lambda.coredump.ru export MATRIX_USER_ID=@lambda_surface_test_bot:matrix.lambda.coredump.ru export MATRIX_PASSWORD='YOUR_PASSWORD' diff --git a/sdk/agent_api_wrapper.py b/sdk/agent_api_wrapper.py deleted file mode 100644 index 34fee46..0000000 --- a/sdk/agent_api_wrapper.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -import re -import sys -from pathlib import Path -from urllib.parse import urlsplit, urlunsplit - -_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api" -if str(_api_root) not in sys.path: - sys.path.insert(0, str(_api_root)) - -from lambda_agent_api.agent_api import AgentApi # noqa: E402 - - -class AgentApiWrapper(AgentApi): - """Thin construction/factory shim over the pinned upstream AgentApi.""" - - def __init__( - self, - agent_id: str, - base_url: str, - *, - chat_id: int | str = 0, - **kwargs, - ) -> None: - self._base_url = self._normalize_base_url(base_url) - self._init_kwargs = dict(kwargs) - self.chat_id = chat_id - super().__init__( - agent_id=agent_id, - base_url=self._base_url, - chat_id=chat_id, - **kwargs, - ) - - @staticmethod - def _normalize_base_url(base_url: str) -> str: - parsed = urlsplit(base_url) - path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/")) - return urlunsplit((parsed.scheme, parsed.netloc, path, "", "")) - - def for_chat(self, chat_id: int | str) -> AgentApiWrapper: - return type(self)( - agent_id=self.id, - base_url=self._base_url, - chat_id=chat_id, - **self._init_kwargs, - ) diff --git a/sdk/agent_session.py b/sdk/agent_session.py index 63acdd1..187b88a 100644 --- a/sdk/agent_session.py +++ b/sdk/agent_session.py @@ -1 +1 @@ -"""Compatibility stub: AgentSessionClient was replaced by AgentApiWrapper in Phase 4.""" +"""Compatibility stub: AgentSessionClient was replaced by direct AgentApi usage in Phase 4.""" diff --git a/sdk/real.py b/sdk/real.py index 2b43056..0b7ef19 100644 --- a/sdk/real.py +++ b/sdk/real.py @@ -4,9 +4,6 @@ import asyncio from collections.abc import AsyncIterator from pathlib import Path -from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk - -from sdk.agent_api_wrapper import AgentApiWrapper from sdk.interface import ( Attachment, MessageChunk, @@ -17,37 +14,32 @@ from sdk.interface import ( UserSettings, ) from sdk.prototype_state import PrototypeStateStore +from sdk.upstream_agent_api import AgentApi, MsgEventSendFile, MsgEventTextChunk class RealPlatformClient(PlatformClient): def __init__( self, - agent_api: AgentApiWrapper, + agent_id: str, + agent_base_url: str, prototype_state: PrototypeStateStore, platform: str = "matrix", + agent_api_cls=AgentApi, ) -> None: - self._agent_api = agent_api + self._agent_id = agent_id + self._agent_base_url = agent_base_url + self._agent_api_cls = agent_api_cls self._prototype_state = prototype_state self._platform = platform - self._chat_apis: dict[str, AgentApiWrapper] = {} - self._chat_api_lock = asyncio.Lock() self._chat_send_locks: dict[str, asyncio.Lock] = {} @property - def agent_api(self) -> AgentApiWrapper: - return self._agent_api + def agent_id(self) -> str: + return self._agent_id - async def _get_chat_api(self, chat_id: str): - chat_key = str(chat_id) - chat_api = self._chat_apis.get(chat_key) - if chat_api is None: - async with self._chat_api_lock: - chat_api = self._chat_apis.get(chat_key) - if chat_api is None: - chat_api = self._agent_api.for_chat(chat_key) - await chat_api.connect() - self._chat_apis[chat_key] = chat_api - return chat_api + @property + def agent_base_url(self) -> str: + return self._agent_base_url def _get_chat_send_lock(self, chat_id: str) -> asyncio.Lock: chat_key = str(chat_id) @@ -82,9 +74,9 @@ class RealPlatformClient(PlatformClient): lock = self._get_chat_send_lock(chat_id) async with lock: - chat_api = await self._get_chat_api(chat_id) - + chat_api = self._build_chat_api(chat_id) try: + await chat_api.connect() async for event in self._stream_agent_events( chat_api, text, attachments=attachments ): @@ -96,8 +88,9 @@ class RealPlatformClient(PlatformClient): if attachment is not None: sent_attachments.append(attachment) except Exception as exc: - await self._handle_chat_api_failure(chat_id, exc) - + raise self._to_platform_error(exc) from exc + finally: + await self._close_chat_api(chat_api) await self._prototype_state.set_last_tokens_used(str(chat_id), 0) response_kwargs = { @@ -118,8 +111,9 @@ class RealPlatformClient(PlatformClient): ) -> AsyncIterator[MessageChunk]: lock = self._get_chat_send_lock(chat_id) async with lock: - chat_api = await self._get_chat_api(chat_id) + chat_api = self._build_chat_api(chat_id) try: + await chat_api.connect() async for event in self._stream_agent_events( chat_api, text, attachments=attachments ): @@ -132,7 +126,9 @@ class RealPlatformClient(PlatformClient): elif isinstance(event, MsgEventSendFile): continue except Exception as exc: - await self._handle_chat_api_failure(chat_id, exc) + raise self._to_platform_error(exc) from exc + finally: + await self._close_chat_api(chat_api) await self._prototype_state.set_last_tokens_used(str(chat_id), 0) yield MessageChunk( message_id=user_id, @@ -148,20 +144,9 @@ class RealPlatformClient(PlatformClient): await self._prototype_state.update_settings(user_id, action) async def disconnect_chat(self, chat_id: str) -> None: - chat_key = str(chat_id) - chat_api = self._chat_apis.pop(chat_key, None) - self._chat_send_locks.pop(chat_key, None) - if chat_api is not None: - close = getattr(chat_api, "close", None) - if callable(close): - await close() + self._chat_send_locks.pop(str(chat_id), None) async def close(self) -> None: - for chat_api in list(self._chat_apis.values()): - close = getattr(chat_api, "close", None) - if callable(close): - await close() - self._chat_apis.clear() self._chat_send_locks.clear() async def _stream_agent_events( @@ -175,10 +160,26 @@ class RealPlatformClient(PlatformClient): async for event in event_stream: yield event - async def _handle_chat_api_failure(self, chat_id: str, exc: Exception) -> None: - await self.disconnect_chat(chat_id) + def _build_chat_api(self, chat_id: str): + return self._agent_api_cls( + agent_id=self._agent_id, + base_url=self._agent_base_url, + chat_id=str(chat_id), + ) + + @staticmethod + async def _close_chat_api(chat_api) -> None: + close = getattr(chat_api, "close", None) + if callable(close): + try: + await close() + except Exception: + pass + + @staticmethod + def _to_platform_error(exc: Exception) -> PlatformError: code = getattr(exc, "code", None) or "PLATFORM_CONNECTION_ERROR" - raise PlatformError(str(exc), code=code) from exc + return PlatformError(str(exc), code=code) @staticmethod def _attachment_paths(attachments: list[Attachment] | None) -> list[str]: diff --git a/sdk/upstream_agent_api.py b/sdk/upstream_agent_api.py new file mode 100644 index 0000000..d0bfdd7 --- /dev/null +++ b/sdk/upstream_agent_api.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api" +if str(_api_root) not in sys.path: + sys.path.insert(0, str(_api_root)) + +from lambda_agent_api.agent_api import AgentApi, AgentBusyException, AgentException # noqa: E402 +from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk # noqa: E402 + +__all__ = [ + "AgentApi", + "AgentBusyException", + "AgentException", + "MsgEventSendFile", + "MsgEventTextChunk", +] diff --git a/tests/adapter/matrix/test_dispatcher.py b/tests/adapter/matrix/test_dispatcher.py index 01b35da..7fa7a47 100644 --- a/tests/adapter/matrix/test_dispatcher.py +++ b/tests/adapter/matrix/test_dispatcher.py @@ -908,34 +908,21 @@ async def test_prepare_live_sync_returns_next_batch_from_bootstrap_sync(): async def test_build_runtime_uses_real_platform_when_matrix_backend_is_real(monkeypatch): - bot_module = importlib.import_module("adapter.matrix.bot") - - class FakeAgentApiWrapper: - def __init__(self, agent_id: str, base_url: str) -> None: - self.agent_id = agent_id - self.base_url = base_url - - monkeypatch.setattr(bot_module, "AgentApiWrapper", FakeAgentApiWrapper) monkeypatch.setenv("MATRIX_PLATFORM_BACKEND", "real") - monkeypatch.setenv("AGENT_WS_URL", "ws://agent.example/agent_ws/") + monkeypatch.setenv("AGENT_BASE_URL", "http://agent.example") runtime = build_runtime() assert isinstance(runtime.platform, RealPlatformClient) - assert runtime.platform.agent_api.base_url == "ws://agent.example/agent_ws/" + assert runtime.platform.agent_base_url == "http://agent.example" + assert runtime.platform.agent_id == "matrix-bot" async def test_matrix_main_closes_platform_without_connecting_root_agent(monkeypatch): bot_module = importlib.import_module("adapter.matrix.bot") platform_close = AsyncMock() - agent_connect = AsyncMock() - runtime = SimpleNamespace( - platform=SimpleNamespace( - close=platform_close, - agent_api=SimpleNamespace(connect=agent_connect), - ) - ) + runtime = SimpleNamespace(platform=SimpleNamespace(close=platform_close)) class FakeAsyncClient: def __init__(self, *args, **kwargs): @@ -959,7 +946,6 @@ async def test_matrix_main_closes_platform_without_connecting_root_agent(monkeyp await bot_module.main() - agent_connect.assert_not_awaited() platform_close.assert_awaited_once() diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 5287074..9260ec8 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -4,7 +4,6 @@ Smoke test: полный цикл через dispatcher + реальные manag Имитирует что делает адаптер (Telegram или Matrix) при получении события. """ import pytest -from lambda_agent_api.server import MsgEventTextChunk from core.auth import AuthManager from core.chat import ChatManager @@ -23,10 +22,13 @@ from core.store import InMemoryStore from sdk.mock import MockPlatformClient from sdk.prototype_state import PrototypeStateStore from sdk.real import RealPlatformClient +from sdk.upstream_agent_api import MsgEventTextChunk class FakeAgentApi: - def __init__(self, chat_id: str) -> None: + def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None: + self.agent_id = agent_id + self.base_url = base_url self.chat_id = chat_id self.calls: list[tuple[str, list[str]]] = [] self.connect_calls = 0 @@ -46,12 +48,12 @@ class FakeAgentApi: class FakeAgentApiFactory: def __init__(self) -> None: self.created_chat_ids: list[str] = [] - self.instances: dict[str, FakeAgentApi] = {} + self.instances: dict[str, list[FakeAgentApi]] = {} - def for_chat(self, chat_id: str) -> FakeAgentApi: - chat_api = FakeAgentApi(chat_id) + def __call__(self, agent_id: str, base_url: str, chat_id: str) -> FakeAgentApi: + chat_api = FakeAgentApi(agent_id, base_url, chat_id) self.created_chat_ids.append(chat_id) - self.instances[chat_id] = chat_api + self.instances.setdefault(chat_id, []).append(chat_api) return chat_api @@ -73,7 +75,9 @@ def dispatcher(): def real_dispatcher(): agent_api = FakeAgentApiFactory() platform = RealPlatformClient( - agent_api=agent_api, + agent_id="matrix-bot", + agent_base_url="http://platform-agent:8000", + agent_api_cls=agent_api, prototype_state=PrototypeStateStore(), platform="matrix", ) @@ -147,7 +151,7 @@ async def test_toggle_skill_callback(dispatcher): assert any("browser" in r.text for r in result if isinstance(r, OutgoingMessage)) -async def test_full_flow_with_real_platform_uses_shared_agent_api(real_dispatcher): +async def test_full_flow_with_real_platform_uses_direct_agent_api(real_dispatcher): dispatcher, agent_api = real_dispatcher start = IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="start") @@ -160,7 +164,7 @@ async def test_full_flow_with_real_platform_uses_shared_agent_api(real_dispatche assert texts == ["[REAL] Привет!"] assert agent_api.created_chat_ids == ["C1"] - assert agent_api.instances["C1"].calls == [("Привет!", [])] + assert [instance.calls for instance in agent_api.instances["C1"]] == [[("Привет!", [])]] async def test_full_flow_with_real_platform_forwards_workspace_attachment(real_dispatcher): @@ -185,6 +189,6 @@ async def test_full_flow_with_real_platform_forwards_workspace_attachment(real_d ) await dispatcher.dispatch(msg) - assert agent_api.instances["C1"].calls == [ - ("Посмотри файл", ["surfaces/matrix/u1/room/inbox/report.pdf"]) + assert [instance.calls for instance in agent_api.instances["C1"]] == [ + [("Посмотри файл", ["surfaces/matrix/u1/room/inbox/report.pdf"])] ] diff --git a/tests/platform/test_agent_session.py b/tests/platform/test_agent_session.py index 7f419e8..bda5cfe 100644 --- a/tests/platform/test_agent_session.py +++ b/tests/platform/test_agent_session.py @@ -1,16 +1,10 @@ """Compatibility tests after the Phase 4 migration.""" -import sys from pathlib import Path -_api_root = Path(__file__).resolve().parents[2] / "external" / "platform-agent_api" -if str(_api_root) not in sys.path: - sys.path.insert(0, str(_api_root)) - - def test_lambda_agent_api_module_is_importable(): - from lambda_agent_api.agent_api import AgentApi + from sdk.upstream_agent_api import AgentApi assert AgentApi is not None @@ -18,4 +12,4 @@ def test_lambda_agent_api_module_is_importable(): def test_agent_session_module_is_intentionally_stubbed(): contents = Path(__file__).resolve().parents[2] / "sdk" / "agent_session.py" - assert "replaced by AgentApiWrapper" in contents.read_text() + assert "replaced by direct AgentApi usage" in contents.read_text() diff --git a/tests/platform/test_real.py b/tests/platform/test_real.py index 38b19e3..7a2e37e 100644 --- a/tests/platform/test_real.py +++ b/tests/platform/test_real.py @@ -1,20 +1,20 @@ import asyncio import pytest -from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk from pydantic import Field -import sdk.agent_api_wrapper as agent_api_wrapper_module from core.protocol import SettingsAction -from sdk.agent_api_wrapper import AgentApiWrapper from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformError, UserSettings from sdk.prototype_state import PrototypeStateStore from sdk.real import RealPlatformClient +from sdk.upstream_agent_api import MsgEventSendFile, MsgEventTextChunk class FakeChatAgentApi: - def __init__(self, chat_id: str) -> None: - self.chat_id = chat_id + def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None: + self.agent_id = agent_id + self.base_url = base_url + self.chat_id = str(chat_id) self.calls: list[str] = [] self.connect_calls = 0 self.close_calls = 0 @@ -33,155 +33,125 @@ class FakeChatAgentApi: class FakeAgentApiFactory: - def __init__(self) -> None: - self.created_chat_ids: list[str] = [] - self.instances: dict[str, FakeChatAgentApi] = {} + def __init__(self, chat_api_cls=FakeChatAgentApi) -> None: + self.chat_api_cls = chat_api_cls + self.created_calls: list[tuple[str, str, str]] = [] + self.instances_by_chat: dict[str, list[FakeChatAgentApi]] = {} - def for_chat(self, chat_id: str) -> FakeChatAgentApi: - chat_api = FakeChatAgentApi(chat_id) - self.created_chat_ids.append(chat_id) - self.instances[chat_id] = chat_api + def __call__(self, agent_id: str, base_url: str, chat_id: str): + chat_key = str(chat_id) + chat_api = self.chat_api_cls(agent_id, base_url, chat_key) + self.created_calls.append((agent_id, base_url, chat_key)) + self.instances_by_chat.setdefault(chat_key, []).append(chat_api) return chat_api + def latest(self, chat_id: str): + return self.instances_by_chat[str(chat_id)][-1] -class BlockingChatAgentApi: - def __init__(self, chat_id: str) -> None: - self.chat_id = chat_id - self.calls: list[str] = [] - self.connect_calls = 0 - self.close_calls = 0 + +class BlockingTracker: + def __init__(self) -> None: self.active_calls = 0 self.max_active_calls = 0 self.started = asyncio.Event() self.release = asyncio.Event() - async def connect(self) -> None: - self.connect_calls += 1 - async def close(self) -> None: - self.close_calls += 1 +class BlockingChatAgentApi(FakeChatAgentApi): + def __init__( + self, + agent_id: str, + base_url: str, + chat_id: str, + *, + tracker: BlockingTracker, + ) -> None: + super().__init__(agent_id, base_url, chat_id) + self._tracker = tracker async def send_message(self, text: str, attachments: list[str] | None = None): self.calls.append(text) - self.active_calls += 1 - self.max_active_calls = max(self.max_active_calls, self.active_calls) - self.started.set() - await self.release.wait() - self.active_calls -= 1 + self._tracker.active_calls += 1 + self._tracker.max_active_calls = max( + self._tracker.max_active_calls, + self._tracker.active_calls, + ) + self._tracker.started.set() + await self._tracker.release.wait() + self._tracker.active_calls -= 1 yield MsgEventTextChunk(text=text) -class AttachmentTrackingChatAgentApi: - def __init__(self, chat_id: str) -> None: - self.chat_id = chat_id +class BlockingAgentApiFactory(FakeAgentApiFactory): + def __init__(self) -> None: + super().__init__() + self.tracker = BlockingTracker() + + def __call__(self, agent_id: str, base_url: str, chat_id: str): + chat_key = str(chat_id) + chat_api = BlockingChatAgentApi( + agent_id, + base_url, + chat_key, + tracker=self.tracker, + ) + self.created_calls.append((agent_id, base_url, chat_key)) + self.instances_by_chat.setdefault(chat_key, []).append(chat_api) + return chat_api + + +class AttachmentTrackingChatAgentApi(FakeChatAgentApi): + def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None: + super().__init__(agent_id, base_url, chat_id) self.calls: list[tuple[str, list[str] | None]] = [] - self.connect_calls = 0 - self.close_calls = 0 - - async def connect(self) -> None: - self.connect_calls += 1 - - async def close(self) -> None: - self.close_calls += 1 async def send_message(self, text: str, attachments: list[str] | None = None): self.calls.append((text, attachments)) yield MsgEventTextChunk(text=text) -class AttachmentTrackingAgentApiFactory: - def __init__(self, chat_api_cls=AttachmentTrackingChatAgentApi) -> None: - self.chat_api_cls = chat_api_cls - self.created_chat_ids: list[str] = [] - self.instances: dict[str, AttachmentTrackingChatAgentApi] = {} - - def for_chat(self, chat_id: str) -> AttachmentTrackingChatAgentApi: - chat_api = self.chat_api_cls(chat_id) - self.created_chat_ids.append(chat_id) - self.instances[chat_id] = chat_api - return chat_api - - -class FlakyChatAgentApi: - def __init__(self, chat_id: str) -> None: - self.chat_id = chat_id - self.connect_calls = 0 - self.close_calls = 0 - - async def connect(self) -> None: - self.connect_calls += 1 - - async def close(self) -> None: - self.close_calls += 1 - +class FlakyChatAgentApi(FakeChatAgentApi): async def send_message(self, text: str, attachments: list[str] | None = None): raise ConnectionError("Connection closed") yield +class ReuseSensitiveChatAgentApi(FakeChatAgentApi): + def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None: + super().__init__(agent_id, base_url, chat_id) + self._send_calls = 0 + + async def send_message(self, text: str, attachments: list[str] | None = None): + self.calls.append(text) + self._send_calls += 1 + if text == "first": + yield MsgEventTextChunk(text="tool ok") + return + if text == "second" and self._send_calls == 1: + yield MsgEventTextChunk(text="Missing") + + class MessageResponseWithAttachments(MessageResponse): attachments: list[Attachment] = Field(default_factory=list) -def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch): - captured = {} - - def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): - captured["agent_id"] = agent_id - captured["base_url"] = base_url - captured["chat_id"] = chat_id - - monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) - - wrapper = AgentApiWrapper( - agent_id="agent-1", - base_url="ws://platform-agent:8000/v1/agent_ws/", - chat_id="41", +def make_real_platform_client( + agent_api_cls, + *, + prototype_state: PrototypeStateStore | None = None, +) -> RealPlatformClient: + return RealPlatformClient( + agent_id="matrix-bot", + agent_base_url="http://platform-agent:8000", + agent_api_cls=agent_api_cls, + prototype_state=prototype_state or PrototypeStateStore(), + platform="matrix", ) - assert wrapper.chat_id == "41" - assert wrapper._base_url == "ws://platform-agent:8000" - assert captured == { - "agent_id": "agent-1", - "base_url": "ws://platform-agent:8000", - "chat_id": "41", - } - - -def test_agent_api_wrapper_for_chat_reuses_normalized_base_url(monkeypatch): - init_calls = [] - - def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs): - self.id = agent_id - self.chat_id = chat_id - self.url = base_url - init_calls.append((agent_id, base_url, chat_id)) - - monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init) - - root = AgentApiWrapper( - agent_id="agent-1", - base_url="http://platform-agent:8000/v1/agent_ws/", - chat_id="1", - ) - - child = root.for_chat("99") - - assert child is not root - assert child.chat_id == "99" - assert child._base_url == "http://platform-agent:8000" - assert init_calls == [ - ("agent-1", "http://platform-agent:8000", "1"), - ("agent-1", "http://platform-agent:8000", "99"), - ] - @pytest.mark.asyncio async def test_real_platform_client_get_or_create_user_uses_local_state(): - client = RealPlatformClient( - agent_api=FakeAgentApiFactory(), - prototype_state=PrototypeStateStore(), - ) + client = make_real_platform_client(FakeAgentApiFactory()) first = await client.get_or_create_user("u1", "matrix", "Alice") second = await client.get_or_create_user("u1", "matrix") @@ -194,14 +164,10 @@ async def test_real_platform_client_get_or_create_user_uses_local_state(): @pytest.mark.asyncio -async def test_real_platform_client_send_message_uses_chat_bound_client(): +async def test_real_platform_client_send_message_uses_direct_agent_api_per_chat(): agent_api = FakeAgentApiFactory() prototype_state = PrototypeStateStore() - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=prototype_state, - platform="matrix", - ) + client = make_real_platform_client(agent_api, prototype_state=prototype_state) result = await client.send_message("@alice:example.org", "chat-7", "hello") @@ -211,21 +177,18 @@ async def test_real_platform_client_send_message_uses_chat_bound_client(): tokens_used=0, finished=True, ) - assert agent_api.created_chat_ids == ["chat-7"] - assert agent_api.instances["chat-7"].chat_id == "chat-7" - assert agent_api.instances["chat-7"].calls == ["hello"] - assert agent_api.instances["chat-7"].connect_calls == 1 + assert agent_api.created_calls == [("matrix-bot", "http://platform-agent:8000", "chat-7")] + assert agent_api.latest("chat-7").chat_id == "chat-7" + assert agent_api.latest("chat-7").calls == ["hello"] + assert agent_api.latest("chat-7").connect_calls == 1 + assert agent_api.latest("chat-7").close_calls == 1 assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 0 @pytest.mark.asyncio async def test_real_platform_client_forwards_attachments_to_chat_api(): - agent_api = AttachmentTrackingAgentApiFactory() - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=PrototypeStateStore(), - platform="matrix", - ) + agent_api = FakeAgentApiFactory(chat_api_cls=AttachmentTrackingChatAgentApi) + client = make_real_platform_client(agent_api) attachment = Attachment( workspace_path="surfaces/matrix/alice/room/inbox/report.pdf", mime_type="application/pdf", @@ -240,7 +203,7 @@ async def test_real_platform_client_forwards_attachments_to_chat_api(): attachments=[attachment], ) - assert agent_api.instances["chat-7"].calls == [ + assert agent_api.latest("chat-7").calls == [ ("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"]) ] assert result.response == "hello" @@ -256,17 +219,10 @@ async def test_real_platform_client_preserves_send_file_events_in_sync_result(mo yield MsgEventSendFile(path="report.pdf") yield MsgEventTextChunk(text="llo") - agent_api = AttachmentTrackingAgentApiFactory(chat_api_cls=FileEventAgentApi) - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=PrototypeStateStore(), - platform="matrix", - ) + agent_api = FakeAgentApiFactory(chat_api_cls=FileEventAgentApi) + client = make_real_platform_client(agent_api) - monkeypatch.setattr( - "sdk.real.MessageResponse", - MessageResponseWithAttachments, - ) + monkeypatch.setattr("sdk.real.MessageResponse", MessageResponseWithAttachments) result = await client.send_message("@alice:example.org", "chat-7", "hello") @@ -284,63 +240,61 @@ async def test_real_platform_client_preserves_send_file_events_in_sync_result(mo @pytest.mark.asyncio -async def test_real_platform_client_reuses_cached_chat_client(): +async def test_real_platform_client_uses_fresh_agent_connection_per_request(): agent_api = FakeAgentApiFactory() - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=PrototypeStateStore(), - platform="matrix", - ) + client = make_real_platform_client(agent_api) await client.send_message("@alice:example.org", "chat-1", "hello") await client.send_message("@alice:example.org", "chat-1", "again") - assert agent_api.created_chat_ids == ["chat-1"] - assert agent_api.instances["chat-1"].calls == ["hello", "again"] - assert agent_api.instances["chat-1"].connect_calls == 1 - assert agent_api.instances["chat-1"].close_calls == 0 + assert agent_api.created_calls == [ + ("matrix-bot", "http://platform-agent:8000", "chat-1"), + ("matrix-bot", "http://platform-agent:8000", "chat-1"), + ] + assert [instance.calls for instance in agent_api.instances_by_chat["chat-1"]] == [ + ["hello"], + ["again"], + ] + assert all(instance.connect_calls == 1 for instance in agent_api.instances_by_chat["chat-1"]) + assert all(instance.close_calls == 1 for instance in agent_api.instances_by_chat["chat-1"]) + + +@pytest.mark.asyncio +async def test_real_platform_client_avoids_reuse_sensitive_second_message_loss(): + agent_api = FakeAgentApiFactory(chat_api_cls=ReuseSensitiveChatAgentApi) + client = make_real_platform_client(agent_api) + + first = await client.send_message("@alice:example.org", "chat-1", "first") + second = await client.send_message("@alice:example.org", "chat-1", "second") + + assert first.response == "tool ok" + assert second.response == "Missing" + assert len(agent_api.instances_by_chat["chat-1"]) == 2 @pytest.mark.asyncio async def test_real_platform_client_wraps_connection_closed_as_platform_error(): - agent_api = FakeAgentApiFactory() - agent_api.instances["chat-1"] = FlakyChatAgentApi("chat-1") - agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault( - chat_id, FlakyChatAgentApi(chat_id) - ) - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=PrototypeStateStore(), - platform="matrix", - ) + agent_api = FakeAgentApiFactory(chat_api_cls=FlakyChatAgentApi) + client = make_real_platform_client(agent_api) with pytest.raises(PlatformError, match="Connection closed") as exc_info: await client.send_message("@alice:example.org", "chat-1", "hello") assert exc_info.value.code == "PLATFORM_CONNECTION_ERROR" - assert "chat-1" not in client._chat_apis - assert agent_api.instances["chat-1"].close_calls == 1 + assert agent_api.latest("chat-1").close_calls == 1 @pytest.mark.asyncio -async def test_real_platform_client_reconnects_after_closed_chat_api(): - agent_api = FakeAgentApiFactory() - flaky = FlakyChatAgentApi("chat-1") - healthy = AttachmentTrackingChatAgentApi("chat-1") - provided = iter([flaky, healthy]) +async def test_real_platform_client_uses_fresh_connection_after_failure(): + class SometimesFlakyAgentApi(FakeChatAgentApi): + async def send_message(self, text: str, attachments: list[str] | None = None): + if text == "hello": + raise ConnectionError("Connection closed") + self.calls.append(text) + yield MsgEventTextChunk(text=text) - def for_chat(chat_id: str): - chat_api = next(provided) - agent_api.created_chat_ids.append(chat_id) - agent_api.instances[chat_id] = chat_api - return chat_api - - agent_api.for_chat = for_chat - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=PrototypeStateStore(), - platform="matrix", - ) + agent_api = FakeAgentApiFactory(chat_api_cls=SometimesFlakyAgentApi) + client = make_real_platform_client(agent_api) with pytest.raises(PlatformError, match="Connection closed"): await client.send_message("@alice:example.org", "chat-1", "hello") @@ -348,60 +302,17 @@ async def test_real_platform_client_reconnects_after_closed_chat_api(): result = await client.send_message("@alice:example.org", "chat-1", "again") assert result.response == "again" - assert agent_api.created_chat_ids == ["chat-1", "chat-1"] - assert healthy.calls == [("again", None)] - - -@pytest.mark.asyncio -async def test_real_platform_client_creates_chat_client_atomically_for_concurrent_requests(): - agent_api = FakeAgentApiFactory() - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=PrototypeStateStore(), - platform="matrix", - ) - - results = await asyncio.gather( - client.send_message("@alice:example.org", "chat-1", "hello"), - client.send_message("@alice:example.org", "chat-1", "again"), - ) - - assert [result.response for result in results] == ["hello", "again"] - assert agent_api.created_chat_ids == ["chat-1"] - assert agent_api.instances["chat-1"].connect_calls == 1 - assert agent_api.instances["chat-1"].calls == ["hello", "again"] - - -@pytest.mark.asyncio -async def test_real_platform_client_creates_distinct_clients_per_chat(): - agent_api = FakeAgentApiFactory() - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=PrototypeStateStore(), - platform="matrix", - ) - - await client.send_message("@alice:example.org", "chat-1", "hello") - await client.send_message("@alice:example.org", "chat-2", "world") - - assert agent_api.created_chat_ids == ["chat-1", "chat-2"] - assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"] - assert agent_api.instances["chat-1"].calls == ["hello"] - assert agent_api.instances["chat-2"].calls == ["world"] + assert agent_api.created_calls == [ + ("matrix-bot", "http://platform-agent:8000", "chat-1"), + ("matrix-bot", "http://platform-agent:8000", "chat-1"), + ] + assert agent_api.latest("chat-1").calls == ["again"] @pytest.mark.asyncio async def test_real_platform_client_serializes_same_chat_streams_across_send_paths(): - agent_api = FakeAgentApiFactory() - agent_api.instances["chat-1"] = BlockingChatAgentApi("chat-1") - agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault( - chat_id, BlockingChatAgentApi(chat_id) - ) - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=PrototypeStateStore(), - platform="matrix", - ) + agent_api = BlockingAgentApiFactory() + client = make_real_platform_client(agent_api) async def consume_stream(): chunks = [] @@ -410,32 +321,48 @@ async def test_real_platform_client_serializes_same_chat_streams_across_send_pat return chunks stream_task = asyncio.create_task(consume_stream()) - await asyncio.wait_for(agent_api.instances["chat-1"].started.wait(), timeout=1) + await asyncio.wait_for(agent_api.tracker.started.wait(), timeout=1) send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again")) await asyncio.sleep(0) - assert agent_api.instances["chat-1"].calls == ["hello"] - assert agent_api.instances["chat-1"].max_active_calls == 1 + assert len(agent_api.instances_by_chat["chat-1"]) == 1 + assert agent_api.instances_by_chat["chat-1"][0].calls == ["hello"] + assert agent_api.tracker.max_active_calls == 1 - agent_api.instances["chat-1"].release.set() + agent_api.tracker.release.set() stream_chunks = await stream_task send_result = await send_task assert [chunk.delta for chunk in stream_chunks] == ["hello", ""] assert send_result.response == "again" - assert agent_api.instances["chat-1"].calls == ["hello", "again"] - assert agent_api.instances["chat-1"].max_active_calls == 1 + assert [instance.calls for instance in agent_api.instances_by_chat["chat-1"]] == [ + ["hello"], + ["again"], + ] + assert agent_api.tracker.max_active_calls == 1 + + +@pytest.mark.asyncio +async def test_real_platform_client_creates_distinct_connections_per_chat(): + agent_api = FakeAgentApiFactory() + client = make_real_platform_client(agent_api) + + await client.send_message("@alice:example.org", "chat-1", "hello") + await client.send_message("@alice:example.org", "chat-2", "world") + + assert agent_api.created_calls == [ + ("matrix-bot", "http://platform-agent:8000", "chat-1"), + ("matrix-bot", "http://platform-agent:8000", "chat-2"), + ] + assert agent_api.latest("chat-1").calls == ["hello"] + assert agent_api.latest("chat-2").calls == ["world"] @pytest.mark.asyncio async def test_real_platform_client_stream_message_emits_final_tokens_chunk(): agent_api = FakeAgentApiFactory() - client = RealPlatformClient( - agent_api=agent_api, - prototype_state=PrototypeStateStore(), - platform="matrix", - ) + client = make_real_platform_client(agent_api) chunks = [] async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"): @@ -461,17 +388,14 @@ async def test_real_platform_client_stream_message_emits_final_tokens_chunk(): tokens_used=0, ), ] - assert agent_api.created_chat_ids == ["chat-1"] - assert agent_api.instances["chat-1"].calls == ["hello"] + assert agent_api.created_calls == [("matrix-bot", "http://platform-agent:8000", "chat-1")] + assert agent_api.latest("chat-1").calls == ["hello"] + assert agent_api.latest("chat-1").close_calls == 1 @pytest.mark.asyncio async def test_real_platform_client_settings_are_local(): - client = RealPlatformClient( - agent_api=FakeAgentApiFactory(), - prototype_state=PrototypeStateStore(), - platform="matrix", - ) + client = make_real_platform_client(FakeAgentApiFactory()) await client.update_settings( "usr-matrix-u1",