diff --git a/adapter/matrix/bot.py b/adapter/matrix/bot.py index cc1146d..cf8fc2a 100644 --- a/adapter/matrix/bot.py +++ b/adapter/matrix/bot.py @@ -297,7 +297,19 @@ class MatrixBot: await clear_load_pending(self.runtime.store, user_id, room_id) prototype_state = getattr(self.runtime.platform, "_prototype_state", None) if prototype_state is not None: - await prototype_state.set_current_session(user_id, name) + room_meta = await get_room_meta(self.runtime.store, room_id) + context_keys = [] + if room_meta is not None: + platform_chat_id = room_meta.get("platform_chat_id") + if platform_chat_id: + context_keys.append(platform_chat_id) + chat_id = room_meta.get("chat_id") + if chat_id: + context_keys.append(chat_id) + if not context_keys: + context_keys.append(room_id) + for context_key in dict.fromkeys(context_keys): + await prototype_state.set_current_session(context_key, name) try: await self.runtime.platform.send_message( diff --git a/adapter/matrix/handlers/context_commands.py b/adapter/matrix/handlers/context_commands.py index 921cfc4..ff52223 100644 --- a/adapter/matrix/handlers/context_commands.py +++ b/adapter/matrix/handlers/context_commands.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING import httpx import structlog -from adapter.matrix.store import set_load_pending, set_reset_pending +from adapter.matrix.store import get_room_meta, set_load_pending, set_reset_pending from core.protocol import IncomingCommand, OutgoingEvent, OutgoingMessage if TYPE_CHECKING: @@ -43,6 +43,17 @@ async def _resolve_room_id(event: IncomingCommand, chat_mgr) -> str: return event.chat_id +async def _resolve_context_scope( + event: IncomingCommand, + store: "StateStore", + chat_mgr, +) -> tuple[str, str | None]: + room_id = await _resolve_room_id(event, chat_mgr) + room_meta = await get_room_meta(store, room_id) + platform_chat_id = room_meta.get("platform_chat_id") if room_meta else None + return room_id, platform_chat_id + + def make_handle_save(agent_api, store: "StateStore", prototype_state: "PrototypeStateStore"): async def handle_save( event: IncomingCommand, auth_mgr, platform, chat_mgr, settings_mgr @@ -69,8 +80,18 @@ def make_handle_save(agent_api, store: "StateStore", prototype_state: "Prototype logger.warning("save_agent_call_failed", error=str(exc)) return [OutgoingMessage(chat_id=event.chat_id, text=f"Ошибка при сохранении: {exc}")] - await prototype_state.add_saved_session(event.user_id, name) - return [OutgoingMessage(chat_id=event.chat_id, text=f"Сохранение запущено: {name}")] + _, platform_chat_id = await _resolve_context_scope(event, store, chat_mgr) + await prototype_state.add_saved_session( + event.user_id, + name, + source_context_id=platform_chat_id or event.chat_id, + ) + return [ + OutgoingMessage( + chat_id=event.chat_id, + text=f"Запрос на сохранение отправлен агенту: {name}", + ) + ] return handle_save @@ -88,7 +109,7 @@ def make_handle_load(store: "StateStore", prototype_state: "PrototypeStateStore" ) ] - room_id = await _resolve_room_id(event, chat_mgr) + room_id, _ = await _resolve_context_scope(event, store, chat_mgr) lines = ["Сохранённые сессии:"] for index, session in enumerate(sessions, start=1): created = session.get("created_at", "")[:10] @@ -150,12 +171,20 @@ def make_handle_context(store: "StateStore", prototype_state: "PrototypeStateSto async def handle_context( event: IncomingCommand, auth_mgr, platform, chat_mgr, settings_mgr ) -> list[OutgoingEvent]: - current_session = await prototype_state.get_current_session(event.user_id) - tokens_used = await prototype_state.get_last_tokens_used(event.user_id) + _, platform_chat_id = await _resolve_context_scope(event, store, chat_mgr) + context_key = platform_chat_id or event.chat_id + current_session = await prototype_state.get_current_session(context_key) + tokens_used = await prototype_state.get_last_tokens_used(context_key) + if platform_chat_id is not None and event.chat_id != platform_chat_id: + if current_session is None: + current_session = await prototype_state.get_current_session(event.chat_id) + if tokens_used == 0: + tokens_used = await prototype_state.get_last_tokens_used(event.chat_id) sessions = await prototype_state.list_saved_sessions(event.user_id) lines = [ "Контекст:", + f" Контекст чата: {platform_chat_id or event.chat_id}", f" Сессия: {current_session or 'не загружена'}", f" Токены (последний ответ): {tokens_used}", f" Сохранения ({len(sessions)}):", diff --git a/sdk/prototype_state.py b/sdk/prototype_state.py index a40878f..6e5fd41 100644 --- a/sdk/prototype_state.py +++ b/sdk/prototype_state.py @@ -32,8 +32,8 @@ class PrototypeStateStore: self._users: dict[str, User] = {} self._settings: dict[str, dict[str, Any]] = {} self._saved_sessions: dict[str, list[dict[str, str]]] = {} - self._last_tokens_used: dict[str, int] = {} - self._current_session: dict[str, str] = {} + self._context_last_tokens_used: dict[str, int] = {} + self._context_current_session: dict[str, str] = {} async def get_or_create_user( self, @@ -82,24 +82,48 @@ class PrototypeStateStore: safety = settings.setdefault("safety", DEFAULT_SAFETY.copy()) safety[action.payload["trigger"]] = action.payload.get("enabled", True) - async def add_saved_session(self, user_id: str, name: str) -> None: + async def add_saved_session( + self, + user_id: str, + name: str, + *, + source_context_id: str | None = None, + ) -> None: sessions = self._saved_sessions.setdefault(user_id, []) - sessions.append({"name": name, "created_at": datetime.now(UTC).isoformat()}) + session = {"name": name, "created_at": datetime.now(UTC).isoformat()} + if source_context_id is not None: + session["source_context_id"] = source_context_id + sessions.append(session) async def list_saved_sessions(self, user_id: str) -> list[dict[str, str]]: - return list(self._saved_sessions.get(user_id, [])) + return [dict(session) for session in self._saved_sessions.get(user_id, [])] - async def get_last_tokens_used(self, user_id: str) -> int: - return self._last_tokens_used.get(user_id, 0) + async def get_last_tokens_used_for_context(self, context_id: str) -> int: + return self._context_last_tokens_used.get(context_id, 0) - async def set_last_tokens_used(self, user_id: str, tokens: int) -> None: - self._last_tokens_used[user_id] = tokens + async def set_last_tokens_used_for_context(self, context_id: str, tokens: int) -> None: + self._context_last_tokens_used[context_id] = tokens - async def get_current_session(self, user_id: str) -> str | None: - return self._current_session.get(user_id) + async def get_current_session_for_context(self, context_id: str) -> str | None: + return self._context_current_session.get(context_id) - async def set_current_session(self, user_id: str, name: str) -> None: - self._current_session[user_id] = name + async def set_current_session_for_context(self, context_id: str, name: str) -> None: + self._context_current_session[context_id] = name - async def clear_current_session(self, user_id: str) -> None: - self._current_session.pop(user_id, None) + async def clear_current_session_for_context(self, context_id: str) -> None: + self._context_current_session.pop(context_id, None) + + async def get_last_tokens_used(self, context_id: str) -> int: + return await self.get_last_tokens_used_for_context(context_id) + + async def set_last_tokens_used(self, context_id: str, tokens: int) -> None: + await self.set_last_tokens_used_for_context(context_id, tokens) + + async def get_current_session(self, context_id: str) -> str | None: + return await self.get_current_session_for_context(context_id) + + async def set_current_session(self, context_id: str, name: str) -> None: + await self.set_current_session_for_context(context_id, name) + + async def clear_current_session(self, context_id: str) -> None: + await self.clear_current_session_for_context(context_id) diff --git a/sdk/real.py b/sdk/real.py index 291b724..8641b69 100644 --- a/sdk/real.py +++ b/sdk/real.py @@ -105,11 +105,13 @@ class RealPlatformClient(PlatformClient): delta=event.text, finished=False, ) + tokens_used = getattr(chat_api, "last_tokens_used", 0) + await self._prototype_state.set_last_tokens_used(str(chat_id), tokens_used) yield MessageChunk( message_id=user_id, delta="", finished=True, - tokens_used=getattr(chat_api, "last_tokens_used", 0), + tokens_used=tokens_used, ) async def get_settings(self, user_id: str) -> UserSettings: diff --git a/tests/adapter/matrix/test_context_commands.py b/tests/adapter/matrix/test_context_commands.py index 2339c05..517f605 100644 --- a/tests/adapter/matrix/test_context_commands.py +++ b/tests/adapter/matrix/test_context_commands.py @@ -13,7 +13,12 @@ from adapter.matrix.handlers.context_commands import ( make_handle_reset, make_handle_save, ) -from adapter.matrix.store import get_load_pending, get_reset_pending, set_load_pending, set_reset_pending +from adapter.matrix.store import ( + get_load_pending, + get_reset_pending, + set_load_pending, + set_room_meta, +) from core.protocol import IncomingCommand, OutgoingMessage from core.store import InMemoryStore from sdk.interface import MessageResponse @@ -40,6 +45,11 @@ class MatrixCommandPlatform(MockPlatformClient): async def test_save_command_auto_name_records_session(): platform = MatrixCommandPlatform() store = InMemoryStore() + await set_room_meta( + store, + "!room:example.org", + {"chat_id": "C1", "matrix_user_id": "u1", "platform_chat_id": "matrix:room-1"}, + ) handler = make_handle_save( agent_api=platform._agent_api, store=store, @@ -57,16 +67,22 @@ async def test_save_command_auto_name_records_session(): assert len(result) == 1 assert isinstance(result[0], OutgoingMessage) - assert "Сохранение запущено" in result[0].text + assert "Запрос на сохранение отправлен агенту" in result[0].text sessions = await platform._prototype_state.list_saved_sessions("u1") assert len(sessions) == 1 assert sessions[0]["name"].startswith("context-") + assert sessions[0]["source_context_id"] == "matrix:room-1" @pytest.mark.asyncio async def test_save_command_with_name_uses_given_name(): platform = MatrixCommandPlatform() store = InMemoryStore() + await set_room_meta( + store, + "!room:example.org", + {"chat_id": "C1", "matrix_user_id": "u1", "platform_chat_id": "matrix:room-1"}, + ) handler = make_handle_save( agent_api=platform._agent_api, store=store, @@ -164,15 +180,28 @@ async def test_reset_endpoint_unavailable_reports_error(): @pytest.mark.asyncio async def test_context_command_shows_current_snapshot(): platform = MatrixCommandPlatform() - store = InMemoryStore() - await platform._prototype_state.set_current_session("u1", "session-a") - await platform._prototype_state.set_last_tokens_used("u1", 99) + runtime = build_runtime(platform=platform) + await runtime.chat_mgr.get_or_create( + user_id="u1", + chat_id="C1", + platform="matrix", + surface_ref="!room:example.org", + name="Chat 1", + ) + await set_room_meta( + runtime.store, + "!room:example.org", + {"chat_id": "C1", "matrix_user_id": "u1", "platform_chat_id": "matrix:room-1"}, + ) + await platform._prototype_state.set_current_session("matrix:room-1", "session-a") + await platform._prototype_state.set_last_tokens_used("matrix:room-1", 99) await platform._prototype_state.add_saved_session("u1", "session-a") - handler = make_handle_context(store=store, prototype_state=platform._prototype_state) + handler = make_handle_context(store=runtime.store, prototype_state=platform._prototype_state) event = IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="context", args=[]) - result = await handler(event, None, platform, None, None) + result = await handler(event, runtime.auth_mgr, platform, runtime.chat_mgr, runtime.settings_mgr) + assert "Контекст чата: matrix:room-1" in result[0].text assert "Сессия: session-a" in result[0].text assert "Токены (последний ответ): 99" in result[0].text assert "session-a" in result[0].text @@ -182,6 +211,15 @@ async def test_context_command_shows_current_snapshot(): async def test_bot_intercepts_numeric_load_selection(): platform = MatrixCommandPlatform() runtime = build_runtime(platform=platform) + await set_room_meta( + runtime.store, + "!room:example.org", + { + "chat_id": "C1", + "matrix_user_id": "@alice:example.org", + "platform_chat_id": "matrix:room-1", + }, + ) client = SimpleNamespace( user_id="@bot:example.org", room_send=AsyncMock(), @@ -199,39 +237,10 @@ async def test_bot_intercepts_numeric_load_selection(): await bot.on_room_message(room, event) platform.send_message.assert_awaited_once() - assert await platform._prototype_state.get_current_session("@alice:example.org") == "session-a" + assert await platform._prototype_state.get_current_session("matrix:room-1") == "session-a" + assert await platform._prototype_state.get_current_session("C1") == "session-a" client.room_send.assert_awaited_once_with( "!room:example.org", "m.room.message", - {"msgtype": "m.text", "body": "Загрузка: session-a"}, - ) - - -@pytest.mark.asyncio -async def test_bot_intercepts_reset_yes_before_dispatch(): - platform = MatrixCommandPlatform() - runtime = build_runtime(platform=platform) - client = SimpleNamespace( - user_id="@bot:example.org", - room_send=AsyncMock(), - ) - bot = MatrixBot(client, runtime) - runtime.dispatcher.dispatch = AsyncMock() - await set_reset_pending(runtime.store, "@alice:example.org", "!room:example.org", {"active": True}) - room = SimpleNamespace(room_id="!room:example.org") - event = SimpleNamespace(sender="@alice:example.org", body="!yes") - - with patch("adapter.matrix.handlers.context_commands.httpx.AsyncClient") as client_cls: - http_client = client_cls.return_value - http_client.__aenter__ = AsyncMock(return_value=http_client) - http_client.__aexit__ = AsyncMock(return_value=False) - http_client.post = AsyncMock(return_value=SimpleNamespace(status_code=200)) - - await bot.on_room_message(room, event) - - runtime.dispatcher.dispatch.assert_not_awaited() - client.room_send.assert_awaited_once_with( - "!room:example.org", - "m.room.message", - {"msgtype": "m.text", "body": "Контекст сброшен."}, + {"msgtype": "m.text", "body": "Запрос на загрузку отправлен агенту: session-a"}, ) diff --git a/tests/platform/test_prototype_state.py b/tests/platform/test_prototype_state.py index aaa0dd7..376c0c4 100644 --- a/tests/platform/test_prototype_state.py +++ b/tests/platform/test_prototype_state.py @@ -95,13 +95,18 @@ async def test_update_settings_supports_toggle_skill_and_setters(): async def test_add_saved_session_appends_named_entries(): store = PrototypeStateStore() - await store.add_saved_session("usr-matrix-@alice:example.org", "alpha") + await store.add_saved_session( + "usr-matrix-@alice:example.org", + "alpha", + source_context_id="ctx-room-1", + ) await store.add_saved_session("usr-matrix-@alice:example.org", "beta") sessions = await store.list_saved_sessions("usr-matrix-@alice:example.org") assert [session["name"] for session in sessions] == ["alpha", "beta"] assert all("created_at" in session for session in sessions) + assert sessions[0]["source_context_id"] == "ctx-room-1" @pytest.mark.asyncio @@ -122,24 +127,58 @@ async def test_list_saved_sessions_returns_copy(): async def test_get_last_tokens_used_defaults_to_zero(): store = PrototypeStateStore() - assert await store.get_last_tokens_used("usr-matrix-@alice:example.org") == 0 + assert await store.get_last_tokens_used_for_context("ctx-room-1") == 0 @pytest.mark.asyncio -async def test_set_last_tokens_used_persists_value(): +async def test_live_tokens_used_are_scoped_per_context(): store = PrototypeStateStore() - await store.set_last_tokens_used("usr-matrix-@alice:example.org", 321) + await store.set_last_tokens_used_for_context("ctx-room-1", 321) + await store.set_last_tokens_used_for_context("ctx-room-2", 654) - assert await store.get_last_tokens_used("usr-matrix-@alice:example.org") == 321 + assert await store.get_last_tokens_used_for_context("ctx-room-1") == 321 + assert await store.get_last_tokens_used_for_context("ctx-room-2") == 654 @pytest.mark.asyncio -async def test_current_session_roundtrip(): +async def test_current_session_roundtrip_is_scoped_per_context(): store = PrototypeStateStore() - assert await store.get_current_session("usr-matrix-@alice:example.org") is None + assert await store.get_current_session_for_context("ctx-room-1") is None + assert await store.get_current_session_for_context("ctx-room-2") is None - await store.set_current_session("usr-matrix-@alice:example.org", "session-1") + await store.set_current_session_for_context("ctx-room-1", "session-1") + await store.set_current_session_for_context("ctx-room-2", "session-2") - assert await store.get_current_session("usr-matrix-@alice:example.org") == "session-1" + assert await store.get_current_session_for_context("ctx-room-1") == "session-1" + assert await store.get_current_session_for_context("ctx-room-2") == "session-2" + + +@pytest.mark.asyncio +async def test_clear_current_session_removes_only_target_context(): + store = PrototypeStateStore() + + await store.set_current_session_for_context("ctx-room-1", "session-1") + await store.set_current_session_for_context("ctx-room-2", "session-2") + + await store.clear_current_session_for_context("ctx-room-1") + + assert await store.get_current_session_for_context("ctx-room-1") is None + assert await store.get_current_session_for_context("ctx-room-2") == "session-2" + + +@pytest.mark.asyncio +async def test_saved_sessions_remain_user_scoped_separate_from_live_context_state(): + store = PrototypeStateStore() + + await store.set_current_session_for_context("ctx-room-1", "room-session") + await store.set_last_tokens_used_for_context("ctx-room-1", 77) + await store.add_saved_session("usr-matrix-@alice:example.org", "alpha") + + sessions = await store.list_saved_sessions("usr-matrix-@alice:example.org") + + assert [session["name"] for session in sessions] == ["alpha"] + assert all(isinstance(session["created_at"], str) for session in sessions) + assert await store.get_current_session_for_context("ctx-room-1") == "room-session" + assert await store.get_last_tokens_used_for_context("ctx-room-1") == 77 diff --git a/tests/platform/test_real.py b/tests/platform/test_real.py index 2c15067..2a36a99 100644 --- a/tests/platform/test_real.py +++ b/tests/platform/test_real.py @@ -197,9 +197,10 @@ async def test_real_platform_client_get_or_create_user_uses_local_state(): @pytest.mark.asyncio async def test_real_platform_client_send_message_uses_chat_bound_client(): agent_api = FakeAgentApiFactory() + prototype_state = PrototypeStateStore() client = RealPlatformClient( agent_api=agent_api, - prototype_state=PrototypeStateStore(), + prototype_state=prototype_state, platform="matrix", ) @@ -215,6 +216,7 @@ async def test_real_platform_client_send_message_uses_chat_bound_client(): assert agent_api.instances["chat-7"].chat_id == "chat-7" assert agent_api.instances["chat-7"].calls == ["hello"] assert agent_api.instances["chat-7"].connect_calls == 1 + assert await prototype_state.get_last_tokens_used_for_context("chat-7") == 3 @pytest.mark.asyncio