feat(04-02): add matrix context management commands
- add save/load/reset/context handlers and matrix interception flows - persist current session and last token usage in prototype state
This commit is contained in:
parent
da0b76882e
commit
b52fdc4670
7 changed files with 638 additions and 21 deletions
|
|
@ -19,9 +19,22 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
from adapter.matrix.converter import from_room_event
|
from adapter.matrix.converter import from_room_event
|
||||||
from adapter.matrix.handlers import register_matrix_handlers
|
from adapter.matrix.handlers import register_matrix_handlers
|
||||||
|
from adapter.matrix.handlers.context_commands import (
|
||||||
|
LOAD_PROMPT,
|
||||||
|
SAVE_PROMPT,
|
||||||
|
_call_reset_endpoint,
|
||||||
|
_sanitize_session_name,
|
||||||
|
)
|
||||||
from adapter.matrix.handlers.auth import handle_invite
|
from adapter.matrix.handlers.auth import handle_invite
|
||||||
from adapter.matrix.room_router import resolve_chat_id
|
from adapter.matrix.room_router import resolve_chat_id
|
||||||
from adapter.matrix.store import get_room_meta, set_pending_confirm
|
from adapter.matrix.store import (
|
||||||
|
clear_load_pending,
|
||||||
|
clear_reset_pending,
|
||||||
|
get_load_pending,
|
||||||
|
get_reset_pending,
|
||||||
|
get_room_meta,
|
||||||
|
set_pending_confirm,
|
||||||
|
)
|
||||||
from core.auth import AuthManager
|
from core.auth import AuthManager
|
||||||
from core.chat import ChatManager
|
from core.chat import ChatManager
|
||||||
from core.handler import EventDispatcher
|
from core.handler import EventDispatcher
|
||||||
|
|
@ -35,8 +48,8 @@ from core.protocol import (
|
||||||
)
|
)
|
||||||
from core.settings import SettingsManager
|
from core.settings import SettingsManager
|
||||||
from core.store import InMemoryStore, SQLiteStore, StateStore
|
from core.store import InMemoryStore, SQLiteStore, StateStore
|
||||||
from sdk.agent_session import AgentSessionClient, AgentSessionConfig
|
from sdk.agent_api_wrapper import AgentApiWrapper
|
||||||
from sdk.interface import PlatformClient
|
from sdk.interface import PlatformClient, PlatformError
|
||||||
from sdk.mock import MockPlatformClient
|
from sdk.mock import MockPlatformClient
|
||||||
from sdk.prototype_state import PrototypeStateStore
|
from sdk.prototype_state import PrototypeStateStore
|
||||||
from sdk.real import RealPlatformClient
|
from sdk.real import RealPlatformClient
|
||||||
|
|
@ -60,11 +73,20 @@ def build_event_dispatcher(platform: PlatformClient, store: StateStore) -> Event
|
||||||
chat_mgr = ChatManager(platform, store)
|
chat_mgr = ChatManager(platform, store)
|
||||||
auth_mgr = AuthManager(platform, store)
|
auth_mgr = AuthManager(platform, store)
|
||||||
settings_mgr = SettingsManager(platform, store)
|
settings_mgr = SettingsManager(platform, store)
|
||||||
|
prototype_state = getattr(platform, "_prototype_state", None)
|
||||||
|
agent_api = getattr(platform, "_agent_api", None)
|
||||||
|
agent_base_url = os.environ.get("AGENT_BASE_URL", "http://127.0.0.1:8000")
|
||||||
dispatcher = EventDispatcher(
|
dispatcher = EventDispatcher(
|
||||||
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
||||||
)
|
)
|
||||||
register_all(dispatcher)
|
register_all(dispatcher)
|
||||||
register_matrix_handlers(dispatcher, store=store)
|
register_matrix_handlers(
|
||||||
|
dispatcher,
|
||||||
|
store=store,
|
||||||
|
agent_api=agent_api,
|
||||||
|
prototype_state=prototype_state,
|
||||||
|
agent_base_url=agent_base_url,
|
||||||
|
)
|
||||||
return dispatcher
|
return dispatcher
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -73,7 +95,7 @@ def _build_platform_from_env() -> PlatformClient:
|
||||||
if backend == "real":
|
if backend == "real":
|
||||||
ws_url = os.environ["AGENT_WS_URL"]
|
ws_url = os.environ["AGENT_WS_URL"]
|
||||||
return RealPlatformClient(
|
return RealPlatformClient(
|
||||||
agent_sessions=AgentSessionClient(AgentSessionConfig(base_ws_url=ws_url)),
|
agent_api=AgentApiWrapper(agent_id="matrix-bot", url=ws_url),
|
||||||
prototype_state=PrototypeStateStore(),
|
prototype_state=PrototypeStateStore(),
|
||||||
platform="matrix",
|
platform="matrix",
|
||||||
)
|
)
|
||||||
|
|
@ -90,11 +112,21 @@ def build_runtime(
|
||||||
chat_mgr = ChatManager(platform, store)
|
chat_mgr = ChatManager(platform, store)
|
||||||
auth_mgr = AuthManager(platform, store)
|
auth_mgr = AuthManager(platform, store)
|
||||||
settings_mgr = SettingsManager(platform, store)
|
settings_mgr = SettingsManager(platform, store)
|
||||||
|
prototype_state = getattr(platform, "_prototype_state", None)
|
||||||
|
agent_api = getattr(platform, "_agent_api", None)
|
||||||
|
agent_base_url = os.environ.get("AGENT_BASE_URL", "http://127.0.0.1:8000")
|
||||||
dispatcher = EventDispatcher(
|
dispatcher = EventDispatcher(
|
||||||
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr
|
||||||
)
|
)
|
||||||
register_all(dispatcher)
|
register_all(dispatcher)
|
||||||
register_matrix_handlers(dispatcher, client=client, store=store)
|
register_matrix_handlers(
|
||||||
|
dispatcher,
|
||||||
|
client=client,
|
||||||
|
store=store,
|
||||||
|
agent_api=agent_api,
|
||||||
|
prototype_state=prototype_state,
|
||||||
|
agent_base_url=agent_base_url,
|
||||||
|
)
|
||||||
return MatrixRuntime(
|
return MatrixRuntime(
|
||||||
platform=platform,
|
platform=platform,
|
||||||
store=store,
|
store=store,
|
||||||
|
|
@ -113,13 +145,118 @@ class MatrixBot:
|
||||||
async def on_room_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
async def on_room_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
if getattr(event, "sender", None) == self.client.user_id:
|
if getattr(event, "sender", None) == self.client.user_id:
|
||||||
return
|
return
|
||||||
chat_id = await resolve_chat_id(self.runtime.store, room.room_id, event.sender)
|
sender = getattr(event, "sender", None)
|
||||||
|
body = (getattr(event, "body", None) or "").strip()
|
||||||
|
load_pending = await get_load_pending(self.runtime.store, sender, room.room_id)
|
||||||
|
if load_pending is not None and (body.isdigit() or body == "!cancel"):
|
||||||
|
outgoing = await self._handle_load_selection(sender, room.room_id, body, load_pending)
|
||||||
|
await self._send_all(room.room_id, outgoing)
|
||||||
|
return
|
||||||
|
|
||||||
|
reset_pending = await get_reset_pending(self.runtime.store, sender, room.room_id)
|
||||||
|
if reset_pending is not None and (body in {"!yes", "!no"} or body.startswith("!save ")):
|
||||||
|
outgoing = await self._handle_reset_selection(sender, room.room_id, body)
|
||||||
|
await self._send_all(room.room_id, outgoing)
|
||||||
|
return
|
||||||
|
|
||||||
|
chat_id = await resolve_chat_id(self.runtime.store, room.room_id, sender)
|
||||||
incoming = from_room_event(event, room_id=room.room_id, chat_id=chat_id)
|
incoming = from_room_event(event, room_id=room.room_id, chat_id=chat_id)
|
||||||
if incoming is None:
|
if incoming is None:
|
||||||
return
|
return
|
||||||
|
try:
|
||||||
outgoing = await self.runtime.dispatcher.dispatch(incoming)
|
outgoing = await self.runtime.dispatcher.dispatch(incoming)
|
||||||
|
except PlatformError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"matrix_message_platform_error",
|
||||||
|
room_id=room.room_id,
|
||||||
|
sender=getattr(event, "sender", None),
|
||||||
|
code=exc.code,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
outgoing = [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text="Сервис временно недоступен. Попробуйте ещё раз позже."
|
||||||
|
)
|
||||||
|
]
|
||||||
await self._send_all(room.room_id, outgoing)
|
await self._send_all(room.room_id, outgoing)
|
||||||
|
|
||||||
|
async def _handle_load_selection(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
room_id: str,
|
||||||
|
text: str,
|
||||||
|
pending: dict,
|
||||||
|
) -> list[OutgoingEvent]:
|
||||||
|
saves = pending.get("saves", [])
|
||||||
|
if text in {"0", "!cancel"}:
|
||||||
|
await clear_load_pending(self.runtime.store, user_id, room_id)
|
||||||
|
return [OutgoingMessage(chat_id=room_id, text="Отменено.")]
|
||||||
|
|
||||||
|
index = int(text) - 1
|
||||||
|
if index < 0 or index >= len(saves):
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=room_id,
|
||||||
|
text=f"Неверный номер. Введи от 1 до {len(saves)} или 0 для отмены.",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
name = saves[index]["name"]
|
||||||
|
await clear_load_pending(self.runtime.store, user_id, room_id)
|
||||||
|
prototype_state = getattr(self.runtime.platform, "_prototype_state", None)
|
||||||
|
if prototype_state is not None:
|
||||||
|
await prototype_state.set_current_session(user_id, name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.runtime.platform.send_message(
|
||||||
|
user_id,
|
||||||
|
room_id,
|
||||||
|
LOAD_PROMPT.format(name=name),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("load_agent_call_failed", error=str(exc))
|
||||||
|
return [OutgoingMessage(chat_id=room_id, text=f"Ошибка при загрузке: {exc}")]
|
||||||
|
return [OutgoingMessage(chat_id=room_id, text=f"Загрузка: {name}")]
|
||||||
|
|
||||||
|
async def _handle_reset_selection(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
room_id: str,
|
||||||
|
text: str,
|
||||||
|
) -> list[OutgoingEvent]:
|
||||||
|
agent_base_url = os.environ.get("AGENT_BASE_URL", "http://127.0.0.1:8000")
|
||||||
|
prototype_state = getattr(self.runtime.platform, "_prototype_state", None)
|
||||||
|
await clear_reset_pending(self.runtime.store, user_id, room_id)
|
||||||
|
|
||||||
|
if text == "!no":
|
||||||
|
return [OutgoingMessage(chat_id=room_id, text="Отменено.")]
|
||||||
|
|
||||||
|
if text.startswith("!save "):
|
||||||
|
name = _sanitize_session_name(text[len("!save ") :].strip())
|
||||||
|
if name is None:
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=room_id,
|
||||||
|
text="Имя сохранения может содержать только буквы, цифры, _ и -.",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
await self.runtime.platform.send_message(
|
||||||
|
user_id,
|
||||||
|
room_id,
|
||||||
|
SAVE_PROMPT.format(name=name),
|
||||||
|
)
|
||||||
|
if prototype_state is not None:
|
||||||
|
await prototype_state.add_saved_session(user_id, name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("save_before_reset_failed", error=str(exc))
|
||||||
|
return [OutgoingMessage(chat_id=room_id, text=f"Ошибка при сохранении: {exc}")]
|
||||||
|
|
||||||
|
if prototype_state is not None:
|
||||||
|
await prototype_state.clear_current_session(user_id)
|
||||||
|
return await _call_reset_endpoint(agent_base_url, room_id)
|
||||||
|
|
||||||
async def on_member(self, room: MatrixRoom, event: RoomMemberEvent) -> None:
|
async def on_member(self, room: MatrixRoom, event: RoomMemberEvent) -> None:
|
||||||
if getattr(event, "sender", None) == self.client.user_id:
|
if getattr(event, "sender", None) == self.client.user_id:
|
||||||
return
|
return
|
||||||
|
|
@ -236,8 +373,12 @@ async def main() -> None:
|
||||||
request_timeout=client_config.request_timeout,
|
request_timeout=client_config.request_timeout,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
if isinstance(runtime.platform, RealPlatformClient):
|
||||||
|
await runtime.platform.agent_api.connect()
|
||||||
await client.sync_forever(timeout=30000, since=since_token)
|
await client.sync_forever(timeout=30000, since=since_token)
|
||||||
finally:
|
finally:
|
||||||
|
if isinstance(runtime.platform, RealPlatformClient):
|
||||||
|
await runtime.platform.agent_api.close()
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,12 @@ from adapter.matrix.handlers.chat import (
|
||||||
make_handle_rename,
|
make_handle_rename,
|
||||||
)
|
)
|
||||||
from adapter.matrix.handlers.confirm import make_handle_cancel, make_handle_confirm
|
from adapter.matrix.handlers.confirm import make_handle_cancel, make_handle_confirm
|
||||||
|
from adapter.matrix.handlers.context_commands import (
|
||||||
|
make_handle_context,
|
||||||
|
make_handle_load,
|
||||||
|
make_handle_reset,
|
||||||
|
make_handle_save,
|
||||||
|
)
|
||||||
from adapter.matrix.handlers.settings import (
|
from adapter.matrix.handlers.settings import (
|
||||||
handle_help,
|
handle_help,
|
||||||
handle_settings,
|
handle_settings,
|
||||||
|
|
@ -23,7 +29,14 @@ from core.handler import EventDispatcher
|
||||||
from core.protocol import IncomingCallback, IncomingCommand
|
from core.protocol import IncomingCallback, IncomingCommand
|
||||||
|
|
||||||
|
|
||||||
def register_matrix_handlers(dispatcher: EventDispatcher, client=None, store=None) -> None:
|
def register_matrix_handlers(
|
||||||
|
dispatcher: EventDispatcher,
|
||||||
|
client=None,
|
||||||
|
store=None,
|
||||||
|
agent_api=None,
|
||||||
|
prototype_state=None,
|
||||||
|
agent_base_url: str = "http://127.0.0.1:8000",
|
||||||
|
) -> None:
|
||||||
dispatcher.register(IncomingCommand, "new", make_handle_new_chat(client, store))
|
dispatcher.register(IncomingCommand, "new", make_handle_new_chat(client, store))
|
||||||
dispatcher.register(IncomingCommand, "chats", handle_list_chats)
|
dispatcher.register(IncomingCommand, "chats", handle_list_chats)
|
||||||
dispatcher.register(IncomingCommand, "rename", make_handle_rename(client, store))
|
dispatcher.register(IncomingCommand, "rename", make_handle_rename(client, store))
|
||||||
|
|
@ -41,3 +54,9 @@ def register_matrix_handlers(dispatcher: EventDispatcher, client=None, store=Non
|
||||||
dispatcher.register(IncomingCallback, "confirm", make_handle_confirm(store))
|
dispatcher.register(IncomingCallback, "confirm", make_handle_confirm(store))
|
||||||
dispatcher.register(IncomingCallback, "cancel", make_handle_cancel(store))
|
dispatcher.register(IncomingCallback, "cancel", make_handle_cancel(store))
|
||||||
dispatcher.register(IncomingCallback, "toggle_skill", handle_toggle_skill)
|
dispatcher.register(IncomingCallback, "toggle_skill", handle_toggle_skill)
|
||||||
|
|
||||||
|
if agent_api is not None and prototype_state is not None:
|
||||||
|
dispatcher.register(IncomingCommand, "save", make_handle_save(agent_api, store, prototype_state))
|
||||||
|
dispatcher.register(IncomingCommand, "load", make_handle_load(store, prototype_state))
|
||||||
|
dispatcher.register(IncomingCommand, "reset", make_handle_reset(store, agent_base_url))
|
||||||
|
dispatcher.register(IncomingCommand, "context", make_handle_context(store, prototype_state))
|
||||||
|
|
|
||||||
172
adapter/matrix/handlers/context_commands.py
Normal file
172
adapter/matrix/handlers/context_commands.py
Normal file
|
|
@ -0,0 +1,172 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from adapter.matrix.store import set_load_pending, set_reset_pending
|
||||||
|
from core.protocol import IncomingCommand, OutgoingEvent, OutgoingMessage
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.store import StateStore
|
||||||
|
from sdk.prototype_state import PrototypeStateStore
|
||||||
|
|
||||||
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
SAVE_PROMPT = (
|
||||||
|
"Summarize our conversation and save to /workspace/contexts/{name}.md. "
|
||||||
|
"Reply only with: Saved: {name}"
|
||||||
|
)
|
||||||
|
LOAD_PROMPT = (
|
||||||
|
"Load context from /workspace/contexts/{name}.md and use it as background "
|
||||||
|
"for our conversation. Reply: Loaded: {name}"
|
||||||
|
)
|
||||||
|
_VALID_NAME = re.compile(r"^[A-Za-z0-9_-]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_session_name(raw_name: str) -> str | None:
|
||||||
|
name = raw_name.strip()
|
||||||
|
if not name or not _VALID_NAME.fullmatch(name):
|
||||||
|
return None
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_room_id(event: IncomingCommand, chat_mgr) -> str:
|
||||||
|
if chat_mgr is None:
|
||||||
|
return event.chat_id
|
||||||
|
ctx = await chat_mgr.get(event.chat_id, user_id=event.user_id)
|
||||||
|
if ctx is not None and ctx.surface_ref:
|
||||||
|
return ctx.surface_ref
|
||||||
|
return event.chat_id
|
||||||
|
|
||||||
|
|
||||||
|
def make_handle_save(agent_api, store: "StateStore", prototype_state: "PrototypeStateStore"):
|
||||||
|
async def handle_save(
|
||||||
|
event: IncomingCommand, auth_mgr, platform, chat_mgr, settings_mgr
|
||||||
|
) -> list[OutgoingEvent]:
|
||||||
|
if event.args:
|
||||||
|
name = _sanitize_session_name(event.args[0])
|
||||||
|
if name is None:
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=event.chat_id,
|
||||||
|
text="Имя сохранения может содержать только буквы, цифры, _ и -.",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
name = f"context-{datetime.now(UTC).strftime('%Y%m%d-%H%M%S')}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
await platform.send_message(
|
||||||
|
event.user_id,
|
||||||
|
event.chat_id,
|
||||||
|
SAVE_PROMPT.format(name=name),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("save_agent_call_failed", error=str(exc))
|
||||||
|
return [OutgoingMessage(chat_id=event.chat_id, text=f"Ошибка при сохранении: {exc}")]
|
||||||
|
|
||||||
|
await prototype_state.add_saved_session(event.user_id, name)
|
||||||
|
return [OutgoingMessage(chat_id=event.chat_id, text=f"Сохранение запущено: {name}")]
|
||||||
|
|
||||||
|
return handle_save
|
||||||
|
|
||||||
|
|
||||||
|
def make_handle_load(store: "StateStore", prototype_state: "PrototypeStateStore"):
|
||||||
|
async def handle_load(
|
||||||
|
event: IncomingCommand, auth_mgr, platform, chat_mgr, settings_mgr
|
||||||
|
) -> list[OutgoingEvent]:
|
||||||
|
sessions = await prototype_state.list_saved_sessions(event.user_id)
|
||||||
|
if not sessions:
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=event.chat_id,
|
||||||
|
text="Нет сохранённых сессий. Используй !save [имя].",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
room_id = await _resolve_room_id(event, chat_mgr)
|
||||||
|
lines = ["Сохранённые сессии:"]
|
||||||
|
for index, session in enumerate(sessions, start=1):
|
||||||
|
created = session.get("created_at", "")[:10]
|
||||||
|
lines.append(f" {index}. {session['name']} ({created})")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Введи номер или 0 / !cancel для отмены.")
|
||||||
|
|
||||||
|
await set_load_pending(store, event.user_id, room_id, {"saves": sessions})
|
||||||
|
return [OutgoingMessage(chat_id=event.chat_id, text="\n".join(lines))]
|
||||||
|
|
||||||
|
return handle_load
|
||||||
|
|
||||||
|
|
||||||
|
def make_handle_reset(store: "StateStore", agent_base_url: str):
|
||||||
|
async def handle_reset(
|
||||||
|
event: IncomingCommand, auth_mgr, platform, chat_mgr, settings_mgr
|
||||||
|
) -> list[OutgoingEvent]:
|
||||||
|
room_id = await _resolve_room_id(event, chat_mgr)
|
||||||
|
await set_reset_pending(store, event.user_id, room_id, {"active": True})
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=event.chat_id,
|
||||||
|
text=(
|
||||||
|
"Сбросить контекст агента? Выбери:\n"
|
||||||
|
" !yes - сбросить\n"
|
||||||
|
" !save [имя] - сохранить и сбросить\n"
|
||||||
|
" !no - отмена"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return handle_reset
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_reset_endpoint(agent_base_url: str, chat_id: str) -> list[OutgoingEvent]:
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(f"{agent_base_url}/reset", timeout=5.0)
|
||||||
|
except (httpx.ConnectError, httpx.TimeoutException) as exc:
|
||||||
|
logger.warning("reset_endpoint_unreachable", error=str(exc))
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text="Reset endpoint недоступен. Обратитесь к администратору.",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if response.status_code == 404:
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text="Reset endpoint недоступен. Обратитесь к администратору.",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return [OutgoingMessage(chat_id=chat_id, text="Контекст сброшен.")]
|
||||||
|
|
||||||
|
|
||||||
|
def make_handle_context(store: "StateStore", prototype_state: "PrototypeStateStore"):
|
||||||
|
async def handle_context(
|
||||||
|
event: IncomingCommand, auth_mgr, platform, chat_mgr, settings_mgr
|
||||||
|
) -> list[OutgoingEvent]:
|
||||||
|
current_session = await prototype_state.get_current_session(event.user_id)
|
||||||
|
tokens_used = await prototype_state.get_last_tokens_used(event.user_id)
|
||||||
|
sessions = await prototype_state.list_saved_sessions(event.user_id)
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"Контекст:",
|
||||||
|
f" Сессия: {current_session or 'не загружена'}",
|
||||||
|
f" Токены (последний ответ): {tokens_used}",
|
||||||
|
f" Сохранения ({len(sessions)}):",
|
||||||
|
]
|
||||||
|
if sessions:
|
||||||
|
for session in sessions:
|
||||||
|
created = session.get("created_at", "")[:10]
|
||||||
|
lines.append(f" - {session['name']} ({created})")
|
||||||
|
else:
|
||||||
|
lines.append(" (нет)")
|
||||||
|
|
||||||
|
return [OutgoingMessage(chat_id=event.chat_id, text="\n".join(lines))]
|
||||||
|
|
||||||
|
return handle_context
|
||||||
|
|
@ -33,6 +33,7 @@ class PrototypeStateStore:
|
||||||
self._settings: dict[str, dict[str, Any]] = {}
|
self._settings: dict[str, dict[str, Any]] = {}
|
||||||
self._saved_sessions: dict[str, list[dict[str, str]]] = {}
|
self._saved_sessions: dict[str, list[dict[str, str]]] = {}
|
||||||
self._last_tokens_used: dict[str, int] = {}
|
self._last_tokens_used: dict[str, int] = {}
|
||||||
|
self._current_session: dict[str, str] = {}
|
||||||
|
|
||||||
async def get_or_create_user(
|
async def get_or_create_user(
|
||||||
self,
|
self,
|
||||||
|
|
@ -93,3 +94,12 @@ class PrototypeStateStore:
|
||||||
|
|
||||||
async def set_last_tokens_used(self, user_id: str, tokens: int) -> None:
|
async def set_last_tokens_used(self, user_id: str, tokens: int) -> None:
|
||||||
self._last_tokens_used[user_id] = tokens
|
self._last_tokens_used[user_id] = tokens
|
||||||
|
|
||||||
|
async def get_current_session(self, user_id: str) -> str | None:
|
||||||
|
return self._current_session.get(user_id)
|
||||||
|
|
||||||
|
async def set_current_session(self, user_id: str, name: str) -> None:
|
||||||
|
self._current_session[user_id] = name
|
||||||
|
|
||||||
|
async def clear_current_session(self, user_id: str) -> None:
|
||||||
|
self._current_session.pop(user_id, None)
|
||||||
|
|
|
||||||
51
sdk/real.py
51
sdk/real.py
|
|
@ -1,26 +1,27 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, AsyncIterator
|
from typing import AsyncIterator
|
||||||
|
|
||||||
from sdk.agent_session import build_thread_key
|
from sdk.agent_api_wrapper import AgentApiWrapper
|
||||||
from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformClient, User, UserSettings
|
from sdk.interface import Attachment, MessageChunk, MessageResponse, PlatformClient, User, UserSettings
|
||||||
from sdk.prototype_state import PrototypeStateStore
|
from sdk.prototype_state import PrototypeStateStore
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sdk.agent_session import AgentSessionClient
|
|
||||||
|
|
||||||
|
|
||||||
class RealPlatformClient(PlatformClient):
|
class RealPlatformClient(PlatformClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
agent_sessions: AgentSessionClient,
|
agent_api: AgentApiWrapper,
|
||||||
prototype_state: PrototypeStateStore,
|
prototype_state: PrototypeStateStore,
|
||||||
platform: str = "matrix",
|
platform: str = "matrix",
|
||||||
) -> None:
|
) -> None:
|
||||||
self._agent_sessions = agent_sessions
|
self._agent_api = agent_api
|
||||||
self._prototype_state = prototype_state
|
self._prototype_state = prototype_state
|
||||||
self._platform = platform
|
self._platform = platform
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent_api(self) -> AgentApiWrapper:
|
||||||
|
return self._agent_api
|
||||||
|
|
||||||
async def get_or_create_user(
|
async def get_or_create_user(
|
||||||
self,
|
self,
|
||||||
external_id: str,
|
external_id: str,
|
||||||
|
|
@ -40,8 +41,23 @@ class RealPlatformClient(PlatformClient):
|
||||||
text: str,
|
text: str,
|
||||||
attachments: list[Attachment] | None = None,
|
attachments: list[Attachment] | None = None,
|
||||||
) -> MessageResponse:
|
) -> MessageResponse:
|
||||||
thread_key = build_thread_key(self._platform, user_id, chat_id)
|
response_parts: list[str] = []
|
||||||
return await self._agent_sessions.send_message(thread_key=thread_key, text=text)
|
tokens_used = 0
|
||||||
|
message_id = user_id
|
||||||
|
|
||||||
|
async for chunk in self.stream_message(user_id, chat_id, text, attachments=attachments):
|
||||||
|
message_id = chunk.message_id
|
||||||
|
if chunk.delta:
|
||||||
|
response_parts.append(chunk.delta)
|
||||||
|
if chunk.finished:
|
||||||
|
tokens_used = chunk.tokens_used
|
||||||
|
|
||||||
|
return MessageResponse(
|
||||||
|
message_id=message_id,
|
||||||
|
response="".join(response_parts),
|
||||||
|
tokens_used=tokens_used,
|
||||||
|
finished=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def stream_message(
|
async def stream_message(
|
||||||
self,
|
self,
|
||||||
|
|
@ -50,9 +66,20 @@ class RealPlatformClient(PlatformClient):
|
||||||
text: str,
|
text: str,
|
||||||
attachments: list[Attachment] | None = None,
|
attachments: list[Attachment] | None = None,
|
||||||
) -> AsyncIterator[MessageChunk]:
|
) -> AsyncIterator[MessageChunk]:
|
||||||
thread_key = build_thread_key(self._platform, user_id, chat_id)
|
self._agent_api.last_tokens_used = 0
|
||||||
async for chunk in self._agent_sessions.stream_message(thread_key=thread_key, text=text):
|
async for event in self._agent_api.send_message(text):
|
||||||
yield chunk
|
yield MessageChunk(
|
||||||
|
message_id=user_id,
|
||||||
|
delta=event.text,
|
||||||
|
finished=False,
|
||||||
|
)
|
||||||
|
await self._prototype_state.set_last_tokens_used(user_id, self._agent_api.last_tokens_used)
|
||||||
|
yield MessageChunk(
|
||||||
|
message_id=user_id,
|
||||||
|
delta="",
|
||||||
|
finished=True,
|
||||||
|
tokens_used=self._agent_api.last_tokens_used,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_settings(self, user_id: str) -> UserSettings:
|
async def get_settings(self, user_id: str) -> UserSettings:
|
||||||
return await self._prototype_state.get_settings(user_id)
|
return await self._prototype_state.get_settings(user_id)
|
||||||
|
|
|
||||||
237
tests/adapter/matrix/test_context_commands.py
Normal file
237
tests/adapter/matrix/test_context_commands.py
Normal file
|
|
@ -0,0 +1,237 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from adapter.matrix.bot import MatrixBot, build_runtime
|
||||||
|
from adapter.matrix.handlers.context_commands import (
|
||||||
|
make_handle_context,
|
||||||
|
make_handle_load,
|
||||||
|
make_handle_reset,
|
||||||
|
make_handle_save,
|
||||||
|
)
|
||||||
|
from adapter.matrix.store import get_load_pending, get_reset_pending, set_load_pending, set_reset_pending
|
||||||
|
from core.protocol import IncomingCommand, OutgoingMessage
|
||||||
|
from core.store import InMemoryStore
|
||||||
|
from sdk.interface import MessageResponse
|
||||||
|
from sdk.mock import MockPlatformClient
|
||||||
|
from sdk.prototype_state import PrototypeStateStore
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixCommandPlatform(MockPlatformClient):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._prototype_state = PrototypeStateStore()
|
||||||
|
self._agent_api = object()
|
||||||
|
self.send_message = AsyncMock(
|
||||||
|
return_value=MessageResponse(
|
||||||
|
message_id="msg-1",
|
||||||
|
response="ok",
|
||||||
|
tokens_used=0,
|
||||||
|
finished=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_command_auto_name_records_session():
|
||||||
|
platform = MatrixCommandPlatform()
|
||||||
|
store = InMemoryStore()
|
||||||
|
handler = make_handle_save(
|
||||||
|
agent_api=platform._agent_api,
|
||||||
|
store=store,
|
||||||
|
prototype_state=platform._prototype_state,
|
||||||
|
)
|
||||||
|
event = IncomingCommand(
|
||||||
|
user_id="u1",
|
||||||
|
platform="matrix",
|
||||||
|
chat_id="!room:example.org",
|
||||||
|
command="save",
|
||||||
|
args=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await handler(event, None, platform, None, None)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert isinstance(result[0], OutgoingMessage)
|
||||||
|
assert "Сохранение запущено" in result[0].text
|
||||||
|
sessions = await platform._prototype_state.list_saved_sessions("u1")
|
||||||
|
assert len(sessions) == 1
|
||||||
|
assert sessions[0]["name"].startswith("context-")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_command_with_name_uses_given_name():
|
||||||
|
platform = MatrixCommandPlatform()
|
||||||
|
store = InMemoryStore()
|
||||||
|
handler = make_handle_save(
|
||||||
|
agent_api=platform._agent_api,
|
||||||
|
store=store,
|
||||||
|
prototype_state=platform._prototype_state,
|
||||||
|
)
|
||||||
|
event = IncomingCommand(
|
||||||
|
user_id="u1",
|
||||||
|
platform="matrix",
|
||||||
|
chat_id="!room:example.org",
|
||||||
|
command="save",
|
||||||
|
args=["my-session"],
|
||||||
|
)
|
||||||
|
|
||||||
|
await handler(event, None, platform, None, None)
|
||||||
|
|
||||||
|
sessions = await platform._prototype_state.list_saved_sessions("u1")
|
||||||
|
assert [session["name"] for session in sessions] == ["my-session"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_command_shows_numbered_list_and_sets_pending():
|
||||||
|
platform = MatrixCommandPlatform()
|
||||||
|
runtime = build_runtime(platform=platform)
|
||||||
|
await runtime.chat_mgr.get_or_create(
|
||||||
|
user_id="u1",
|
||||||
|
chat_id="C1",
|
||||||
|
platform="matrix",
|
||||||
|
surface_ref="!room:example.org",
|
||||||
|
name="Chat 1",
|
||||||
|
)
|
||||||
|
await platform._prototype_state.add_saved_session("u1", "session-a")
|
||||||
|
await platform._prototype_state.add_saved_session("u1", "session-b")
|
||||||
|
|
||||||
|
handler = make_handle_load(store=runtime.store, prototype_state=platform._prototype_state)
|
||||||
|
event = IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="load", args=[])
|
||||||
|
|
||||||
|
result = await handler(event, runtime.auth_mgr, platform, runtime.chat_mgr, runtime.settings_mgr)
|
||||||
|
|
||||||
|
assert "1. session-a" in result[0].text
|
||||||
|
assert "2. session-b" in result[0].text
|
||||||
|
pending = await get_load_pending(runtime.store, "u1", "!room:example.org")
|
||||||
|
assert pending is not None
|
||||||
|
assert [session["name"] for session in pending["saves"]] == ["session-a", "session-b"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_command_without_saved_sessions_reports_empty():
|
||||||
|
platform = MatrixCommandPlatform()
|
||||||
|
store = InMemoryStore()
|
||||||
|
handler = make_handle_load(store=store, prototype_state=platform._prototype_state)
|
||||||
|
event = IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="load", args=[])
|
||||||
|
|
||||||
|
result = await handler(event, None, platform, None, None)
|
||||||
|
|
||||||
|
assert "Нет сохранённых сессий" in result[0].text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_command_shows_dialog_and_sets_pending():
|
||||||
|
platform = MatrixCommandPlatform()
|
||||||
|
runtime = build_runtime(platform=platform)
|
||||||
|
await runtime.chat_mgr.get_or_create(
|
||||||
|
user_id="u1",
|
||||||
|
chat_id="C1",
|
||||||
|
platform="matrix",
|
||||||
|
surface_ref="!room:example.org",
|
||||||
|
name="Chat 1",
|
||||||
|
)
|
||||||
|
handler = make_handle_reset(store=runtime.store, agent_base_url="http://127.0.0.1:8000")
|
||||||
|
event = IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="reset", args=[])
|
||||||
|
|
||||||
|
result = await handler(event, runtime.auth_mgr, platform, runtime.chat_mgr, runtime.settings_mgr)
|
||||||
|
|
||||||
|
assert "!yes" in result[0].text
|
||||||
|
assert "!save" in result[0].text
|
||||||
|
assert "!no" in result[0].text
|
||||||
|
assert await get_reset_pending(runtime.store, "u1", "!room:example.org") == {"active": True}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_endpoint_unavailable_reports_error():
|
||||||
|
with patch("adapter.matrix.handlers.context_commands.httpx.AsyncClient") as client_cls:
|
||||||
|
client = client_cls.return_value
|
||||||
|
client.__aenter__ = AsyncMock(return_value=client)
|
||||||
|
client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
client.post = AsyncMock(side_effect=httpx.ConnectError("refused"))
|
||||||
|
|
||||||
|
from adapter.matrix.handlers.context_commands import _call_reset_endpoint
|
||||||
|
|
||||||
|
result = await _call_reset_endpoint("http://127.0.0.1:8000", "!room:example.org")
|
||||||
|
|
||||||
|
assert "недоступен" in result[0].text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_command_shows_current_snapshot():
|
||||||
|
platform = MatrixCommandPlatform()
|
||||||
|
store = InMemoryStore()
|
||||||
|
await platform._prototype_state.set_current_session("u1", "session-a")
|
||||||
|
await platform._prototype_state.set_last_tokens_used("u1", 99)
|
||||||
|
await platform._prototype_state.add_saved_session("u1", "session-a")
|
||||||
|
handler = make_handle_context(store=store, prototype_state=platform._prototype_state)
|
||||||
|
event = IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="context", args=[])
|
||||||
|
|
||||||
|
result = await handler(event, None, platform, None, None)
|
||||||
|
|
||||||
|
assert "Сессия: session-a" in result[0].text
|
||||||
|
assert "Токены (последний ответ): 99" in result[0].text
|
||||||
|
assert "session-a" in result[0].text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bot_intercepts_numeric_load_selection():
|
||||||
|
platform = MatrixCommandPlatform()
|
||||||
|
runtime = build_runtime(platform=platform)
|
||||||
|
client = SimpleNamespace(
|
||||||
|
user_id="@bot:example.org",
|
||||||
|
room_send=AsyncMock(),
|
||||||
|
)
|
||||||
|
bot = MatrixBot(client, runtime)
|
||||||
|
await set_load_pending(
|
||||||
|
runtime.store,
|
||||||
|
"@alice:example.org",
|
||||||
|
"!room:example.org",
|
||||||
|
{"saves": [{"name": "session-a", "created_at": "2026-04-17T00:00:00+00:00"}]},
|
||||||
|
)
|
||||||
|
room = SimpleNamespace(room_id="!room:example.org")
|
||||||
|
event = SimpleNamespace(sender="@alice:example.org", body="1")
|
||||||
|
|
||||||
|
await bot.on_room_message(room, event)
|
||||||
|
|
||||||
|
platform.send_message.assert_awaited_once()
|
||||||
|
assert await platform._prototype_state.get_current_session("@alice:example.org") == "session-a"
|
||||||
|
client.room_send.assert_awaited_once_with(
|
||||||
|
"!room:example.org",
|
||||||
|
"m.room.message",
|
||||||
|
{"msgtype": "m.text", "body": "Загрузка: session-a"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bot_intercepts_reset_yes_before_dispatch():
|
||||||
|
platform = MatrixCommandPlatform()
|
||||||
|
runtime = build_runtime(platform=platform)
|
||||||
|
client = SimpleNamespace(
|
||||||
|
user_id="@bot:example.org",
|
||||||
|
room_send=AsyncMock(),
|
||||||
|
)
|
||||||
|
bot = MatrixBot(client, runtime)
|
||||||
|
runtime.dispatcher.dispatch = AsyncMock()
|
||||||
|
await set_reset_pending(runtime.store, "@alice:example.org", "!room:example.org", {"active": True})
|
||||||
|
room = SimpleNamespace(room_id="!room:example.org")
|
||||||
|
event = SimpleNamespace(sender="@alice:example.org", body="!yes")
|
||||||
|
|
||||||
|
with patch("adapter.matrix.handlers.context_commands.httpx.AsyncClient") as client_cls:
|
||||||
|
http_client = client_cls.return_value
|
||||||
|
http_client.__aenter__ = AsyncMock(return_value=http_client)
|
||||||
|
http_client.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
http_client.post = AsyncMock(return_value=SimpleNamespace(status_code=200))
|
||||||
|
|
||||||
|
await bot.on_room_message(room, event)
|
||||||
|
|
||||||
|
runtime.dispatcher.dispatch.assert_not_awaited()
|
||||||
|
client.room_send.assert_awaited_once_with(
|
||||||
|
"!room:example.org",
|
||||||
|
"m.room.message",
|
||||||
|
{"msgtype": "m.text", "body": "Контекст сброшен."},
|
||||||
|
)
|
||||||
|
|
@ -132,3 +132,14 @@ async def test_set_last_tokens_used_persists_value():
|
||||||
await store.set_last_tokens_used("usr-matrix-@alice:example.org", 321)
|
await store.set_last_tokens_used("usr-matrix-@alice:example.org", 321)
|
||||||
|
|
||||||
assert await store.get_last_tokens_used("usr-matrix-@alice:example.org") == 321
|
assert await store.get_last_tokens_used("usr-matrix-@alice:example.org") == 321
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_current_session_roundtrip():
|
||||||
|
store = PrototypeStateStore()
|
||||||
|
|
||||||
|
assert await store.get_current_session("usr-matrix-@alice:example.org") is None
|
||||||
|
|
||||||
|
await store.set_current_session("usr-matrix-@alice:example.org", "session-1")
|
||||||
|
|
||||||
|
assert await store.get_current_session("usr-matrix-@alice:example.org") == "session-1"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue