fix: use direct agent api per request
This commit is contained in:
parent
7d270d3d31
commit
7d58dd1caf
14 changed files with 285 additions and 400 deletions
|
|
@ -11,7 +11,6 @@ MATRIX_PLATFORM_BACKEND=real
|
|||
SURFACES_WORKSPACE_DIR=/workspace
|
||||
|
||||
# Compose-local platform-agent route
|
||||
AGENT_WS_URL=ws://platform-agent:8000/v1/agent_ws/
|
||||
AGENT_BASE_URL=http://platform-agent:8000
|
||||
|
||||
# platform-agent provider
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ surfaces-bot/
|
|||
- **Диалог** — сообщения, вложения, подтверждения `!yes` / `!no` и routing через `EventDispatcher`
|
||||
- **Стабильность** — перед `sync_forever()` бот делает bootstrap sync и стартует с `since`, чтобы не переигрывать старую timeline после рестарта
|
||||
- **Текущее ограничение** — encrypted DM официально не поддержан; ручное тестирование Matrix ведётся в незашифрованных комнатах и зависит от локального state-store бота
|
||||
- **Backend selection** — `MATRIX_PLATFORM_BACKEND=mock` остаётся значением по умолчанию; `MATRIX_PLATFORM_BACKEND=real` использует `platform-agent` из compose и WebSocket contract `/v1/agent_ws/{chat_id}/`
|
||||
- **Ограничения real backend** — локальный runtime использует shared `/workspace`, файлы передаются как относительные пути в `attachments`, а transport layer со стороны `surfaces` использует pinned upstream `platform-agent_api.AgentApi` почти без локальной stream-логики; текущая реализация рабочая, но после tool/file flow остаётся подтверждённый upstream streaming bug, из-за которого начало ответа может пропадать
|
||||
- **Backend selection** — `MATRIX_PLATFORM_BACKEND=mock` остаётся значением по умолчанию; `MATRIX_PLATFORM_BACKEND=real` использует `platform-agent` из compose и upstream `AgentApi` по contract `/v1/agent_ws/{chat_id}/`
|
||||
- **Ограничения real backend** — локальный runtime использует shared `/workspace`, файлы передаются как относительные пути в `attachments`, а transport layer со стороны `surfaces` использует прямой upstream `platform-agent_api.AgentApi` без локального subclass; prod-default lifecycle открывает отдельное соединение на каждый запрос, но после tool/file flow всё ещё остаётся подтверждённый upstream streaming bug, из-за которого начало ответа может пропадать
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -122,9 +122,6 @@ MATRIX_PASSWORD=... # или MATRIX_ACCESS_TOKEN=...
|
|||
MATRIX_PLATFORM_BACKEND=real
|
||||
|
||||
# compose runtime: platform-agent service name + shared /workspace
|
||||
# значение передаётся в thin wrapper как base URL; wrapper сам нормализует его
|
||||
# до upstream WS route /v1/agent_ws/{chat_id}/
|
||||
AGENT_WS_URL=ws://platform-agent:8000/v1/agent_ws/
|
||||
AGENT_BASE_URL=http://platform-agent:8000
|
||||
SURFACES_WORKSPACE_DIR=/workspace
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlsplit, urlunsplit
|
||||
|
||||
import structlog
|
||||
from dotenv import load_dotenv
|
||||
|
|
@ -63,7 +65,6 @@ from core.protocol import (
|
|||
)
|
||||
from core.settings import SettingsManager
|
||||
from core.store import InMemoryStore, SQLiteStore, StateStore
|
||||
from sdk.agent_api_wrapper import AgentApiWrapper
|
||||
from sdk.interface import PlatformClient, PlatformError
|
||||
from sdk.mock import MockPlatformClient
|
||||
from sdk.prototype_state import PrototypeStateStore
|
||||
|
|
@ -89,8 +90,7 @@ def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> Event
|
|||
auth_mgr = AuthManager(platform, store)
|
||||
settings_mgr = SettingsManager(platform, store)
|
||||
prototype_state = getattr(platform, "_prototype_state", None)
|
||||
agent_api = getattr(platform, "_agent_api", None)
|
||||
agent_base_url = os.environ.get("AGENT_BASE_URL", "http://127.0.0.1:8000")
|
||||
agent_base_url = _agent_base_url_from_env()
|
||||
dispatcher = EventDispatcher(
|
||||
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
||||
)
|
||||
|
|
@ -98,19 +98,32 @@ def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> Event
|
|||
register_matrix_handlers(
|
||||
dispatcher,
|
||||
store=store,
|
||||
agent_api=agent_api,
|
||||
prototype_state=prototype_state,
|
||||
agent_base_url=agent_base_url,
|
||||
)
|
||||
return dispatcher
|
||||
|
||||
|
||||
def _normalize_agent_base_url(url: str) -> str:
|
||||
parsed = urlsplit(url)
|
||||
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/"))
|
||||
return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
|
||||
|
||||
|
||||
def _agent_base_url_from_env() -> str:
|
||||
if base_url := os.environ.get("AGENT_BASE_URL"):
|
||||
return base_url
|
||||
if ws_url := os.environ.get("AGENT_WS_URL"):
|
||||
return _normalize_agent_base_url(ws_url)
|
||||
return "http://127.0.0.1:8000"
|
||||
|
||||
|
||||
def _build_platform_from_env() -> PlatformClient:
|
||||
backend = os.environ.get("MATRIX_PLATFORM_BACKEND", "mock").strip().lower()
|
||||
if backend == "real":
|
||||
ws_url = os.environ["AGENT_WS_URL"]
|
||||
return RealPlatformClient(
|
||||
agent_api=AgentApiWrapper(agent_id="matrix-bot", base_url=ws_url),
|
||||
agent_id="matrix-bot",
|
||||
agent_base_url=_agent_base_url_from_env(),
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
|
@ -128,8 +141,7 @@ def build_runtime(
|
|||
auth_mgr = AuthManager(platform, store)
|
||||
settings_mgr = SettingsManager(platform, store)
|
||||
prototype_state = getattr(platform, "_prototype_state", None)
|
||||
agent_api = getattr(platform, "_agent_api", None)
|
||||
agent_base_url = os.environ.get("AGENT_BASE_URL", "http://127.0.0.1:8000")
|
||||
agent_base_url = _agent_base_url_from_env()
|
||||
dispatcher = EventDispatcher(
|
||||
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
||||
)
|
||||
|
|
@ -138,7 +150,6 @@ def build_runtime(
|
|||
dispatcher,
|
||||
client=client,
|
||||
store=store,
|
||||
agent_api=agent_api,
|
||||
prototype_state=prototype_state,
|
||||
agent_base_url=agent_base_url,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ def register_matrix_handlers(
|
|||
dispatcher: EventDispatcher,
|
||||
client=None,
|
||||
store=None,
|
||||
agent_api=None,
|
||||
prototype_state=None,
|
||||
agent_base_url: str = "http://127.0.0.1:8000",
|
||||
) -> None:
|
||||
|
|
@ -64,11 +63,11 @@ def register_matrix_handlers(
|
|||
dispatcher.register(IncomingCallback, "toggle_skill", handle_toggle_skill)
|
||||
dispatcher.register(IncomingCommand, "*", handle_unknown_command)
|
||||
|
||||
if agent_api is not None and prototype_state is not None:
|
||||
if prototype_state is not None:
|
||||
dispatcher.register(
|
||||
IncomingCommand,
|
||||
"save",
|
||||
make_handle_save(agent_api, store, prototype_state),
|
||||
make_handle_save(None, store, prototype_state),
|
||||
)
|
||||
dispatcher.register(IncomingCommand, "load", make_handle_load(store, prototype_state))
|
||||
dispatcher.register(IncomingCommand, "context", make_handle_context(store, prototype_state))
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ services:
|
|||
env_file: .env
|
||||
environment:
|
||||
AGENT_BASE_URL: http://platform-agent:8000
|
||||
AGENT_WS_URL: ws://platform-agent:8000/v1/agent_ws/
|
||||
SURFACES_WORKSPACE_DIR: /workspace
|
||||
depends_on:
|
||||
- platform-agent
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@
|
|||
- переключение Matrix backend через env:
|
||||
- `MATRIX_PLATFORM_BACKEND=mock`
|
||||
- `MATRIX_PLATFORM_BACKEND=real`
|
||||
- прямую отправку текста в live agent через `AGENT_WS_URL`
|
||||
- прямую отправку текста в live agent через `AGENT_BASE_URL`
|
||||
- локальное хранение settings и user mapping
|
||||
- изоляцию backend memory по `thread_id`
|
||||
- исправление повторных invite: бот теперь сначала `join()`, а уже потом решает, нужно ли пере-провиженить Space/chat tree
|
||||
|
|
@ -154,7 +154,7 @@ ws://127.0.0.1:8000/agent_ws/
|
|||
cd /Users/a/MAI/sem2/lambda/surfaces-bot
|
||||
|
||||
export MATRIX_PLATFORM_BACKEND=real
|
||||
export AGENT_WS_URL=ws://127.0.0.1:8000/agent_ws/
|
||||
export AGENT_BASE_URL=http://127.0.0.1:8000
|
||||
export MATRIX_HOMESERVER=https://matrix.lambda.coredump.ru
|
||||
export MATRIX_USER_ID=@lambda_surface_test_bot:matrix.lambda.coredump.ru
|
||||
export MATRIX_PASSWORD='YOUR_PASSWORD'
|
||||
|
|
@ -193,7 +193,7 @@ uv run uvicorn src.main:app --host 0.0.0.0 --port 8000
|
|||
cd /Users/a/MAI/sem2/lambda/surfaces-bot
|
||||
|
||||
export MATRIX_PLATFORM_BACKEND=real
|
||||
export AGENT_WS_URL=ws://127.0.0.1:8000/agent_ws/
|
||||
export AGENT_BASE_URL=http://127.0.0.1:8000
|
||||
export MATRIX_HOMESERVER=https://matrix.lambda.coredump.ru
|
||||
export MATRIX_USER_ID=@lambda_surface_test_bot:matrix.lambda.coredump.ru
|
||||
export MATRIX_PASSWORD='YOUR_PASSWORD'
|
||||
|
|
|
|||
|
|
@ -1,48 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlsplit, urlunsplit
|
||||
|
||||
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
|
||||
if str(_api_root) not in sys.path:
|
||||
sys.path.insert(0, str(_api_root))
|
||||
|
||||
from lambda_agent_api.agent_api import AgentApi # noqa: E402
|
||||
|
||||
|
||||
class AgentApiWrapper(AgentApi):
|
||||
"""Thin construction/factory shim over the pinned upstream AgentApi."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
base_url: str,
|
||||
*,
|
||||
chat_id: int | str = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self._base_url = self._normalize_base_url(base_url)
|
||||
self._init_kwargs = dict(kwargs)
|
||||
self.chat_id = chat_id
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
base_url=self._base_url,
|
||||
chat_id=chat_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
parsed = urlsplit(base_url)
|
||||
path = re.sub(r"(?:/v1)?/agent_ws(?:/[^/]+)?/?$", "", parsed.path.rstrip("/"))
|
||||
return urlunsplit((parsed.scheme, parsed.netloc, path, "", ""))
|
||||
|
||||
def for_chat(self, chat_id: int | str) -> AgentApiWrapper:
|
||||
return type(self)(
|
||||
agent_id=self.id,
|
||||
base_url=self._base_url,
|
||||
chat_id=chat_id,
|
||||
**self._init_kwargs,
|
||||
)
|
||||
|
|
@ -1 +1 @@
|
|||
"""Compatibility stub: AgentSessionClient was replaced by AgentApiWrapper in Phase 4."""
|
||||
"""Compatibility stub: AgentSessionClient was replaced by direct AgentApi usage in Phase 4."""
|
||||
|
|
|
|||
83
sdk/real.py
83
sdk/real.py
|
|
@ -4,9 +4,6 @@ import asyncio
|
|||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk
|
||||
|
||||
from sdk.agent_api_wrapper import AgentApiWrapper
|
||||
from sdk.interface import (
|
||||
Attachment,
|
||||
MessageChunk,
|
||||
|
|
@ -17,37 +14,32 @@ from sdk.interface import (
|
|||
UserSettings,
|
||||
)
|
||||
from sdk.prototype_state import PrototypeStateStore
|
||||
from sdk.upstream_agent_api import AgentApi, MsgEventSendFile, MsgEventTextChunk
|
||||
|
||||
|
||||
class RealPlatformClient(PlatformClient):
|
||||
def __init__(
|
||||
self,
|
||||
agent_api: AgentApiWrapper,
|
||||
agent_id: str,
|
||||
agent_base_url: str,
|
||||
prototype_state: PrototypeStateStore,
|
||||
platform: str = "matrix",
|
||||
agent_api_cls=AgentApi,
|
||||
) -> None:
|
||||
self._agent_api = agent_api
|
||||
self._agent_id = agent_id
|
||||
self._agent_base_url = agent_base_url
|
||||
self._agent_api_cls = agent_api_cls
|
||||
self._prototype_state = prototype_state
|
||||
self._platform = platform
|
||||
self._chat_apis: dict[str, AgentApiWrapper] = {}
|
||||
self._chat_api_lock = asyncio.Lock()
|
||||
self._chat_send_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
@property
|
||||
def agent_api(self) -> AgentApiWrapper:
|
||||
return self._agent_api
|
||||
def agent_id(self) -> str:
|
||||
return self._agent_id
|
||||
|
||||
async def _get_chat_api(self, chat_id: str):
|
||||
chat_key = str(chat_id)
|
||||
chat_api = self._chat_apis.get(chat_key)
|
||||
if chat_api is None:
|
||||
async with self._chat_api_lock:
|
||||
chat_api = self._chat_apis.get(chat_key)
|
||||
if chat_api is None:
|
||||
chat_api = self._agent_api.for_chat(chat_key)
|
||||
await chat_api.connect()
|
||||
self._chat_apis[chat_key] = chat_api
|
||||
return chat_api
|
||||
@property
|
||||
def agent_base_url(self) -> str:
|
||||
return self._agent_base_url
|
||||
|
||||
def _get_chat_send_lock(self, chat_id: str) -> asyncio.Lock:
|
||||
chat_key = str(chat_id)
|
||||
|
|
@ -82,9 +74,9 @@ class RealPlatformClient(PlatformClient):
|
|||
|
||||
lock = self._get_chat_send_lock(chat_id)
|
||||
async with lock:
|
||||
chat_api = await self._get_chat_api(chat_id)
|
||||
|
||||
chat_api = self._build_chat_api(chat_id)
|
||||
try:
|
||||
await chat_api.connect()
|
||||
async for event in self._stream_agent_events(
|
||||
chat_api, text, attachments=attachments
|
||||
):
|
||||
|
|
@ -96,8 +88,9 @@ class RealPlatformClient(PlatformClient):
|
|||
if attachment is not None:
|
||||
sent_attachments.append(attachment)
|
||||
except Exception as exc:
|
||||
await self._handle_chat_api_failure(chat_id, exc)
|
||||
|
||||
raise self._to_platform_error(exc) from exc
|
||||
finally:
|
||||
await self._close_chat_api(chat_api)
|
||||
await self._prototype_state.set_last_tokens_used(str(chat_id), 0)
|
||||
|
||||
response_kwargs = {
|
||||
|
|
@ -118,8 +111,9 @@ class RealPlatformClient(PlatformClient):
|
|||
) -> AsyncIterator[MessageChunk]:
|
||||
lock = self._get_chat_send_lock(chat_id)
|
||||
async with lock:
|
||||
chat_api = await self._get_chat_api(chat_id)
|
||||
chat_api = self._build_chat_api(chat_id)
|
||||
try:
|
||||
await chat_api.connect()
|
||||
async for event in self._stream_agent_events(
|
||||
chat_api, text, attachments=attachments
|
||||
):
|
||||
|
|
@ -132,7 +126,9 @@ class RealPlatformClient(PlatformClient):
|
|||
elif isinstance(event, MsgEventSendFile):
|
||||
continue
|
||||
except Exception as exc:
|
||||
await self._handle_chat_api_failure(chat_id, exc)
|
||||
raise self._to_platform_error(exc) from exc
|
||||
finally:
|
||||
await self._close_chat_api(chat_api)
|
||||
await self._prototype_state.set_last_tokens_used(str(chat_id), 0)
|
||||
yield MessageChunk(
|
||||
message_id=user_id,
|
||||
|
|
@ -148,20 +144,9 @@ class RealPlatformClient(PlatformClient):
|
|||
await self._prototype_state.update_settings(user_id, action)
|
||||
|
||||
async def disconnect_chat(self, chat_id: str) -> None:
|
||||
chat_key = str(chat_id)
|
||||
chat_api = self._chat_apis.pop(chat_key, None)
|
||||
self._chat_send_locks.pop(chat_key, None)
|
||||
if chat_api is not None:
|
||||
close = getattr(chat_api, "close", None)
|
||||
if callable(close):
|
||||
await close()
|
||||
self._chat_send_locks.pop(str(chat_id), None)
|
||||
|
||||
async def close(self) -> None:
|
||||
for chat_api in list(self._chat_apis.values()):
|
||||
close = getattr(chat_api, "close", None)
|
||||
if callable(close):
|
||||
await close()
|
||||
self._chat_apis.clear()
|
||||
self._chat_send_locks.clear()
|
||||
|
||||
async def _stream_agent_events(
|
||||
|
|
@ -175,10 +160,26 @@ class RealPlatformClient(PlatformClient):
|
|||
async for event in event_stream:
|
||||
yield event
|
||||
|
||||
async def _handle_chat_api_failure(self, chat_id: str, exc: Exception) -> None:
|
||||
await self.disconnect_chat(chat_id)
|
||||
def _build_chat_api(self, chat_id: str):
|
||||
return self._agent_api_cls(
|
||||
agent_id=self._agent_id,
|
||||
base_url=self._agent_base_url,
|
||||
chat_id=str(chat_id),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _close_chat_api(chat_api) -> None:
|
||||
close = getattr(chat_api, "close", None)
|
||||
if callable(close):
|
||||
try:
|
||||
await close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _to_platform_error(exc: Exception) -> PlatformError:
|
||||
code = getattr(exc, "code", None) or "PLATFORM_CONNECTION_ERROR"
|
||||
raise PlatformError(str(exc), code=code) from exc
|
||||
return PlatformError(str(exc), code=code)
|
||||
|
||||
@staticmethod
|
||||
def _attachment_paths(attachments: list[Attachment] | None) -> list[str]:
|
||||
|
|
|
|||
19
sdk/upstream_agent_api.py
Normal file
19
sdk/upstream_agent_api.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_api_root = Path(__file__).resolve().parents[1] / "external" / "platform-agent_api"
|
||||
if str(_api_root) not in sys.path:
|
||||
sys.path.insert(0, str(_api_root))
|
||||
|
||||
from lambda_agent_api.agent_api import AgentApi, AgentBusyException, AgentException # noqa: E402
|
||||
from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"AgentApi",
|
||||
"AgentBusyException",
|
||||
"AgentException",
|
||||
"MsgEventSendFile",
|
||||
"MsgEventTextChunk",
|
||||
]
|
||||
|
|
@ -908,34 +908,21 @@ async def test_prepare_live_sync_returns_next_batch_from_bootstrap_sync():
|
|||
|
||||
|
||||
async def test_build_runtime_uses_real_platform_when_matrix_backend_is_real(monkeypatch):
|
||||
bot_module = importlib.import_module("adapter.matrix.bot")
|
||||
|
||||
class FakeAgentApiWrapper:
|
||||
def __init__(self, agent_id: str, base_url: str) -> None:
|
||||
self.agent_id = agent_id
|
||||
self.base_url = base_url
|
||||
|
||||
monkeypatch.setattr(bot_module, "AgentApiWrapper", FakeAgentApiWrapper)
|
||||
monkeypatch.setenv("MATRIX_PLATFORM_BACKEND", "real")
|
||||
monkeypatch.setenv("AGENT_WS_URL", "ws://agent.example/agent_ws/")
|
||||
monkeypatch.setenv("AGENT_BASE_URL", "http://agent.example")
|
||||
|
||||
runtime = build_runtime()
|
||||
|
||||
assert isinstance(runtime.platform, RealPlatformClient)
|
||||
assert runtime.platform.agent_api.base_url == "ws://agent.example/agent_ws/"
|
||||
assert runtime.platform.agent_base_url == "http://agent.example"
|
||||
assert runtime.platform.agent_id == "matrix-bot"
|
||||
|
||||
|
||||
async def test_matrix_main_closes_platform_without_connecting_root_agent(monkeypatch):
|
||||
bot_module = importlib.import_module("adapter.matrix.bot")
|
||||
|
||||
platform_close = AsyncMock()
|
||||
agent_connect = AsyncMock()
|
||||
runtime = SimpleNamespace(
|
||||
platform=SimpleNamespace(
|
||||
close=platform_close,
|
||||
agent_api=SimpleNamespace(connect=agent_connect),
|
||||
)
|
||||
)
|
||||
runtime = SimpleNamespace(platform=SimpleNamespace(close=platform_close))
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -959,7 +946,6 @@ async def test_matrix_main_closes_platform_without_connecting_root_agent(monkeyp
|
|||
|
||||
await bot_module.main()
|
||||
|
||||
agent_connect.assert_not_awaited()
|
||||
platform_close.assert_awaited_once()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ Smoke test: полный цикл через dispatcher + реальные manag
|
|||
Имитирует что делает адаптер (Telegram или Matrix) при получении события.
|
||||
"""
|
||||
import pytest
|
||||
from lambda_agent_api.server import MsgEventTextChunk
|
||||
|
||||
from core.auth import AuthManager
|
||||
from core.chat import ChatManager
|
||||
|
|
@ -23,10 +22,13 @@ from core.store import InMemoryStore
|
|||
from sdk.mock import MockPlatformClient
|
||||
from sdk.prototype_state import PrototypeStateStore
|
||||
from sdk.real import RealPlatformClient
|
||||
from sdk.upstream_agent_api import MsgEventTextChunk
|
||||
|
||||
|
||||
class FakeAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
|
||||
self.agent_id = agent_id
|
||||
self.base_url = base_url
|
||||
self.chat_id = chat_id
|
||||
self.calls: list[tuple[str, list[str]]] = []
|
||||
self.connect_calls = 0
|
||||
|
|
@ -46,12 +48,12 @@ class FakeAgentApi:
|
|||
class FakeAgentApiFactory:
|
||||
def __init__(self) -> None:
|
||||
self.created_chat_ids: list[str] = []
|
||||
self.instances: dict[str, FakeAgentApi] = {}
|
||||
self.instances: dict[str, list[FakeAgentApi]] = {}
|
||||
|
||||
def for_chat(self, chat_id: str) -> FakeAgentApi:
|
||||
chat_api = FakeAgentApi(chat_id)
|
||||
def __call__(self, agent_id: str, base_url: str, chat_id: str) -> FakeAgentApi:
|
||||
chat_api = FakeAgentApi(agent_id, base_url, chat_id)
|
||||
self.created_chat_ids.append(chat_id)
|
||||
self.instances[chat_id] = chat_api
|
||||
self.instances.setdefault(chat_id, []).append(chat_api)
|
||||
return chat_api
|
||||
|
||||
|
||||
|
|
@ -73,7 +75,9 @@ def dispatcher():
|
|||
def real_dispatcher():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
platform = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
agent_id="matrix-bot",
|
||||
agent_base_url="http://platform-agent:8000",
|
||||
agent_api_cls=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
|
@ -147,7 +151,7 @@ async def test_toggle_skill_callback(dispatcher):
|
|||
assert any("browser" in r.text for r in result if isinstance(r, OutgoingMessage))
|
||||
|
||||
|
||||
async def test_full_flow_with_real_platform_uses_shared_agent_api(real_dispatcher):
|
||||
async def test_full_flow_with_real_platform_uses_direct_agent_api(real_dispatcher):
|
||||
dispatcher, agent_api = real_dispatcher
|
||||
|
||||
start = IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="start")
|
||||
|
|
@ -160,7 +164,7 @@ async def test_full_flow_with_real_platform_uses_shared_agent_api(real_dispatche
|
|||
|
||||
assert texts == ["[REAL] Привет!"]
|
||||
assert agent_api.created_chat_ids == ["C1"]
|
||||
assert agent_api.instances["C1"].calls == [("Привет!", [])]
|
||||
assert [instance.calls for instance in agent_api.instances["C1"]] == [[("Привет!", [])]]
|
||||
|
||||
|
||||
async def test_full_flow_with_real_platform_forwards_workspace_attachment(real_dispatcher):
|
||||
|
|
@ -185,6 +189,6 @@ async def test_full_flow_with_real_platform_forwards_workspace_attachment(real_d
|
|||
)
|
||||
await dispatcher.dispatch(msg)
|
||||
|
||||
assert agent_api.instances["C1"].calls == [
|
||||
("Посмотри файл", ["surfaces/matrix/u1/room/inbox/report.pdf"])
|
||||
assert [instance.calls for instance in agent_api.instances["C1"]] == [
|
||||
[("Посмотри файл", ["surfaces/matrix/u1/room/inbox/report.pdf"])]
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,16 +1,10 @@
|
|||
"""Compatibility tests after the Phase 4 migration."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_api_root = Path(__file__).resolve().parents[2] / "external" / "platform-agent_api"
|
||||
if str(_api_root) not in sys.path:
|
||||
sys.path.insert(0, str(_api_root))
|
||||
|
||||
|
||||
def test_lambda_agent_api_module_is_importable():
|
||||
from lambda_agent_api.agent_api import AgentApi
|
||||
from sdk.upstream_agent_api import AgentApi
|
||||
|
||||
assert AgentApi is not None
|
||||
|
||||
|
|
@ -18,4 +12,4 @@ def test_lambda_agent_api_module_is_importable():
|
|||
def test_agent_session_module_is_intentionally_stubbed():
|
||||
contents = Path(__file__).resolve().parents[2] / "sdk" / "agent_session.py"
|
||||
|
||||
assert "replaced by AgentApiWrapper" in contents.read_text()
|
||||
assert "replaced by direct AgentApi usage" in contents.read_text()
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
from lambda_agent_api.server import MsgEventSendFile, MsgEventTextChunk
|
||||
from pydantic import Field
|
||||
|
||||
import sdk.agent_api_wrapper as agent_api_wrapper_module
|
||||
from core.protocol import SettingsAction
|
||||
from sdk.agent_api_wrapper import AgentApiWrapper
|
||||
from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformError, UserSettings
|
||||
from sdk.prototype_state import PrototypeStateStore
|
||||
from sdk.real import RealPlatformClient
|
||||
from sdk.upstream_agent_api import MsgEventSendFile, MsgEventTextChunk
|
||||
|
||||
|
||||
class FakeChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
|
||||
self.agent_id = agent_id
|
||||
self.base_url = base_url
|
||||
self.chat_id = str(chat_id)
|
||||
self.calls: list[str] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
|
|
@ -33,155 +33,125 @@ class FakeChatAgentApi:
|
|||
|
||||
|
||||
class FakeAgentApiFactory:
|
||||
def __init__(self) -> None:
|
||||
self.created_chat_ids: list[str] = []
|
||||
self.instances: dict[str, FakeChatAgentApi] = {}
|
||||
def __init__(self, chat_api_cls=FakeChatAgentApi) -> None:
|
||||
self.chat_api_cls = chat_api_cls
|
||||
self.created_calls: list[tuple[str, str, str]] = []
|
||||
self.instances_by_chat: dict[str, list[FakeChatAgentApi]] = {}
|
||||
|
||||
def for_chat(self, chat_id: str) -> FakeChatAgentApi:
|
||||
chat_api = FakeChatAgentApi(chat_id)
|
||||
self.created_chat_ids.append(chat_id)
|
||||
self.instances[chat_id] = chat_api
|
||||
def __call__(self, agent_id: str, base_url: str, chat_id: str):
|
||||
chat_key = str(chat_id)
|
||||
chat_api = self.chat_api_cls(agent_id, base_url, chat_key)
|
||||
self.created_calls.append((agent_id, base_url, chat_key))
|
||||
self.instances_by_chat.setdefault(chat_key, []).append(chat_api)
|
||||
return chat_api
|
||||
|
||||
def latest(self, chat_id: str):
|
||||
return self.instances_by_chat[str(chat_id)][-1]
|
||||
|
||||
class BlockingChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.calls: list[str] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
|
||||
class BlockingTracker:
|
||||
def __init__(self) -> None:
|
||||
self.active_calls = 0
|
||||
self.max_active_calls = 0
|
||||
self.started = asyncio.Event()
|
||||
self.release = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
||||
async def close(self) -> None:
|
||||
self.close_calls += 1
|
||||
class BlockingChatAgentApi(FakeChatAgentApi):
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
base_url: str,
|
||||
chat_id: str,
|
||||
*,
|
||||
tracker: BlockingTracker,
|
||||
) -> None:
|
||||
super().__init__(agent_id, base_url, chat_id)
|
||||
self._tracker = tracker
|
||||
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append(text)
|
||||
self.active_calls += 1
|
||||
self.max_active_calls = max(self.max_active_calls, self.active_calls)
|
||||
self.started.set()
|
||||
await self.release.wait()
|
||||
self.active_calls -= 1
|
||||
self._tracker.active_calls += 1
|
||||
self._tracker.max_active_calls = max(
|
||||
self._tracker.max_active_calls,
|
||||
self._tracker.active_calls,
|
||||
)
|
||||
self._tracker.started.set()
|
||||
await self._tracker.release.wait()
|
||||
self._tracker.active_calls -= 1
|
||||
yield MsgEventTextChunk(text=text)
|
||||
|
||||
|
||||
class AttachmentTrackingChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
class BlockingAgentApiFactory(FakeAgentApiFactory):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tracker = BlockingTracker()
|
||||
|
||||
def __call__(self, agent_id: str, base_url: str, chat_id: str):
|
||||
chat_key = str(chat_id)
|
||||
chat_api = BlockingChatAgentApi(
|
||||
agent_id,
|
||||
base_url,
|
||||
chat_key,
|
||||
tracker=self.tracker,
|
||||
)
|
||||
self.created_calls.append((agent_id, base_url, chat_key))
|
||||
self.instances_by_chat.setdefault(chat_key, []).append(chat_api)
|
||||
return chat_api
|
||||
|
||||
|
||||
class AttachmentTrackingChatAgentApi(FakeChatAgentApi):
|
||||
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
|
||||
super().__init__(agent_id, base_url, chat_id)
|
||||
self.calls: list[tuple[str, list[str] | None]] = []
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
||||
async def close(self) -> None:
|
||||
self.close_calls += 1
|
||||
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append((text, attachments))
|
||||
yield MsgEventTextChunk(text=text)
|
||||
|
||||
|
||||
class AttachmentTrackingAgentApiFactory:
|
||||
def __init__(self, chat_api_cls=AttachmentTrackingChatAgentApi) -> None:
|
||||
self.chat_api_cls = chat_api_cls
|
||||
self.created_chat_ids: list[str] = []
|
||||
self.instances: dict[str, AttachmentTrackingChatAgentApi] = {}
|
||||
|
||||
def for_chat(self, chat_id: str) -> AttachmentTrackingChatAgentApi:
|
||||
chat_api = self.chat_api_cls(chat_id)
|
||||
self.created_chat_ids.append(chat_id)
|
||||
self.instances[chat_id] = chat_api
|
||||
return chat_api
|
||||
|
||||
|
||||
class FlakyChatAgentApi:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.connect_calls = 0
|
||||
self.close_calls = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.connect_calls += 1
|
||||
|
||||
async def close(self) -> None:
|
||||
self.close_calls += 1
|
||||
|
||||
class FlakyChatAgentApi(FakeChatAgentApi):
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
raise ConnectionError("Connection closed")
|
||||
yield
|
||||
|
||||
|
||||
class ReuseSensitiveChatAgentApi(FakeChatAgentApi):
|
||||
def __init__(self, agent_id: str, base_url: str, chat_id: str) -> None:
|
||||
super().__init__(agent_id, base_url, chat_id)
|
||||
self._send_calls = 0
|
||||
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
self.calls.append(text)
|
||||
self._send_calls += 1
|
||||
if text == "first":
|
||||
yield MsgEventTextChunk(text="tool ok")
|
||||
return
|
||||
if text == "second" and self._send_calls == 1:
|
||||
yield MsgEventTextChunk(text="Missing")
|
||||
|
||||
|
||||
class MessageResponseWithAttachments(MessageResponse):
|
||||
attachments: list[Attachment] = Field(default_factory=list)
|
||||
|
||||
|
||||
def test_agent_api_wrapper_normalizes_base_url_and_uses_modern_constructor(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
||||
captured["agent_id"] = agent_id
|
||||
captured["base_url"] = base_url
|
||||
captured["chat_id"] = chat_id
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||
|
||||
wrapper = AgentApiWrapper(
|
||||
agent_id="agent-1",
|
||||
base_url="ws://platform-agent:8000/v1/agent_ws/",
|
||||
chat_id="41",
|
||||
def make_real_platform_client(
|
||||
agent_api_cls,
|
||||
*,
|
||||
prototype_state: PrototypeStateStore | None = None,
|
||||
) -> RealPlatformClient:
|
||||
return RealPlatformClient(
|
||||
agent_id="matrix-bot",
|
||||
agent_base_url="http://platform-agent:8000",
|
||||
agent_api_cls=agent_api_cls,
|
||||
prototype_state=prototype_state or PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
||||
assert wrapper.chat_id == "41"
|
||||
assert wrapper._base_url == "ws://platform-agent:8000"
|
||||
assert captured == {
|
||||
"agent_id": "agent-1",
|
||||
"base_url": "ws://platform-agent:8000",
|
||||
"chat_id": "41",
|
||||
}
|
||||
|
||||
|
||||
def test_agent_api_wrapper_for_chat_reuses_normalized_base_url(monkeypatch):
|
||||
init_calls = []
|
||||
|
||||
def fake_init(self, agent_id, base_url=None, chat_id=0, **kwargs):
|
||||
self.id = agent_id
|
||||
self.chat_id = chat_id
|
||||
self.url = base_url
|
||||
init_calls.append((agent_id, base_url, chat_id))
|
||||
|
||||
monkeypatch.setattr(agent_api_wrapper_module.AgentApi, "__init__", fake_init)
|
||||
|
||||
root = AgentApiWrapper(
|
||||
agent_id="agent-1",
|
||||
base_url="http://platform-agent:8000/v1/agent_ws/",
|
||||
chat_id="1",
|
||||
)
|
||||
|
||||
child = root.for_chat("99")
|
||||
|
||||
assert child is not root
|
||||
assert child.chat_id == "99"
|
||||
assert child._base_url == "http://platform-agent:8000"
|
||||
assert init_calls == [
|
||||
("agent-1", "http://platform-agent:8000", "1"),
|
||||
("agent-1", "http://platform-agent:8000", "99"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_get_or_create_user_uses_local_state():
|
||||
client = RealPlatformClient(
|
||||
agent_api=FakeAgentApiFactory(),
|
||||
prototype_state=PrototypeStateStore(),
|
||||
)
|
||||
client = make_real_platform_client(FakeAgentApiFactory())
|
||||
|
||||
first = await client.get_or_create_user("u1", "matrix", "Alice")
|
||||
second = await client.get_or_create_user("u1", "matrix")
|
||||
|
|
@ -194,14 +164,10 @@ async def test_real_platform_client_get_or_create_user_uses_local_state():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_send_message_uses_chat_bound_client():
|
||||
async def test_real_platform_client_send_message_uses_direct_agent_api_per_chat():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
prototype_state = PrototypeStateStore()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=prototype_state,
|
||||
platform="matrix",
|
||||
)
|
||||
client = make_real_platform_client(agent_api, prototype_state=prototype_state)
|
||||
|
||||
result = await client.send_message("@alice:example.org", "chat-7", "hello")
|
||||
|
||||
|
|
@ -211,21 +177,18 @@ async def test_real_platform_client_send_message_uses_chat_bound_client():
|
|||
tokens_used=0,
|
||||
finished=True,
|
||||
)
|
||||
assert agent_api.created_chat_ids == ["chat-7"]
|
||||
assert agent_api.instances["chat-7"].chat_id == "chat-7"
|
||||
assert agent_api.instances["chat-7"].calls == ["hello"]
|
||||
assert agent_api.instances["chat-7"].connect_calls == 1
|
||||
assert agent_api.created_calls == [("matrix-bot", "http://platform-agent:8000", "chat-7")]
|
||||
assert agent_api.latest("chat-7").chat_id == "chat-7"
|
||||
assert agent_api.latest("chat-7").calls == ["hello"]
|
||||
assert agent_api.latest("chat-7").connect_calls == 1
|
||||
assert agent_api.latest("chat-7").close_calls == 1
|
||||
assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_forwards_attachments_to_chat_api():
|
||||
agent_api = AttachmentTrackingAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=AttachmentTrackingChatAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
attachment = Attachment(
|
||||
workspace_path="surfaces/matrix/alice/room/inbox/report.pdf",
|
||||
mime_type="application/pdf",
|
||||
|
|
@ -240,7 +203,7 @@ async def test_real_platform_client_forwards_attachments_to_chat_api():
|
|||
attachments=[attachment],
|
||||
)
|
||||
|
||||
assert agent_api.instances["chat-7"].calls == [
|
||||
assert agent_api.latest("chat-7").calls == [
|
||||
("hello", ["surfaces/matrix/alice/room/inbox/report.pdf"])
|
||||
]
|
||||
assert result.response == "hello"
|
||||
|
|
@ -256,17 +219,10 @@ async def test_real_platform_client_preserves_send_file_events_in_sync_result(mo
|
|||
yield MsgEventSendFile(path="report.pdf")
|
||||
yield MsgEventTextChunk(text="llo")
|
||||
|
||||
agent_api = AttachmentTrackingAgentApiFactory(chat_api_cls=FileEventAgentApi)
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=FileEventAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"sdk.real.MessageResponse",
|
||||
MessageResponseWithAttachments,
|
||||
)
|
||||
monkeypatch.setattr("sdk.real.MessageResponse", MessageResponseWithAttachments)
|
||||
|
||||
result = await client.send_message("@alice:example.org", "chat-7", "hello")
|
||||
|
||||
|
|
@ -284,63 +240,61 @@ async def test_real_platform_client_preserves_send_file_events_in_sync_result(mo
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_reuses_cached_chat_client():
|
||||
async def test_real_platform_client_uses_fresh_agent_connection_per_request():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
await client.send_message("@alice:example.org", "chat-1", "again")
|
||||
|
||||
assert agent_api.created_chat_ids == ["chat-1"]
|
||||
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
||||
assert agent_api.instances["chat-1"].connect_calls == 1
|
||||
assert agent_api.instances["chat-1"].close_calls == 0
|
||||
assert agent_api.created_calls == [
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
]
|
||||
assert [instance.calls for instance in agent_api.instances_by_chat["chat-1"]] == [
|
||||
["hello"],
|
||||
["again"],
|
||||
]
|
||||
assert all(instance.connect_calls == 1 for instance in agent_api.instances_by_chat["chat-1"])
|
||||
assert all(instance.close_calls == 1 for instance in agent_api.instances_by_chat["chat-1"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_avoids_reuse_sensitive_second_message_loss():
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=ReuseSensitiveChatAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
first = await client.send_message("@alice:example.org", "chat-1", "first")
|
||||
second = await client.send_message("@alice:example.org", "chat-1", "second")
|
||||
|
||||
assert first.response == "tool ok"
|
||||
assert second.response == "Missing"
|
||||
assert len(agent_api.instances_by_chat["chat-1"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_wraps_connection_closed_as_platform_error():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
agent_api.instances["chat-1"] = FlakyChatAgentApi("chat-1")
|
||||
agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(
|
||||
chat_id, FlakyChatAgentApi(chat_id)
|
||||
)
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=FlakyChatAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
with pytest.raises(PlatformError, match="Connection closed") as exc_info:
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
|
||||
assert exc_info.value.code == "PLATFORM_CONNECTION_ERROR"
|
||||
assert "chat-1" not in client._chat_apis
|
||||
assert agent_api.instances["chat-1"].close_calls == 1
|
||||
assert agent_api.latest("chat-1").close_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_reconnects_after_closed_chat_api():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
flaky = FlakyChatAgentApi("chat-1")
|
||||
healthy = AttachmentTrackingChatAgentApi("chat-1")
|
||||
provided = iter([flaky, healthy])
|
||||
async def test_real_platform_client_uses_fresh_connection_after_failure():
|
||||
class SometimesFlakyAgentApi(FakeChatAgentApi):
|
||||
async def send_message(self, text: str, attachments: list[str] | None = None):
|
||||
if text == "hello":
|
||||
raise ConnectionError("Connection closed")
|
||||
self.calls.append(text)
|
||||
yield MsgEventTextChunk(text=text)
|
||||
|
||||
def for_chat(chat_id: str):
|
||||
chat_api = next(provided)
|
||||
agent_api.created_chat_ids.append(chat_id)
|
||||
agent_api.instances[chat_id] = chat_api
|
||||
return chat_api
|
||||
|
||||
agent_api.for_chat = for_chat
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = FakeAgentApiFactory(chat_api_cls=SometimesFlakyAgentApi)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
with pytest.raises(PlatformError, match="Connection closed"):
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
|
|
@ -348,60 +302,17 @@ async def test_real_platform_client_reconnects_after_closed_chat_api():
|
|||
result = await client.send_message("@alice:example.org", "chat-1", "again")
|
||||
|
||||
assert result.response == "again"
|
||||
assert agent_api.created_chat_ids == ["chat-1", "chat-1"]
|
||||
assert healthy.calls == [("again", None)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_creates_chat_client_atomically_for_concurrent_requests():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
client.send_message("@alice:example.org", "chat-1", "hello"),
|
||||
client.send_message("@alice:example.org", "chat-1", "again"),
|
||||
)
|
||||
|
||||
assert [result.response for result in results] == ["hello", "again"]
|
||||
assert agent_api.created_chat_ids == ["chat-1"]
|
||||
assert agent_api.instances["chat-1"].connect_calls == 1
|
||||
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_creates_distinct_clients_per_chat():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
await client.send_message("@alice:example.org", "chat-2", "world")
|
||||
|
||||
assert agent_api.created_chat_ids == ["chat-1", "chat-2"]
|
||||
assert agent_api.instances["chat-1"] is not agent_api.instances["chat-2"]
|
||||
assert agent_api.instances["chat-1"].calls == ["hello"]
|
||||
assert agent_api.instances["chat-2"].calls == ["world"]
|
||||
assert agent_api.created_calls == [
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
]
|
||||
assert agent_api.latest("chat-1").calls == ["again"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_serializes_same_chat_streams_across_send_paths():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
agent_api.instances["chat-1"] = BlockingChatAgentApi("chat-1")
|
||||
agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(
|
||||
chat_id, BlockingChatAgentApi(chat_id)
|
||||
)
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
agent_api = BlockingAgentApiFactory()
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
async def consume_stream():
|
||||
chunks = []
|
||||
|
|
@ -410,32 +321,48 @@ async def test_real_platform_client_serializes_same_chat_streams_across_send_pat
|
|||
return chunks
|
||||
|
||||
stream_task = asyncio.create_task(consume_stream())
|
||||
await asyncio.wait_for(agent_api.instances["chat-1"].started.wait(), timeout=1)
|
||||
await asyncio.wait_for(agent_api.tracker.started.wait(), timeout=1)
|
||||
|
||||
send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again"))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert agent_api.instances["chat-1"].calls == ["hello"]
|
||||
assert agent_api.instances["chat-1"].max_active_calls == 1
|
||||
assert len(agent_api.instances_by_chat["chat-1"]) == 1
|
||||
assert agent_api.instances_by_chat["chat-1"][0].calls == ["hello"]
|
||||
assert agent_api.tracker.max_active_calls == 1
|
||||
|
||||
agent_api.instances["chat-1"].release.set()
|
||||
agent_api.tracker.release.set()
|
||||
stream_chunks = await stream_task
|
||||
send_result = await send_task
|
||||
|
||||
assert [chunk.delta for chunk in stream_chunks] == ["hello", ""]
|
||||
assert send_result.response == "again"
|
||||
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
||||
assert agent_api.instances["chat-1"].max_active_calls == 1
|
||||
assert [instance.calls for instance in agent_api.instances_by_chat["chat-1"]] == [
|
||||
["hello"],
|
||||
["again"],
|
||||
]
|
||||
assert agent_api.tracker.max_active_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_creates_distinct_connections_per_chat():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
await client.send_message("@alice:example.org", "chat-1", "hello")
|
||||
await client.send_message("@alice:example.org", "chat-2", "world")
|
||||
|
||||
assert agent_api.created_calls == [
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-1"),
|
||||
("matrix-bot", "http://platform-agent:8000", "chat-2"),
|
||||
]
|
||||
assert agent_api.latest("chat-1").calls == ["hello"]
|
||||
assert agent_api.latest("chat-2").calls == ["world"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
|
||||
agent_api = FakeAgentApiFactory()
|
||||
client = RealPlatformClient(
|
||||
agent_api=agent_api,
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
client = make_real_platform_client(agent_api)
|
||||
|
||||
chunks = []
|
||||
async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
|
||||
|
|
@ -461,17 +388,14 @@ async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
|
|||
tokens_used=0,
|
||||
),
|
||||
]
|
||||
assert agent_api.created_chat_ids == ["chat-1"]
|
||||
assert agent_api.instances["chat-1"].calls == ["hello"]
|
||||
assert agent_api.created_calls == [("matrix-bot", "http://platform-agent:8000", "chat-1")]
|
||||
assert agent_api.latest("chat-1").calls == ["hello"]
|
||||
assert agent_api.latest("chat-1").close_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_platform_client_settings_are_local():
|
||||
client = RealPlatformClient(
|
||||
agent_api=FakeAgentApiFactory(),
|
||||
prototype_state=PrototypeStateStore(),
|
||||
platform="matrix",
|
||||
)
|
||||
client = make_real_platform_client(FakeAgentApiFactory())
|
||||
|
||||
await client.update_settings(
|
||||
"usr-matrix-u1",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue