diff --git a/adapter/matrix/bot.py b/adapter/matrix/bot.py index 636d6ef..cf8adb1 100644 --- a/adapter/matrix/bot.py +++ b/adapter/matrix/bot.py @@ -30,7 +30,7 @@ from adapter.matrix.files import ( matrix_msgtype_for_attachment, resolve_workspace_attachment_path, ) -from adapter.matrix.agent_registry import AgentRegistryError, load_agent_registry +from adapter.matrix.agent_registry import AgentRegistry, AgentRegistryError, load_agent_registry from adapter.matrix.handlers import register_matrix_handlers from adapter.matrix.handlers.auth import handle_invite, provision_workspace_chat from adapter.matrix.handlers.context_commands import ( @@ -44,11 +44,13 @@ from adapter.matrix.store import ( clear_staged_attachments, get_load_pending, get_room_meta, + get_selected_agent_id, get_staged_attachments, next_platform_chat_id, remove_staged_attachment_at, set_pending_confirm, set_platform_chat_id, + set_room_agent_id, set_room_meta, ) from core.auth import AuthManager @@ -85,6 +87,7 @@ class MatrixRuntime: auth_mgr: AuthManager settings_mgr: SettingsManager dispatcher: EventDispatcher + agent_routing_enabled: bool = False def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> EventDispatcher: @@ -93,6 +96,7 @@ def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> Event settings_mgr = SettingsManager(platform, store) prototype_state = getattr(platform, "_prototype_state", None) agent_base_url = _agent_base_url_from_env() + registry = _load_agent_registry_from_env() dispatcher = EventDispatcher( platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr ) @@ -100,6 +104,7 @@ def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> Event register_matrix_handlers( dispatcher, store=store, + registry=registry, prototype_state=prototype_state, agent_base_url=agent_base_url, ) @@ -120,19 +125,26 @@ def _agent_base_url_from_env() -> str: return "http://127.0.0.1:8000" +def _load_agent_registry_from_env(required: bool = False) -> AgentRegistry | None: + registry_path = os.environ.get("MATRIX_AGENT_REGISTRY_PATH", "").strip() + if not registry_path: + if required: + raise RuntimeError( + "MATRIX_AGENT_REGISTRY_PATH is required when MATRIX_PLATFORM_BACKEND=real" + ) + return None + try: + return load_agent_registry(registry_path) + except (AgentRegistryError, OSError) as exc: + raise RuntimeError(f"failed to load matrix agent registry: {registry_path}") from exc + + def _build_platform_from_env(*, store: StateStore, chat_mgr: ChatManager) -> PlatformClient: backend = os.environ.get("MATRIX_PLATFORM_BACKEND", "mock").strip().lower() if backend == "real": prototype_state = PrototypeStateStore() - registry_path = os.environ.get("MATRIX_AGENT_REGISTRY_PATH", "").strip() - if not registry_path: - raise RuntimeError( - "MATRIX_AGENT_REGISTRY_PATH is required when MATRIX_PLATFORM_BACKEND=real" - ) - try: - registry = load_agent_registry(registry_path) - except (AgentRegistryError, OSError) as exc: - raise RuntimeError(f"failed to load matrix agent registry: {registry_path}") from exc + registry = _load_agent_registry_from_env(required=True) + assert registry is not None delegates = { agent.agent_id: RealPlatformClient( agent_id=agent.agent_id, @@ -163,6 +175,7 @@ def build_runtime( settings_mgr = SettingsManager(platform, store) prototype_state = getattr(platform, "_prototype_state", None) agent_base_url = _agent_base_url_from_env() + registry = _load_agent_registry_from_env() dispatcher = EventDispatcher( platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr ) @@ -171,6 +184,7 @@ def build_runtime( dispatcher, client=client, store=store, + registry=registry, prototype_state=prototype_state, agent_base_url=agent_base_url, ) @@ -181,6 +195,7 @@ def build_runtime( auth_mgr=auth_mgr, settings_mgr=settings_mgr, dispatcher=dispatcher, + agent_routing_enabled=isinstance(platform, RoutedPlatformClient), ) @@ -244,6 +259,12 @@ class MatrixBot: user=sender, ) return + if not body.startswith("!") and self.runtime.agent_routing_enabled: + block = await self._check_agent_routing(room.room_id, sender, room_meta) + if block is not None: + await self._send_all(room.room_id, block) + return + local_chat_id = await resolve_chat_id(self.runtime.store, room.room_id, sender) incoming = from_room_event(event, room_id=room.room_id, chat_id=local_chat_id) if incoming is None: @@ -574,6 +595,38 @@ class MatrixBot: self.runtime.chat_mgr, ) + async def _check_agent_routing( + self, + room_id: str, + sender: str, + room_meta: dict, + ) -> list[OutgoingEvent] | None: + selected_agent_id = await get_selected_agent_id(self.runtime.store, sender) + if not selected_agent_id: + return [ + OutgoingMessage( + chat_id=room_id, + text="Выбери агент через !agent прежде чем отправлять сообщения.", + ) + ] + room_agent_id = room_meta.get("agent_id") + if room_agent_id and room_agent_id != selected_agent_id: + return [ + OutgoingMessage( + chat_id=room_id, + text=( + f"Этот чат привязан к агенту «{room_agent_id}». " + "Создай новый чат командой !new." + ), + ) + ] + if not room_agent_id: + await set_room_agent_id(self.runtime.store, room_id, selected_agent_id) + await self._ensure_platform_chat_id( + room_id, await get_room_meta(self.runtime.store, room_id) + ) + return None + async def _send_all(self, room_id: str, outgoing: list[OutgoingEvent]) -> None: for event in outgoing: await send_outgoing(self.client, room_id, event, store=self.runtime.store) diff --git a/adapter/matrix/handlers/chat.py b/adapter/matrix/handlers/chat.py index 6ce267c..b5c5dee 100644 --- a/adapter/matrix/handlers/chat.py +++ b/adapter/matrix/handlers/chat.py @@ -8,6 +8,7 @@ from nio.api import RoomVisibility from nio.responses import RoomCreateError from adapter.matrix.store import ( + get_selected_agent_id, get_user_meta, next_chat_id, next_platform_chat_id, @@ -104,18 +105,18 @@ def make_handle_new_chat( state_key=room_id, ) - await set_room_meta( - store, - room_id, - { - "room_type": "chat", - "chat_id": chat_id, - "display_name": room_name, - "matrix_user_id": event.user_id, - "space_id": space_id, - "platform_chat_id": platform_chat_id, - }, - ) + selected_agent_id = await get_selected_agent_id(store, event.user_id) + room_meta: dict = { + "room_type": "chat", + "chat_id": chat_id, + "display_name": room_name, + "matrix_user_id": event.user_id, + "space_id": space_id, + "platform_chat_id": platform_chat_id, + } + if selected_agent_id: + room_meta["agent_id"] = selected_agent_id + await set_room_meta(store, room_id, room_meta) ctx = await chat_mgr.get_or_create( user_id=event.user_id, chat_id=chat_id, diff --git a/tests/adapter/matrix/test_restart_persistence.py b/tests/adapter/matrix/test_restart_persistence.py new file mode 100644 index 0000000..492a94a --- /dev/null +++ b/tests/adapter/matrix/test_restart_persistence.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import pytest + +from core.store import SQLiteStore +from adapter.matrix.store import ( + PLATFORM_CHAT_SEQ_KEY, + get_room_meta, + get_selected_agent_id, + next_platform_chat_id, + set_room_meta, + set_selected_agent_id, +) + + +async def test_selected_agent_id_survives_restart(tmp_path): + db = str(tmp_path / "state.db") + store = SQLiteStore(db) + await set_selected_agent_id(store, "@alice:example.org", "agent-2") + + store2 = SQLiteStore(db) + assert await get_selected_agent_id(store2, "@alice:example.org") == "agent-2" + + +async def test_room_agent_id_and_platform_chat_id_survive_restart(tmp_path): + db = str(tmp_path / "state.db") + store = SQLiteStore(db) + await set_room_meta(store, "!room:example.org", { + "room_type": "chat", + "agent_id": "agent-1", + "platform_chat_id": "42", + }) + + store2 = SQLiteStore(db) + meta = await get_room_meta(store2, "!room:example.org") + assert meta is not None + assert meta["agent_id"] == "agent-1" + assert meta["platform_chat_id"] == "42" + + +async def test_platform_chat_seq_survives_restart(tmp_path): + db = str(tmp_path / "state.db") + store = SQLiteStore(db) + assert await next_platform_chat_id(store) == "1" + assert await next_platform_chat_id(store) == "2" + assert await next_platform_chat_id(store) == "3" + + store2 = SQLiteStore(db) + assert await next_platform_chat_id(store2) == "4" + + +async def test_routing_state_survives_restart_and_routes_correctly(tmp_path): + db = str(tmp_path / "state.db") + store = SQLiteStore(db) + await set_selected_agent_id(store, "@bob:example.org", "agent-1") + await set_room_meta(store, "!convo:example.org", { + "room_type": "chat", + "agent_id": "agent-1", + "platform_chat_id": "10", + }) + + store2 = SQLiteStore(db) + selected = await get_selected_agent_id(store2, "@bob:example.org") + meta = await get_room_meta(store2, "!convo:example.org") + assert selected == "agent-1" + assert meta is not None + assert meta["agent_id"] == selected + assert meta["platform_chat_id"] == "10" + + +async def test_missing_durable_store_starts_clean(tmp_path): + db = str(tmp_path / "brand_new.db") + store = SQLiteStore(db) + assert await get_selected_agent_id(store, "@nobody:example.org") is None + assert await get_room_meta(store, "!nonexistent:example.org") is None diff --git a/tests/adapter/matrix/test_routing_enforcement.py b/tests/adapter/matrix/test_routing_enforcement.py new file mode 100644 index 0000000..c9a7869 --- /dev/null +++ b/tests/adapter/matrix/test_routing_enforcement.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from adapter.matrix.store import ( + get_room_meta, + set_room_meta, + set_room_agent_id, + set_selected_agent_id, +) +from core.protocol import IncomingCommand, OutgoingMessage +from core.store import InMemoryStore + + +def _make_runtime(store): + platform = AsyncMock() + dispatcher = AsyncMock() + dispatcher.dispatch.return_value = [OutgoingMessage(chat_id="!r:s", text="ok")] + runtime = MagicMock() + runtime.store = store + runtime.dispatcher = dispatcher + runtime.platform = platform + runtime.agent_routing_enabled = True + return runtime + + +def _make_bot(store): + from adapter.matrix.bot import MatrixBot + client = MagicMock() + client.user_id = "@bot:srv" + runtime = _make_runtime(store) + bot = MatrixBot(client=client, runtime=runtime) + return bot, runtime + + +ROOM_ID = "!room:srv" +USER_ID = "@alice:srv" + + +async def _send_message(bot, body): + from nio import RoomMessageText, MatrixRoom + room = MagicMock(spec=MatrixRoom) + room.room_id = ROOM_ID + event = MagicMock(spec=RoomMessageText) + event.sender = USER_ID + event.body = body + event.source = {} + bot._send_all = AsyncMock() + await bot.on_room_message(room, event) + return bot._send_all + + +async def test_stale_room_blocks_normal_message(): + store = InMemoryStore() + await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID, + "platform_chat_id": "1", "agent_id": "agent-1"}) + await set_selected_agent_id(store, USER_ID, "agent-2") + bot, runtime = _make_bot(store) + send_all = await _send_message(bot, "hello") + runtime.dispatcher.dispatch.assert_not_called() + args = send_all.call_args[0] + assert any("agent-1" in m.text and "!new" in m.text for m in args[1]) + + +async def test_stale_room_allows_commands(): + store = InMemoryStore() + await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID, + "platform_chat_id": "1", "agent_id": "agent-1"}) + await set_selected_agent_id(store, USER_ID, "agent-2") + bot, runtime = _make_bot(store) + await _send_message(bot, "!help") + runtime.dispatcher.dispatch.assert_called_once() + + +async def test_no_selected_agent_blocks_normal_message(): + store = InMemoryStore() + await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID, + "platform_chat_id": "1"}) + bot, runtime = _make_bot(store) + send_all = await _send_message(bot, "hello") + runtime.dispatcher.dispatch.assert_not_called() + args = send_all.call_args[0] + assert any("!agent" in m.text for m in args[1]) + + +async def test_no_selected_agent_allows_commands(): + store = InMemoryStore() + await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID, + "platform_chat_id": "1"}) + bot, runtime = _make_bot(store) + await _send_message(bot, "!agent") + runtime.dispatcher.dispatch.assert_called_once() + + +async def test_unbound_room_binds_on_message_when_agent_selected(): + store = InMemoryStore() + await set_room_meta(store, ROOM_ID, {"room_type": "chat", "matrix_user_id": USER_ID, + "platform_chat_id": "1"}) + await set_selected_agent_id(store, USER_ID, "agent-1") + bot, runtime = _make_bot(store) + await _send_message(bot, "hello") + meta = await get_room_meta(store, ROOM_ID) + assert meta["agent_id"] == "agent-1" + runtime.dispatcher.dispatch.assert_called_once()