feat: enforce agent routing and persist restart state
Task 4: stale room blocking + agent_id binding - MatrixBot._check_agent_routing: blocks normal messages when user has no selected agent or room is bound to a different agent - agent_routing_enabled flag on MatrixRuntime activates the check only in real multi-agent mode (RoutedPlatformClient) - make_handle_new_chat now writes agent_id into new room metadata when user already has a selected agent Task 5: durable restart state tests - test_restart_persistence.py proves selected_agent_id, room agent_id, platform_chat_id, and the sequence counter all survive SQLiteStore close/reopen; also covers clean startup with no prior state
This commit is contained in:
parent
74cf028e8f
commit
e733119d1e
4 changed files with 256 additions and 22 deletions
|
|
@ -30,7 +30,7 @@ from adapter.matrix.files import (
|
||||||
matrix_msgtype_for_attachment,
|
matrix_msgtype_for_attachment,
|
||||||
resolve_workspace_attachment_path,
|
resolve_workspace_attachment_path,
|
||||||
)
|
)
|
||||||
from adapter.matrix.agent_registry import AgentRegistryError, load_agent_registry
|
from adapter.matrix.agent_registry import AgentRegistry, AgentRegistryError, load_agent_registry
|
||||||
from adapter.matrix.handlers import register_matrix_handlers
|
from adapter.matrix.handlers import register_matrix_handlers
|
||||||
from adapter.matrix.handlers.auth import handle_invite, provision_workspace_chat
|
from adapter.matrix.handlers.auth import handle_invite, provision_workspace_chat
|
||||||
from adapter.matrix.handlers.context_commands import (
|
from adapter.matrix.handlers.context_commands import (
|
||||||
|
|
@ -44,11 +44,13 @@ from adapter.matrix.store import (
|
||||||
clear_staged_attachments,
|
clear_staged_attachments,
|
||||||
get_load_pending,
|
get_load_pending,
|
||||||
get_room_meta,
|
get_room_meta,
|
||||||
|
get_selected_agent_id,
|
||||||
get_staged_attachments,
|
get_staged_attachments,
|
||||||
next_platform_chat_id,
|
next_platform_chat_id,
|
||||||
remove_staged_attachment_at,
|
remove_staged_attachment_at,
|
||||||
set_pending_confirm,
|
set_pending_confirm,
|
||||||
set_platform_chat_id,
|
set_platform_chat_id,
|
||||||
|
set_room_agent_id,
|
||||||
set_room_meta,
|
set_room_meta,
|
||||||
)
|
)
|
||||||
from core.auth import AuthManager
|
from core.auth import AuthManager
|
||||||
|
|
@ -85,6 +87,7 @@ class MatrixRuntime:
|
||||||
auth_mgr: AuthManager
|
auth_mgr: AuthManager
|
||||||
settings_mgr: SettingsManager
|
settings_mgr: SettingsManager
|
||||||
dispatcher: EventDispatcher
|
dispatcher: EventDispatcher
|
||||||
|
agent_routing_enabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> EventDispatcher:
|
def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> EventDispatcher:
|
||||||
|
|
@ -93,6 +96,7 @@ def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> Event
|
||||||
settings_mgr = SettingsManager(platform, store)
|
settings_mgr = SettingsManager(platform, store)
|
||||||
prototype_state = getattr(platform, "_prototype_state", None)
|
prototype_state = getattr(platform, "_prototype_state", None)
|
||||||
agent_base_url = _agent_base_url_from_env()
|
agent_base_url = _agent_base_url_from_env()
|
||||||
|
registry = _load_agent_registry_from_env()
|
||||||
dispatcher = EventDispatcher(
|
dispatcher = EventDispatcher(
|
||||||
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
||||||
)
|
)
|
||||||
|
|
@ -100,6 +104,7 @@ def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> Event
|
||||||
register_matrix_handlers(
|
register_matrix_handlers(
|
||||||
dispatcher,
|
dispatcher,
|
||||||
store=store,
|
store=store,
|
||||||
|
registry=registry,
|
||||||
prototype_state=prototype_state,
|
prototype_state=prototype_state,
|
||||||
agent_base_url=agent_base_url,
|
agent_base_url=agent_base_url,
|
||||||
)
|
)
|
||||||
|
|
@ -120,19 +125,26 @@ def _agent_base_url_from_env() -> str:
|
||||||
return "http://127.0.0.1:8000"
|
return "http://127.0.0.1:8000"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_agent_registry_from_env(required: bool = False) -> AgentRegistry | None:
|
||||||
|
registry_path = os.environ.get("MATRIX_AGENT_REGISTRY_PATH", "").strip()
|
||||||
|
if not registry_path:
|
||||||
|
if required:
|
||||||
|
raise RuntimeError(
|
||||||
|
"MATRIX_AGENT_REGISTRY_PATH is required when MATRIX_PLATFORM_BACKEND=real"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return load_agent_registry(registry_path)
|
||||||
|
except (AgentRegistryError, OSError) as exc:
|
||||||
|
raise RuntimeError(f"failed to load matrix agent registry: {registry_path}") from exc
|
||||||
|
|
||||||
|
|
||||||
def _build_platform_from_env(*, store: StateStore, chat_mgr: ChatManager) -> PlatformClient:
|
def _build_platform_from_env(*, store: StateStore, chat_mgr: ChatManager) -> PlatformClient:
|
||||||
backend = os.environ.get("MATRIX_PLATFORM_BACKEND", "mock").strip().lower()
|
backend = os.environ.get("MATRIX_PLATFORM_BACKEND", "mock").strip().lower()
|
||||||
if backend == "real":
|
if backend == "real":
|
||||||
prototype_state = PrototypeStateStore()
|
prototype_state = PrototypeStateStore()
|
||||||
registry_path = os.environ.get("MATRIX_AGENT_REGISTRY_PATH", "").strip()
|
registry = _load_agent_registry_from_env(required=True)
|
||||||
if not registry_path:
|
assert registry is not None
|
||||||
raise RuntimeError(
|
|
||||||
"MATRIX_AGENT_REGISTRY_PATH is required when MATRIX_PLATFORM_BACKEND=real"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
registry = load_agent_registry(registry_path)
|
|
||||||
except (AgentRegistryError, OSError) as exc:
|
|
||||||
raise RuntimeError(f"failed to load matrix agent registry: {registry_path}") from exc
|
|
||||||
delegates = {
|
delegates = {
|
||||||
agent.agent_id: RealPlatformClient(
|
agent.agent_id: RealPlatformClient(
|
||||||
agent_id=agent.agent_id,
|
agent_id=agent.agent_id,
|
||||||
|
|
@ -163,6 +175,7 @@ def build_runtime(
|
||||||
settings_mgr = SettingsManager(platform, store)
|
settings_mgr = SettingsManager(platform, store)
|
||||||
prototype_state = getattr(platform, "_prototype_state", None)
|
prototype_state = getattr(platform, "_prototype_state", None)
|
||||||
agent_base_url = _agent_base_url_from_env()
|
agent_base_url = _agent_base_url_from_env()
|
||||||
|
registry = _load_agent_registry_from_env()
|
||||||
dispatcher = EventDispatcher(
|
dispatcher = EventDispatcher(
|
||||||
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
||||||
)
|
)
|
||||||
|
|
@ -171,6 +184,7 @@ def build_runtime(
|
||||||
dispatcher,
|
dispatcher,
|
||||||
client=client,
|
client=client,
|
||||||
store=store,
|
store=store,
|
||||||
|
registry=registry,
|
||||||
prototype_state=prototype_state,
|
prototype_state=prototype_state,
|
||||||
agent_base_url=agent_base_url,
|
agent_base_url=agent_base_url,
|
||||||
)
|
)
|
||||||
|
|
@ -181,6 +195,7 @@ def build_runtime(
|
||||||
auth_mgr=auth_mgr,
|
auth_mgr=auth_mgr,
|
||||||
settings_mgr=settings_mgr,
|
settings_mgr=settings_mgr,
|
||||||
dispatcher=dispatcher,
|
dispatcher=dispatcher,
|
||||||
|
agent_routing_enabled=isinstance(platform, RoutedPlatformClient),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -244,6 +259,12 @@ class MatrixBot:
|
||||||
user=sender,
|
user=sender,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
if not body.startswith("!") and self.runtime.agent_routing_enabled:
|
||||||
|
block = await self._check_agent_routing(room.room_id, sender, room_meta)
|
||||||
|
if block is not None:
|
||||||
|
await self._send_all(room.room_id, block)
|
||||||
|
return
|
||||||
|
|
||||||
local_chat_id = await resolve_chat_id(self.runtime.store, room.room_id, sender)
|
local_chat_id = await resolve_chat_id(self.runtime.store, room.room_id, sender)
|
||||||
incoming = from_room_event(event, room_id=room.room_id, chat_id=local_chat_id)
|
incoming = from_room_event(event, room_id=room.room_id, chat_id=local_chat_id)
|
||||||
if incoming is None:
|
if incoming is None:
|
||||||
|
|
@ -574,6 +595,38 @@ class MatrixBot:
|
||||||
self.runtime.chat_mgr,
|
self.runtime.chat_mgr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _check_agent_routing(
|
||||||
|
self,
|
||||||
|
room_id: str,
|
||||||
|
sender: str,
|
||||||
|
room_meta: dict,
|
||||||
|
) -> list[OutgoingEvent] | None:
|
||||||
|
selected_agent_id = await get_selected_agent_id(self.runtime.store, sender)
|
||||||
|
if not selected_agent_id:
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=room_id,
|
||||||
|
text="Выбери агент через !agent прежде чем отправлять сообщения.",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
room_agent_id = room_meta.get("agent_id")
|
||||||
|
if room_agent_id and room_agent_id != selected_agent_id:
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=room_id,
|
||||||
|
text=(
|
||||||
|
f"Этот чат привязан к агенту «{room_agent_id}». "
|
||||||
|
"Создай новый чат командой !new."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if not room_agent_id:
|
||||||
|
await set_room_agent_id(self.runtime.store, room_id, selected_agent_id)
|
||||||
|
await self._ensure_platform_chat_id(
|
||||||
|
room_id, await get_room_meta(self.runtime.store, room_id)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
async def _send_all(self, room_id: str, outgoing: list[OutgoingEvent]) -> None:
|
async def _send_all(self, room_id: str, outgoing: list[OutgoingEvent]) -> None:
|
||||||
for event in outgoing:
|
for event in outgoing:
|
||||||
await send_outgoing(self.client, room_id, event, store=self.runtime.store)
|
await send_outgoing(self.client, room_id, event, store=self.runtime.store)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from nio.api import RoomVisibility
|
||||||
from nio.responses import RoomCreateError
|
from nio.responses import RoomCreateError
|
||||||
|
|
||||||
from adapter.matrix.store import (
|
from adapter.matrix.store import (
|
||||||
|
get_selected_agent_id,
|
||||||
get_user_meta,
|
get_user_meta,
|
||||||
next_chat_id,
|
next_chat_id,
|
||||||
next_platform_chat_id,
|
next_platform_chat_id,
|
||||||
|
|
@ -104,18 +105,18 @@ def make_handle_new_chat(
|
||||||
state_key=room_id,
|
state_key=room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await set_room_meta(
|
selected_agent_id = await get_selected_agent_id(store, event.user_id)
|
||||||
store,
|
room_meta: dict = {
|
||||||
room_id,
|
"room_type": "chat",
|
||||||
{
|
"chat_id": chat_id,
|
||||||
"room_type": "chat",
|
"display_name": room_name,
|
||||||
"chat_id": chat_id,
|
"matrix_user_id": event.user_id,
|
||||||
"display_name": room_name,
|
"space_id": space_id,
|
||||||
"matrix_user_id": event.user_id,
|
"platform_chat_id": platform_chat_id,
|
||||||
"space_id": space_id,
|
}
|
||||||
"platform_chat_id": platform_chat_id,
|
if selected_agent_id:
|
||||||
},
|
room_meta["agent_id"] = selected_agent_id
|
||||||
)
|
await set_room_meta(store, room_id, room_meta)
|
||||||
ctx = await chat_mgr.get_or_create(
|
ctx = await chat_mgr.get_or_create(
|
||||||
user_id=event.user_id,
|
user_id=event.user_id,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
|
|
|
||||||
75
tests/adapter/matrix/test_restart_persistence.py
Normal file
75
tests/adapter/matrix/test_restart_persistence.py
Normal file
|
|
@ -0,0 +1,75 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.store import SQLiteStore
|
||||||
|
from adapter.matrix.store import (
|
||||||
|
PLATFORM_CHAT_SEQ_KEY,
|
||||||
|
get_room_meta,
|
||||||
|
get_selected_agent_id,
|
||||||
|
next_platform_chat_id,
|
||||||
|
set_room_meta,
|
||||||
|
set_selected_agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_selected_agent_id_survives_restart(tmp_path):
|
||||||
|
db = str(tmp_path / "state.db")
|
||||||
|
store = SQLiteStore(db)
|
||||||
|
await set_selected_agent_id(store, "@alice:example.org", "agent-2")
|
||||||
|
|
||||||
|
store2 = SQLiteStore(db)
|
||||||
|
assert await get_selected_agent_id(store2, "@alice:example.org") == "agent-2"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_room_agent_id_and_platform_chat_id_survive_restart(tmp_path):
|
||||||
|
db = str(tmp_path / "state.db")
|
||||||
|
store = SQLiteStore(db)
|
||||||
|
await set_room_meta(store, "!room:example.org", {
|
||||||
|
"room_type": "chat",
|
||||||
|
"agent_id": "agent-1",
|
||||||
|
"platform_chat_id": "42",
|
||||||
|
})
|
||||||
|
|
||||||
|
store2 = SQLiteStore(db)
|
||||||
|
meta = await get_room_meta(store2, "!room:example.org")
|
||||||
|
assert meta is not None
|
||||||
|
assert meta["agent_id"] == "agent-1"
|
||||||
|
assert meta["platform_chat_id"] == "42"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_platform_chat_seq_survives_restart(tmp_path):
|
||||||
|
db = str(tmp_path / "state.db")
|
||||||
|
store = SQLiteStore(db)
|
||||||
|
assert await next_platform_chat_id(store) == "1"
|
||||||
|
assert await next_platform_chat_id(store) == "2"
|
||||||
|
assert await next_platform_chat_id(store) == "3"
|
||||||
|
|
||||||
|
store2 = SQLiteStore(db)
|
||||||
|
assert await next_platform_chat_id(store2) == "4"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_routing_state_survives_restart_and_routes_correctly(tmp_path):
|
||||||
|
db = str(tmp_path / "state.db")
|
||||||
|
store = SQLiteStore(db)
|
||||||
|
await set_selected_agent_id(store, "@bob:example.org", "agent-1")
|
||||||
|
await set_room_meta(store, "!convo:example.org", {
|
||||||
|
"room_type": "chat",
|
||||||
|
"agent_id": "agent-1",
|
||||||
|
"platform_chat_id": "10",
|
||||||
|
})
|
||||||
|
|
||||||
|
store2 = SQLiteStore(db)
|
||||||
|
selected = await get_selected_agent_id(store2, "@bob:example.org")
|
||||||
|
meta = await get_room_meta(store2, "!convo:example.org")
|
||||||
|
assert selected == "agent-1"
|
||||||
|
assert meta is not None
|
||||||
|
assert meta["agent_id"] == selected
|
||||||
|
assert meta["platform_chat_id"] == "10"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_missing_durable_store_starts_clean(tmp_path):
|
||||||
|
db = str(tmp_path / "brand_new.db")
|
||||||
|
store = SQLiteStore(db)
|
||||||
|
assert await get_selected_agent_id(store, "@nobody:example.org") is None
|
||||||
|
assert await get_room_meta(store, "!nonexistent:example.org") is None
|
||||||
105
tests/adapter/matrix/test_routing_enforcement.py
Normal file
105
tests/adapter/matrix/test_routing_enforcement.py
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from adapter.matrix.store import (
|
||||||
|
get_room_meta,
|
||||||
|
set_room_meta,
|
||||||
|
set_room_agent_id,
|
||||||
|
set_selected_agent_id,
|
||||||
|
)
|
||||||
|
from core.protocol import IncomingCommand, OutgoingMessage
|
||||||
|
from core.store import InMemoryStore
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runtime(store):
|
||||||
|
platform = AsyncMock()
|
||||||
|
dispatcher = AsyncMock()
|
||||||
|
dispatcher.dispatch.return_value = [OutgoingMessage(chat_id="!r:s", text="ok")]
|
||||||
|
runtime = MagicMock()
|
||||||
|
runtime.store = store
|
||||||
|
runtime.dispatcher = dispatcher
|
||||||
|
runtime.platform = platform
|
||||||
|
runtime.agent_routing_enabled = True
|
||||||
|
return runtime
|
||||||
|
|
||||||
|
|
||||||
|
def _make_bot(store):
|
||||||
|
from adapter.matrix.bot import MatrixBot
|
||||||
|
client = MagicMock()
|
||||||
|
client.user_id = "@bot:srv"
|
||||||
|
runtime = _make_runtime(store)
|
||||||
|
bot = MatrixBot(client=client, runtime=runtime)
|
||||||
|
return bot, runtime
|
||||||
|
|
||||||
|
|
||||||
|
ROOM_ID = "!room:srv"
|
||||||
|
USER_ID = "@alice:srv"
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_message(bot, body):
|
||||||
|
from nio import RoomMessageText, MatrixRoom
|
||||||
|
room = MagicMock(spec=MatrixRoom)
|
||||||
|
room.room_id = ROOM_ID
|
||||||
|
event = MagicMock(spec=RoomMessageText)
|
||||||
|
event.sender = USER_ID
|
||||||
|
event.body = body
|
||||||
|
event.source = {}
|
||||||
|
bot._send_all = AsyncMock()
|
||||||
|
await bot.on_room_message(room, event)
|
||||||
|
return bot._send_all
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stale_room_blocks_normal_message():
|
||||||
|
store = InMemoryStore()
|
||||||
|
await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID,
|
||||||
|
"platform_chat_id": "1", "agent_id": "agent-1"})
|
||||||
|
await set_selected_agent_id(store, USER_ID, "agent-2")
|
||||||
|
bot, runtime = _make_bot(store)
|
||||||
|
send_all = await _send_message(bot, "hello")
|
||||||
|
runtime.dispatcher.dispatch.assert_not_called()
|
||||||
|
args = send_all.call_args[0]
|
||||||
|
assert any("agent-1" in m.text and "!new" in m.text for m in args[1])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stale_room_allows_commands():
|
||||||
|
store = InMemoryStore()
|
||||||
|
await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID,
|
||||||
|
"platform_chat_id": "1", "agent_id": "agent-1"})
|
||||||
|
await set_selected_agent_id(store, USER_ID, "agent-2")
|
||||||
|
bot, runtime = _make_bot(store)
|
||||||
|
await _send_message(bot, "!help")
|
||||||
|
runtime.dispatcher.dispatch.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_no_selected_agent_blocks_normal_message():
|
||||||
|
store = InMemoryStore()
|
||||||
|
await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID,
|
||||||
|
"platform_chat_id": "1"})
|
||||||
|
bot, runtime = _make_bot(store)
|
||||||
|
send_all = await _send_message(bot, "hello")
|
||||||
|
runtime.dispatcher.dispatch.assert_not_called()
|
||||||
|
args = send_all.call_args[0]
|
||||||
|
assert any("!agent" in m.text for m in args[1])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_no_selected_agent_allows_commands():
|
||||||
|
store = InMemoryStore()
|
||||||
|
await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID,
|
||||||
|
"platform_chat_id": "1"})
|
||||||
|
bot, runtime = _make_bot(store)
|
||||||
|
await _send_message(bot, "!agent")
|
||||||
|
runtime.dispatcher.dispatch.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_unbound_room_binds_on_message_when_agent_selected():
|
||||||
|
store = InMemoryStore()
|
||||||
|
await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID,
|
||||||
|
"platform_chat_id": "1"})
|
||||||
|
await set_selected_agent_id(store, USER_ID, "agent-1")
|
||||||
|
bot, runtime = _make_bot(store)
|
||||||
|
await _send_message(bot, "hello")
|
||||||
|
meta = await get_room_meta(store, ROOM_ID)
|
||||||
|
assert meta["agent_id"] == "agent-1"
|
||||||
|
runtime.dispatcher.dispatch.assert_called_once()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue