diff --git a/adapter/matrix/bot.py b/adapter/matrix/bot.py index 6ccebef..9c35d74 100644 --- a/adapter/matrix/bot.py +++ b/adapter/matrix/bot.py @@ -58,12 +58,14 @@ def build_event_dispatcher(platform: MockPlatformClient, store: StateStore) -> E platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr ) register_all(dispatcher) - register_matrix_handlers(dispatcher) + register_matrix_handlers(dispatcher, store=store) return dispatcher def build_runtime( - platform: MockPlatformClient | None = None, store: StateStore | None = None + platform: MockPlatformClient | None = None, + store: StateStore | None = None, + client: AsyncClient | None = None, ) -> MatrixRuntime: platform = platform or MockPlatformClient() store = store or InMemoryStore() @@ -74,7 +76,7 @@ def build_runtime( platform=platform, chat_mgr=chat_mgr, auth_mgr=auth_mgr, settings_mgr=settings_mgr ) register_all(dispatcher) - register_matrix_handlers(dispatcher) + register_matrix_handlers(dispatcher, client=client, store=store) return MatrixRuntime( platform=platform, store=store, @@ -187,13 +189,13 @@ async def main() -> None: if not homeserver or not user_id: raise RuntimeError("MATRIX_HOMESERVER and MATRIX_USER_ID are required") - runtime = build_runtime(store=SQLiteStore(db_path)) client = AsyncClient( homeserver, user=user_id, device_id=device_id, store_path=os.environ.get("MATRIX_STORE_PATH"), ) + runtime = build_runtime(store=SQLiteStore(db_path), client=client) if token: client.access_token = token elif password: diff --git a/adapter/matrix/handlers/__init__.py b/adapter/matrix/handlers/__init__.py index 61964e2..d03cba7 100644 --- a/adapter/matrix/handlers/__init__.py +++ b/adapter/matrix/handlers/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations from adapter.matrix.handlers.chat import ( handle_archive, handle_list_chats, - handle_new_chat, + make_handle_new_chat, handle_rename, ) from adapter.matrix.handlers.confirm import handle_cancel, handle_confirm @@ -22,8 +22,8 @@ from core.handler import EventDispatcher from core.protocol import IncomingCallback, IncomingCommand -def register_matrix_handlers(dispatcher: EventDispatcher) -> None: - dispatcher.register(IncomingCommand, "new", handle_new_chat) +def register_matrix_handlers(dispatcher: EventDispatcher, client=None, store=None) -> None: + dispatcher.register(IncomingCommand, "new", make_handle_new_chat(client, store)) dispatcher.register(IncomingCommand, "chats", handle_list_chats) dispatcher.register(IncomingCommand, "rename", handle_rename) dispatcher.register(IncomingCommand, "archive", handle_archive) diff --git a/adapter/matrix/handlers/chat.py b/adapter/matrix/handlers/chat.py index 700b881..9d20088 100644 --- a/adapter/matrix/handlers/chat.py +++ b/adapter/matrix/handlers/chat.py @@ -1,9 +1,12 @@ from __future__ import annotations +from typing import Any, Awaitable, Callable + +from adapter.matrix.store import set_room_meta from core.protocol import IncomingCommand, OutgoingMessage -async def handle_new_chat( +async def _fallback_new_chat( event: IncomingCommand, auth_mgr, platform, chat_mgr, settings_mgr ) -> list: if not await auth_mgr.is_authenticated(event.user_id): @@ -26,6 +29,60 @@ async def handle_new_chat( ] +def make_handle_new_chat( + client: Any | None, + store: Any | None, +) -> Callable[..., Awaitable[list]]: + async def handle_new_chat( + event: IncomingCommand, auth_mgr, platform, chat_mgr, settings_mgr + ) -> list: + if client is None or store is None: + return await _fallback_new_chat(event, auth_mgr, platform, chat_mgr, settings_mgr) + + if not await auth_mgr.is_authenticated(event.user_id): + return [OutgoingMessage(chat_id=event.chat_id, text="Введите !start чтобы начать.")] + + name = " ".join(event.args).strip() if event.args else "" + chats = await chat_mgr.list_active(event.user_id) + chat_id = f"C{len(chats) + 1}" + room_name = name or f"Чат {chat_id}" + + response = await client.room_create( + name=room_name, + invite=[event.user_id], + is_direct=False, + ) + room_id = getattr(response, "room_id", None) + if not room_id: + return [OutgoingMessage(chat_id=event.chat_id, text="Не удалось создать комнату.")] + + await set_room_meta( + store, + room_id, + { + "room_type": "chat", + "chat_id": chat_id, + "display_name": room_name, + "matrix_user_id": event.user_id, + }, + ) + ctx = await chat_mgr.get_or_create( + user_id=event.user_id, + chat_id=chat_id, + platform=event.platform, + surface_ref=room_id, + name=room_name, + ) + return [ + OutgoingMessage( + chat_id=event.chat_id, + text=f"Создан чат: {ctx.display_name} ({ctx.chat_id})\nКомната: {room_id}", + ) + ] + + return handle_new_chat + + async def handle_list_chats( event: IncomingCommand, auth_mgr, platform, chat_mgr, settings_mgr ) -> list: diff --git a/tests/adapter/matrix/test_dispatcher.py b/tests/adapter/matrix/test_dispatcher.py index 7b9b605..d8bfa69 100644 --- a/tests/adapter/matrix/test_dispatcher.py +++ b/tests/adapter/matrix/test_dispatcher.py @@ -49,6 +49,28 @@ async def test_matrix_dispatcher_registers_custom_handlers(): assert any(isinstance(r, OutgoingMessage) and "fetch-url" in r.text for r in result) +async def test_new_chat_creates_real_matrix_room_when_client_available(): + client = SimpleNamespace(room_create=AsyncMock(return_value=SimpleNamespace(room_id="!r2:example"))) + runtime = build_runtime(platform=MockPlatformClient(), client=client) + + start = IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="start") + await runtime.dispatcher.dispatch(start) + + new = IncomingCommand( + user_id="u1", + platform="matrix", + chat_id="C1", + command="new", + args=["Research"], + ) + result = await runtime.dispatcher.dispatch(new) + + client.room_create.assert_awaited_once_with(name="Research", invite=["u1"], is_direct=False) + chats = await runtime.chat_mgr.list_active("u1") + assert [c.surface_ref for c in chats] == ["!r2:example"] + assert any(isinstance(r, OutgoingMessage) and "!r2:example" in r.text for r in result) + + async def test_invite_event_creates_dm_room_and_sends_welcome(): runtime = build_runtime(platform=MockPlatformClient()) client = SimpleNamespace(join=AsyncMock(), room_send=AsyncMock())