fix prototype backend review issues

This commit is contained in:
Mikhail Putilovskij 2026-04-08 01:43:44 +03:00
parent 94bdb44b93
commit 37643a9695
9 changed files with 182 additions and 46 deletions

View file

@ -15,12 +15,14 @@ dependencies = [
"structlog>=24.1", "structlog>=24.1",
"python-dotenv>=1.0", "python-dotenv>=1.0",
"httpx>=0.27", "httpx>=0.27",
"aiohttp>=3.9",
] ]
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"pytest>=8.0", "pytest>=8.0",
"pytest-asyncio>=0.23", "pytest-asyncio>=0.23",
"pytest-aiohttp>=1.0",
"pytest-cov>=4.1", "pytest-cov>=4.1",
"ruff>=0.3", "ruff>=0.3",
"mypy>=1.8", "mypy>=1.8",

View file

@ -1,3 +1,9 @@
from sdk.real import RealPlatformClient
__all__ = ["RealPlatformClient"] __all__ = ["RealPlatformClient"]
def __getattr__(name: str):
if name == "RealPlatformClient":
from sdk.real import RealPlatformClient
return RealPlatformClient
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View file

@ -4,8 +4,6 @@ from dataclasses import dataclass
from typing import AsyncIterator from typing import AsyncIterator
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
import aiohttp
from sdk.interface import MessageChunk, MessageResponse, PlatformError from sdk.interface import MessageChunk, MessageResponse, PlatformError
@ -41,6 +39,8 @@ class AgentSessionClient:
) )
async def stream_message(self, *, thread_key: str, text: str) -> AsyncIterator[MessageChunk]: async def stream_message(self, *, thread_key: str, text: str) -> AsyncIterator[MessageChunk]:
import aiohttp
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.ws_connect( async with session.ws_connect(
self._ws_url(thread_key), self._ws_url(thread_key),

View file

@ -34,7 +34,6 @@ class PrototypeStateStore:
async def get_or_create_user( async def get_or_create_user(
self, self,
*,
external_id: str, external_id: str,
platform: str, platform: str,
display_name: str | None = None, display_name: str | None = None,
@ -54,14 +53,14 @@ class PrototypeStateStore:
created_at=datetime.now(UTC), created_at=datetime.now(UTC),
is_new=True, is_new=True,
) )
self._users[key] = user self._users[key] = user.model_copy(update={"is_new": False})
return user.model_copy() return user.model_copy()
async def get_settings(self, user_id: str) -> UserSettings: async def get_settings(self, user_id: str) -> UserSettings:
stored = self._settings.get(user_id, {}) stored = self._settings.get(user_id, {})
return UserSettings( return UserSettings(
skills={**DEFAULT_SKILLS, **stored.get("skills", {})}, skills={**DEFAULT_SKILLS, **stored.get("skills", {})},
connectors=stored.get("connectors", {}), connectors=dict(stored.get("connectors", {})),
soul={**DEFAULT_SOUL, **stored.get("soul", {})}, soul={**DEFAULT_SOUL, **stored.get("soul", {})},
safety={**DEFAULT_SAFETY, **stored.get("safety", {})}, safety={**DEFAULT_SAFETY, **stored.get("safety", {})},
plan={**DEFAULT_PLAN, **stored.get("plan", {})}, plan={**DEFAULT_PLAN, **stored.get("plan", {})},

View file

@ -1,18 +1,21 @@
from __future__ import annotations from __future__ import annotations
from typing import AsyncIterator from typing import TYPE_CHECKING, AsyncIterator
from sdk.agent_session import AgentSessionClient, build_thread_key from sdk.agent_session import build_thread_key
from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformClient, User, UserSettings from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformClient, User, UserSettings
from sdk.prototype_state import PrototypeStateStore from sdk.prototype_state import PrototypeStateStore
if TYPE_CHECKING:
from sdk.agent_session import AgentSessionClient
class RealPlatformClient(PlatformClient): class RealPlatformClient(PlatformClient):
def __init__( def __init__(
self, self,
agent_sessions: AgentSessionClient, agent_sessions: AgentSessionClient,
prototype_state: PrototypeStateStore, prototype_state: PrototypeStateStore,
platform: str, platform: str = "matrix",
) -> None: ) -> None:
self._agent_sessions = agent_sessions self._agent_sessions = agent_sessions
self._prototype_state = prototype_state self._prototype_state = prototype_state

View file

@ -1,9 +1,58 @@
import sys
from pathlib import Path
from types import ModuleType
import pytest import pytest
from aiohttp import web from aiohttp import web
from sdk.interface import MessageChunk, MessageResponse from sdk.interface import MessageChunk, MessageResponse
from sdk.agent_session import AgentSessionClient, AgentSessionConfig, build_thread_key from sdk.agent_session import AgentSessionClient, AgentSessionConfig, build_thread_key
AGENT_ROOT = Path(__file__).resolve().parents[2] / "external" / "platform-agent"
AGENT_API_ROOT = Path(__file__).resolve().parents[2] / "external" / "platform-agent_api"
for path in (AGENT_ROOT, AGENT_API_ROOT):
if str(path) not in sys.path:
sys.path.insert(0, str(path))
if "fastapi" not in sys.modules:
fastapi = ModuleType("fastapi")
class _Router:
def websocket(self, _path: str):
def decorator(fn):
return fn
return decorator
class _WebSocketDisconnect(Exception):
pass
def _depends(value):
return value
fastapi.APIRouter = _Router
fastapi.WebSocket = object
fastapi.WebSocketDisconnect = _WebSocketDisconnect
fastapi.Depends = _depends
sys.modules["fastapi"] = fastapi
if "src.agent" not in sys.modules:
agent_module = ModuleType("src.agent")
class _AgentService:
async def astream(self, text: str, thread_id: str):
yield text
def _get_agent_service():
return _AgentService()
agent_module.AgentService = _AgentService
agent_module.get_agent_service = _get_agent_service
sys.modules["src.agent"] = agent_module
from lambda_agent_api.client import MsgUserMessage # noqa: E402
from src.api.external import process_message # noqa: E402
def test_build_thread_key_uses_platform_user_and_chat_id(): def test_build_thread_key_uses_platform_user_and_chat_id():
assert build_thread_key("matrix", "@alice:example.org", "C1") == "6:matrix18:@alice:example.org2:C1" assert build_thread_key("matrix", "@alice:example.org", "C1") == "6:matrix18:@alice:example.org2:C1"
@ -18,11 +67,13 @@ def test_build_thread_key_does_not_collide_when_user_id_contains_colons():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_message_yields_text_chunks_and_end(aiohttp_server): async def test_stream_message_yields_text_chunks_and_end(aiohttp_server):
thread_key = build_thread_key("matrix", "@alice:example.org", "C1")
async def handler(request): async def handler(request):
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
await ws.prepare(request) await ws.prepare(request)
assert request.query["thread_id"] == "matrix:@alice:example.org:C1" assert request.query["thread_id"] == thread_key
await ws.send_json({"type": "STATUS"}) await ws.send_json({"type": "STATUS"})
@ -43,25 +94,27 @@ async def test_stream_message_yields_text_chunks_and_end(aiohttp_server):
chunks = [] chunks = []
async for chunk in client.stream_message( async for chunk in client.stream_message(
thread_key="matrix:@alice:example.org:C1", thread_key=thread_key,
text="hello", text="hello",
): ):
chunks.append(chunk) chunks.append(chunk)
assert chunks == [ assert chunks == [
MessageChunk(message_id="matrix:@alice:example.org:C1", delta="hel", finished=False, tokens_used=0), MessageChunk(message_id=thread_key, delta="hel", finished=False, tokens_used=0),
MessageChunk(message_id="matrix:@alice:example.org:C1", delta="lo", finished=False, tokens_used=0), MessageChunk(message_id=thread_key, delta="lo", finished=False, tokens_used=0),
MessageChunk(message_id="matrix:@alice:example.org:C1", delta="", finished=True, tokens_used=7), MessageChunk(message_id=thread_key, delta="", finished=True, tokens_used=7),
] ]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_collects_streamed_chunks_and_tokens(aiohttp_server): async def test_send_message_collects_streamed_chunks_and_tokens(aiohttp_server):
thread_key = build_thread_key("matrix", "@alice:example.org", "C1")
async def handler(request): async def handler(request):
ws = web.WebSocketResponse() ws = web.WebSocketResponse()
await ws.prepare(request) await ws.prepare(request)
assert request.query["thread_id"] == "matrix:@alice:example.org:C1" assert request.query["thread_id"] == thread_key
await ws.send_json({"type": "STATUS"}) await ws.send_json({"type": "STATUS"})
@ -81,13 +134,60 @@ async def test_send_message_collects_streamed_chunks_and_tokens(aiohttp_server):
client = AgentSessionClient(AgentSessionConfig(base_ws_url=str(server.make_url("/agent_ws/")))) client = AgentSessionClient(AgentSessionConfig(base_ws_url=str(server.make_url("/agent_ws/"))))
result = await client.send_message( result = await client.send_message(
thread_key="matrix:@alice:example.org:C1", thread_key=thread_key,
text="hello world", text="hello world",
) )
assert result == MessageResponse( assert result == MessageResponse(
message_id="matrix:@alice:example.org:C1", message_id=thread_key,
response="hello world", response="hello world",
tokens_used=11, tokens_used=11,
finished=True, finished=True,
) )
@pytest.mark.asyncio
async def test_process_message_requires_thread_id_query_param():
class FakeWebSocket:
query_params = {}
async def send_text(self, text: str) -> None:
raise AssertionError(f"send_text should not be called: {text}")
class FakeAgentService:
async def astream(self, text: str, thread_id: str):
yield text
with pytest.raises(ValueError, match="thread_id query parameter is required"):
await process_message(
FakeWebSocket(),
MsgUserMessage(text="hello"),
FakeAgentService(),
)
@pytest.mark.asyncio
async def test_process_message_passes_thread_id_to_agent_service():
class FakeWebSocket:
def __init__(self) -> None:
self.query_params = {"thread_id": "6:matrix18:@alice:example.org2:C1"}
self.sent_messages: list[str] = []
async def send_text(self, text: str) -> None:
self.sent_messages.append(text)
class FakeAgentService:
def __init__(self) -> None:
self.calls: list[tuple[str, str]] = []
async def astream(self, text: str, thread_id: str):
self.calls.append((text, thread_id))
yield "hello"
ws = FakeWebSocket()
agent_service = FakeAgentService()
await process_message(ws, MsgUserMessage(text="hello"), agent_service)
assert agent_service.calls == [("hello", "6:matrix18:@alice:example.org2:C1")]
assert any("AGENT_EVENT_TEXT_CHUNK" in message for message in ws.sent_messages)
assert any("AGENT_EVENT_END" in message for message in ws.sent_messages)

View file

@ -9,18 +9,12 @@ from sdk.prototype_state import PrototypeStateStore
async def test_get_or_create_user_is_stable_per_surface_identity(): async def test_get_or_create_user_is_stable_per_surface_identity():
store = PrototypeStateStore() store = PrototypeStateStore()
first = await store.get_or_create_user( first = await store.get_or_create_user("@alice:example.org", "matrix", "Alice")
external_id="@alice:example.org", second = await store.get_or_create_user("@alice:example.org", "matrix")
platform="matrix",
display_name="Alice",
)
second = await store.get_or_create_user(
external_id="@alice:example.org",
platform="matrix",
)
assert first.user_id == "usr-matrix-@alice:example.org" assert first.user_id == "usr-matrix-@alice:example.org"
assert first.is_new is True assert first.is_new is True
assert store._users["matrix:@alice:example.org"].is_new is False
first.display_name = "Mallory" first.display_name = "Mallory"
first.is_new = False first.is_new = False
@ -56,6 +50,22 @@ async def test_settings_defaults_match_existing_mock_shape():
assert settings.plan == {"name": "Beta", "tokens_used": 0, "tokens_limit": 1000} assert settings.plan == {"name": "Beta", "tokens_used": 0, "tokens_limit": 1000}
@pytest.mark.asyncio
async def test_get_settings_returns_connectors_copy():
store = PrototypeStateStore()
store._settings["usr-matrix-@alice:example.org"] = {
"connectors": {"github": {"enabled": True}},
}
settings = await store.get_settings("usr-matrix-@alice:example.org")
settings.connectors["github"]["enabled"] = False
settings.connectors["slack"] = {"enabled": True}
assert store._settings["usr-matrix-@alice:example.org"]["connectors"] == {
"github": {"enabled": True},
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_settings_supports_toggle_skill_and_setters(): async def test_update_settings_supports_toggle_skill_and_setters():
store = PrototypeStateStore() store = PrototypeStateStore()

View file

@ -1,6 +1,7 @@
import pytest import pytest
from core.protocol import SettingsAction from core.protocol import SettingsAction
from sdk.agent_session import build_thread_key
from sdk.interface import MessageChunk, MessageResponse, UserSettings from sdk.interface import MessageChunk, MessageResponse, UserSettings
from sdk.prototype_state import PrototypeStateStore from sdk.prototype_state import PrototypeStateStore
from sdk.real import RealPlatformClient from sdk.real import RealPlatformClient
@ -31,13 +32,12 @@ async def test_real_platform_client_get_or_create_user_uses_local_state():
client = RealPlatformClient( client = RealPlatformClient(
agent_sessions=FakeAgentSessionClient(), agent_sessions=FakeAgentSessionClient(),
prototype_state=PrototypeStateStore(), prototype_state=PrototypeStateStore(),
platform="telegram",
) )
first = await client.get_or_create_user("u1", "telegram", "Alice") first = await client.get_or_create_user("u1", "matrix", "Alice")
second = await client.get_or_create_user("u1", "telegram") second = await client.get_or_create_user("u1", "matrix")
assert first.user_id == "usr-telegram-u1" assert first.user_id == "usr-matrix-u1"
assert first.is_new is True assert first.is_new is True
assert second.user_id == first.user_id assert second.user_id == first.user_id
assert second.is_new is False assert second.is_new is False
@ -45,57 +45,55 @@ async def test_real_platform_client_get_or_create_user_uses_local_state():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_real_platform_client_send_message_uses_configured_platform(): async def test_real_platform_client_send_message_uses_surface_user_thread_identity():
agent_sessions = FakeAgentSessionClient() agent_sessions = FakeAgentSessionClient()
client = RealPlatformClient( client = RealPlatformClient(
agent_sessions=agent_sessions, agent_sessions=agent_sessions,
prototype_state=PrototypeStateStore(), prototype_state=PrototypeStateStore(),
platform="telegram", platform="matrix",
) )
result = await client.send_message("usr-telegram-u1", "C1", "hello") thread_key = build_thread_key("matrix", "@alice:example.org", "C1")
result = await client.send_message("@alice:example.org", "C1", "hello")
assert result == MessageResponse( assert result == MessageResponse(
message_id="8:telegram15:usr-telegram-u12:C1", message_id=thread_key,
response="echo:hello", response="echo:hello",
tokens_used=3, tokens_used=3,
finished=True, finished=True,
) )
assert agent_sessions.send_calls == [ assert agent_sessions.send_calls == [(thread_key, "hello")]
("8:telegram15:usr-telegram-u12:C1", "hello")
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_real_platform_client_stream_message_uses_configured_platform(): async def test_real_platform_client_stream_message_uses_surface_user_thread_identity():
agent_sessions = FakeAgentSessionClient() agent_sessions = FakeAgentSessionClient()
client = RealPlatformClient( client = RealPlatformClient(
agent_sessions=agent_sessions, agent_sessions=agent_sessions,
prototype_state=PrototypeStateStore(), prototype_state=PrototypeStateStore(),
platform="telegram", platform="matrix",
) )
thread_key = build_thread_key("matrix", "@alice:example.org", "C1")
chunks = [] chunks = []
async for chunk in client.stream_message("usr-telegram-u1", "C1", "hello"): async for chunk in client.stream_message("@alice:example.org", "C1", "hello"):
chunks.append(chunk) chunks.append(chunk)
assert chunks == [ assert chunks == [
MessageChunk( MessageChunk(
message_id="8:telegram15:usr-telegram-u12:C1", message_id=thread_key,
delta="he", delta="he",
finished=False, finished=False,
tokens_used=0, tokens_used=0,
), ),
MessageChunk( MessageChunk(
message_id="8:telegram15:usr-telegram-u12:C1", message_id=thread_key,
delta="llo", delta="llo",
finished=True, finished=True,
tokens_used=3, tokens_used=3,
), ),
] ]
assert agent_sessions.stream_calls == [ assert agent_sessions.stream_calls == [(thread_key, "hello")]
("8:telegram15:usr-telegram-u12:C1", "hello")
]
@pytest.mark.asyncio @pytest.mark.asyncio

18
uv.lock generated
View file

@ -1095,6 +1095,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
] ]
[[package]]
name = "pytest-aiohttp"
version = "1.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
]
sdist = { url = "https://files.pythonhosted.org/packages/72/4b/d326890c153f2c4ce1bf45d07683c08c10a1766058a22934620bc6ac6592/pytest_aiohttp-1.1.0.tar.gz", hash = "sha256:147de8cb164f3fc9d7196967f109ab3c0b93ea3463ab50631e56438eab7b5adc", size = 12842, upload-time = "2025-01-23T12:44:04.465Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ba/0f/e6af71c02e0f1098eaf7d2dbf3ffdf0a69fc1e0ef174f96af05cef161f1b/pytest_aiohttp-1.1.0-py3-none-any.whl", hash = "sha256:f39a11693a0dce08dd6c542d241e199dd8047a6e6596b2bcfa60d373f143456d", size = 8932, upload-time = "2025-01-23T12:44:03.27Z" },
]
[[package]] [[package]]
name = "pytest-asyncio" name = "pytest-asyncio"
version = "1.3.0" version = "1.3.0"
@ -1302,6 +1316,7 @@ version = "0.1.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiogram" }, { name = "aiogram" },
{ name = "aiohttp" },
{ name = "httpx" }, { name = "httpx" },
{ name = "matrix-nio" }, { name = "matrix-nio" },
{ name = "pydantic" }, { name = "pydantic" },
@ -1313,6 +1328,7 @@ dependencies = [
dev = [ dev = [
{ name = "mypy" }, { name = "mypy" },
{ name = "pytest" }, { name = "pytest" },
{ name = "pytest-aiohttp" },
{ name = "pytest-asyncio" }, { name = "pytest-asyncio" },
{ name = "pytest-cov" }, { name = "pytest-cov" },
{ name = "ruff" }, { name = "ruff" },
@ -1321,11 +1337,13 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "aiogram", specifier = ">=3.4,<4" }, { name = "aiogram", specifier = ">=3.4,<4" },
{ name = "aiohttp", specifier = ">=3.9" },
{ name = "httpx", specifier = ">=0.27" }, { name = "httpx", specifier = ">=0.27" },
{ name = "matrix-nio", specifier = ">=0.21" }, { name = "matrix-nio", specifier = ">=0.21" },
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8" },
{ name = "pydantic", specifier = ">=2.5" }, { name = "pydantic", specifier = ">=2.5" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0" },
{ name = "pytest-aiohttp", marker = "extra == 'dev'", specifier = ">=1.0" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23" },
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1" },
{ name = "python-dotenv", specifier = ">=1.0" }, { name = "python-dotenv", specifier = ">=1.0" },