- harden Matrix onboarding/chat lifecycle after manual QA - refresh README and Matrix docs to match current behavior - add local ignores for runtime artifacts and include current planning/report docs Closes #7 Closes #9 Closes #14
2667 lines
115 KiB
Python
Executable file
2667 lines
115 KiB
Python
Executable file
"""Matrix bot frontend.
|
||
|
||
Connects to a Matrix homeserver, listens for messages in rooms,
|
||
routes them through Claude CLI sessions. Same session layer as Telegram bot.
|
||
|
||
Commands:
|
||
!new [topic] — Create a new conversation room with optional topic name.
|
||
!claude-auth — Refresh Claude Code OAuth token (manual browser flow).
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
from nio import (
|
||
AsyncClient,
|
||
AsyncClientConfig,
|
||
MatrixRoom,
|
||
MegolmEvent,
|
||
RoomEncryptedAudio,
|
||
RoomEncryptedFile,
|
||
RoomEncryptedImage,
|
||
RoomMemberEvent,
|
||
RoomMessageAudio,
|
||
RoomMessageImage,
|
||
RoomMessageText,
|
||
RoomMessageFile,
|
||
RoomMessageUnknown,
|
||
SyncResponse,
|
||
UnknownEvent,
|
||
)
|
||
from nio.events.to_device import (
|
||
KeyVerificationCancel,
|
||
KeyVerificationKey,
|
||
KeyVerificationMac,
|
||
KeyVerificationStart,
|
||
)
|
||
|
||
from nio.crypto import decrypt_attachment
|
||
|
||
from core.asr import transcribe
|
||
from core.claude_session import send_message as claude_send
|
||
from core.config import Config
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class SessionState:
|
||
"""Tracks an active Claude session for a room."""
|
||
cancel_event: asyncio.Event
|
||
user_event_id: str # original user message (thread root)
|
||
status_event_id: str | None = None # status message in thread
|
||
status_lines: list[str] = field(default_factory=list)
|
||
last_status_edit: float = 0.0
|
||
idle_timeout_ref: list = field(default_factory=lambda: [120])
|
||
start_time: float = field(default_factory=time.monotonic)
|
||
|
||
|
||
class MatrixBot:
|
||
def __init__(self, config: Config, homeserver: str, user_id: str, access_token: str,
|
||
owner_mxid: str = "", users: dict[str, dict] | None = None,
|
||
device_id: str = "AGENT_CORE", admin_mxid: str = ""):
|
||
self.config = config
|
||
self.owner_mxid = owner_mxid
|
||
self.admin_mxid = admin_mxid # For admin notifications (fallback, errors)
|
||
self._users = users or {}
|
||
# If single-owner mode (no users map), treat owner as the only allowed user
|
||
if not self._users and owner_mxid:
|
||
self._users = {owner_mxid: {}}
|
||
# E2E: crypto store for keys, auto-decrypt/encrypt
|
||
store_path = str(config.data_dir / "crypto_store")
|
||
Path(store_path).mkdir(parents=True, exist_ok=True)
|
||
client_config = AsyncClientConfig(
|
||
encryption_enabled=True,
|
||
store_sync_tokens=True,
|
||
)
|
||
self.client = AsyncClient(
|
||
homeserver, user_id,
|
||
device_id=device_id,
|
||
store_path=store_path,
|
||
config=client_config,
|
||
)
|
||
self.client.restore_login(user_id, device_id, access_token)
|
||
self._synced = False
|
||
self._default_room_prefix = "Bot: "
|
||
self._pending_questions: dict[str, asyncio.Future] = {}
|
||
self._active_sessions: dict[str, SessionState] = {} # room_id -> session state
|
||
# Persistent message queue removed — using queue.jsonl files instead
|
||
self._auth_flows: dict[str, dict] = {} # safe_id -> {tmux_session, started}
|
||
self._collect_preambles: dict[str, str] = {} # safe_id -> preamble for next Claude call
|
||
self._processed_events: set[str] = set()
|
||
self._room_verifications: dict[str, dict] = {} # tx_id → state
|
||
self._sync_token_path = config.data_dir / "matrix_sync_token.txt"
|
||
self._avatar_mxc: str | None = None # cached after upload
|
||
|
||
def _is_allowed_user(self, sender: str) -> bool:
|
||
return sender in self._users
|
||
|
||
def _get_user_workspace(self, sender: str) -> Path | None:
|
||
"""Get workspace directory for a user, or None."""
|
||
user_info = self._users.get(sender, {})
|
||
ws = user_info.get("workspace")
|
||
if ws:
|
||
path = Path(ws)
|
||
if path.is_dir():
|
||
return path
|
||
return None
|
||
|
||
def _get_user_profile(self, sender: str) -> str:
|
||
"""Load user.md content for a sender, or empty string."""
|
||
user_info = self._users.get(sender, {})
|
||
profile_file = user_info.get("profile")
|
||
if profile_file and self.config.workspace_dir:
|
||
path = self.config.workspace_dir / profile_file
|
||
if path.exists():
|
||
return path.read_text().strip()
|
||
# Fallback: single-user mode with user.md
|
||
if self.config.workspace_dir:
|
||
path = self.config.workspace_dir / "user.md"
|
||
if path.exists():
|
||
return path.read_text().strip()
|
||
return ""
|
||
|
||
def _is_group_room(self, room: MatrixRoom) -> bool:
|
||
"""Room has more than 2 members (joined + invited, not a 1:1 chat)."""
|
||
return (room.member_count + room.invited_count) > 2
|
||
|
||
def _text_mentions_bot(self, text: str) -> bool:
|
||
"""Check if text contains a bot mention (@user_id, localpart, or display name)."""
|
||
text = text.lower()
|
||
# Check user_id (@bot:your.homeserver.example)
|
||
if self.client.user_id.lower() in text:
|
||
return True
|
||
# Check localpart (bot)
|
||
local_name = self.client.user_id.split(":")[0].lstrip("@").lower()
|
||
if local_name in text:
|
||
return True
|
||
# Check display name from any room
|
||
for room in self.client.rooms.values():
|
||
me = room.users.get(self.client.user_id)
|
||
if me and me.display_name and me.display_name.lower() in text:
|
||
return True
|
||
return False
|
||
|
||
def _strip_mention_prefix(self, text: str) -> str:
|
||
"""Strip bot mention prefix from text (e.g. '@[bot-dev] !status' → '!status')."""
|
||
import re
|
||
local_name = self.client.user_id.split(":")[0].lstrip("@")
|
||
names = [re.escape(self.client.user_id), re.escape(local_name)]
|
||
for room in self.client.rooms.values():
|
||
me = room.users.get(self.client.user_id)
|
||
if me and me.display_name:
|
||
names.append(re.escape(me.display_name))
|
||
break
|
||
alts = "|".join(names)
|
||
# Match: @[name], @name, name: , name, — with optional @[] wrapping and trailing punctuation
|
||
pattern = r"^@?\[?(?:" + alts + r")\]?[\s:,]*"
|
||
return re.sub(pattern, "", text, flags=re.IGNORECASE)
|
||
|
||
def _is_bot_mentioned(self, event: RoomMessageText) -> bool:
|
||
"""Check if bot is mentioned in a message event."""
|
||
# Check structured mentions first (m.mentions in content)
|
||
mentions = event.source.get("content", {}).get("m.mentions", {})
|
||
user_ids = mentions.get("user_ids", [])
|
||
if self.client.user_id in user_ids:
|
||
return True
|
||
return self._text_mentions_bot(event.body)
|
||
|
||
def _room_dir(self, room_id: str) -> Path:
|
||
safe_id = room_id.replace(":", "_").replace("!", "")
|
||
d = self.config.data_dir / "rooms" / safe_id
|
||
d.mkdir(parents=True, exist_ok=True)
|
||
return d
|
||
|
||
def _topic_dir(self, safe_id: str) -> Path:
|
||
return self.config.data_dir / "topics" / safe_id
|
||
|
||
# --- Room history ---
|
||
|
||
def _save_room_message(self, room_id: str, sender: str, msg_type: str, text: str,
|
||
file_path: str | None = None) -> None:
|
||
"""Append a message to room history. Called for ALL messages in ALL rooms."""
|
||
history_file = self._room_dir(room_id) / "history.jsonl"
|
||
display = sender.split(":")[0].lstrip("@")
|
||
entry: dict = {
|
||
"ts": datetime.now(timezone.utc).isoformat(),
|
||
"sender": sender,
|
||
"name": display,
|
||
"type": msg_type,
|
||
"text": text,
|
||
}
|
||
if file_path:
|
||
entry["file"] = file_path
|
||
with open(history_file, "a") as f:
|
||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||
|
||
def _get_room_context(self, room_id: str, limit: int = 50) -> str:
|
||
"""Read last N messages from history.jsonl and format as chat context."""
|
||
history_file = self._room_dir(room_id) / "history.jsonl"
|
||
if not history_file.exists():
|
||
return ""
|
||
lines = []
|
||
try:
|
||
with open(history_file) as f:
|
||
all_lines = f.readlines()
|
||
for line in all_lines[-limit:]:
|
||
line = line.strip()
|
||
if line:
|
||
lines.append(json.loads(line))
|
||
except Exception as e:
|
||
logger.warning("Failed to read room history: %s", e)
|
||
return ""
|
||
if not lines:
|
||
return ""
|
||
parts = []
|
||
for msg in lines:
|
||
name = msg.get("name", "?")
|
||
text = msg.get("text", "")
|
||
msg_type = msg.get("type", "text")
|
||
ts = msg.get("ts", "")[:16].replace("T", " ")
|
||
if msg_type == "image":
|
||
parts.append(f"[{ts}] {name}: [sent an image] {text}")
|
||
elif msg_type == "audio":
|
||
parts.append(f"[{ts}] {name}: [voice] {text}")
|
||
elif msg_type == "file":
|
||
parts.append(f"[{ts}] {name}: [sent a file] {text}")
|
||
else:
|
||
parts.append(f"[{ts}] {name}: {text}")
|
||
context = "\n".join(parts)
|
||
return (
|
||
"[Recent room history — you can see what participants discussed before mentioning you. "
|
||
"Use this context to understand the conversation. Do NOT repeat this history back.]\n\n"
|
||
+ context
|
||
)
|
||
|
||
# --- Room mode (quiet / context / full / collect) ---
|
||
|
||
ROOM_MODES = ("quiet", "context", "full", "collect")
|
||
|
||
def _get_room_mode(self, room_id: str) -> str:
|
||
"""Get room mode from config.json. Default: quiet for groups, full for 1:1."""
|
||
config_file = self._room_dir(room_id) / "config.json"
|
||
if config_file.exists():
|
||
try:
|
||
data = json.loads(config_file.read_text())
|
||
mode = data.get("mode", "")
|
||
if mode in self.ROOM_MODES:
|
||
return mode
|
||
except Exception:
|
||
pass
|
||
room = self.client.rooms.get(room_id)
|
||
if room and self._is_group_room(room):
|
||
return "quiet"
|
||
return "full"
|
||
|
||
def _set_room_mode(self, room_id: str, mode: str) -> None:
|
||
"""Save room mode to config.json."""
|
||
config_file = self._room_dir(room_id) / "config.json"
|
||
data = {}
|
||
if config_file.exists():
|
||
try:
|
||
data = json.loads(config_file.read_text())
|
||
except Exception:
|
||
pass
|
||
data["mode"] = mode
|
||
config_file.write_text(json.dumps(data, ensure_ascii=False, indent=2))
|
||
|
||
# --- Room security mode (strict / guarded / open) ---
|
||
|
||
SECURITY_MODES = ("strict", "guarded", "open")
|
||
|
||
def _get_security_mode(self, room_id: str) -> str:
|
||
"""Get room security mode from config.json. Default: guarded."""
|
||
config_file = self._room_dir(room_id) / "config.json"
|
||
if config_file.exists():
|
||
try:
|
||
data = json.loads(config_file.read_text())
|
||
mode = data.get("security", "")
|
||
if mode in self.SECURITY_MODES:
|
||
return mode
|
||
except Exception:
|
||
pass
|
||
return "guarded"
|
||
|
||
def _set_security_mode(self, room_id: str, mode: str) -> None:
|
||
"""Save room security mode to config.json."""
|
||
config_file = self._room_dir(room_id) / "config.json"
|
||
data = {}
|
||
if config_file.exists():
|
||
try:
|
||
data = json.loads(config_file.read_text())
|
||
except Exception:
|
||
pass
|
||
data["security"] = mode
|
||
config_file.write_text(json.dumps(data, ensure_ascii=False, indent=2))
|
||
|
||
def _get_unverified_devices(self, room_id: str) -> dict[str, list[str]]:
|
||
"""Return {user_id: [device_id, ...]} for unverified devices in a room.
|
||
|
||
Only checks allowed users (room members known to the bot).
|
||
"""
|
||
if not self.client.olm:
|
||
return {}
|
||
room = self.client.rooms.get(room_id)
|
||
if not room:
|
||
return {}
|
||
unverified: dict[str, list[str]] = {}
|
||
for user_id in room.users:
|
||
if user_id == self.client.user_id:
|
||
continue
|
||
for device in self.client.device_store.active_user_devices(user_id):
|
||
if not device.verified:
|
||
unverified.setdefault(user_id, []).append(device.id)
|
||
return unverified
|
||
|
||
def _user_fully_verified(self, sender: str) -> bool:
|
||
"""Check if all of sender's devices are verified."""
|
||
if not self.client.olm:
|
||
return True # no E2E, no verification needed
|
||
for device in self.client.device_store.active_user_devices(sender):
|
||
if not device.verified:
|
||
return False
|
||
return True
|
||
|
||
def _format_unverified_warning(self, unverified: dict[str, list[str]]) -> str:
|
||
"""Format a warning string listing unverified devices."""
|
||
parts = []
|
||
for user_id, devices in unverified.items():
|
||
dev_str = ", ".join(f"`{d}`" for d in devices)
|
||
parts.append(f"{user_id}: {dev_str}")
|
||
return "\u26a0 Unverified devices in room: " + "; ".join(parts)
|
||
|
||
async def _check_security(self, room_id: str, sender: str) -> tuple[bool, str | None]:
|
||
"""Check room security policy for a sender.
|
||
|
||
Returns:
|
||
(allowed, warning_or_error):
|
||
- (True, None) — proceed, no warning
|
||
- (True, warning) — proceed, append warning to response
|
||
- (False, error) — refuse, send error message
|
||
"""
|
||
security = self._get_security_mode(room_id)
|
||
if security == "open":
|
||
unverified = self._get_unverified_devices(room_id)
|
||
if unverified:
|
||
return True, self._format_unverified_warning(unverified)
|
||
return True, None
|
||
|
||
unverified = self._get_unverified_devices(room_id)
|
||
if not unverified:
|
||
return True, None
|
||
|
||
if security == "strict":
|
||
return False, (
|
||
"Room has unverified devices — refusing to respond.\n"
|
||
+ self._format_unverified_warning(unverified)
|
||
+ "\n\nVerify devices or use `!security open` from a fully verified session."
|
||
)
|
||
|
||
# guarded: block only users with unverified devices
|
||
sender_unverified = unverified.get(sender)
|
||
if sender_unverified:
|
||
dev_str = ", ".join(f"`{d}`" for d in sender_unverified)
|
||
return False, (
|
||
f"You have unverified devices ({dev_str}) — not accepting commands.\n"
|
||
"Verify your devices or ask a verified user to `!security open`."
|
||
)
|
||
return True, None
|
||
|
||
def _log_interaction(self, room_id: str, user_msg: str, bot_msg: str) -> None:
|
||
log_file = self._room_dir(room_id) / "log.jsonl"
|
||
entry = {
|
||
"ts": datetime.now(timezone.utc).isoformat(),
|
||
"user": user_msg[:1000],
|
||
"bot": bot_msg[:2000],
|
||
}
|
||
with open(log_file, "a") as f:
|
||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||
|
||
def _md_to_html(self, text: str) -> str:
|
||
"""Convert markdown to Matrix HTML, with tables as monospace <pre> blocks."""
|
||
import re
|
||
import markdown
|
||
|
||
lines = text.split("\n")
|
||
result_lines = []
|
||
table_lines = []
|
||
in_table = False
|
||
|
||
for line in lines:
|
||
is_table_line = bool(re.match(r"^\s*\|.*\|\s*$", line))
|
||
is_separator = bool(re.match(r"^\s*\|[-:| ]+\|\s*$", line))
|
||
|
||
if is_table_line:
|
||
if not in_table:
|
||
in_table = True
|
||
table_lines = []
|
||
if not is_separator:
|
||
table_lines.append(line)
|
||
else:
|
||
table_lines.append(line)
|
||
else:
|
||
if in_table:
|
||
result_lines.append("```")
|
||
result_lines.extend(table_lines)
|
||
result_lines.append("```")
|
||
table_lines = []
|
||
in_table = False
|
||
result_lines.append(line)
|
||
|
||
if in_table:
|
||
result_lines.append("```")
|
||
result_lines.extend(table_lines)
|
||
result_lines.append("```")
|
||
|
||
text = "\n".join(result_lines)
|
||
html = markdown.markdown(text, extensions=["fenced_code"])
|
||
return html
|
||
|
||
# --- Avatar management ---
|
||
|
||
def _avatar_path(self) -> Path | None:
|
||
"""Return path to avatar.jpg in workspace, or None."""
|
||
if self.config.workspace_dir:
|
||
p = self.config.workspace_dir / "avatar.jpg"
|
||
if p.exists():
|
||
return p
|
||
return None
|
||
|
||
async def _set_bot_avatar(self) -> None:
|
||
"""Upload avatar.jpg and set as bot profile picture (only if not already set)."""
|
||
path = self._avatar_path()
|
||
if not path:
|
||
return
|
||
try:
|
||
async with httpx.AsyncClient() as http:
|
||
user_id = self.client.user_id
|
||
hs = self.client.homeserver
|
||
# Check if avatar already set
|
||
resp = await http.get(
|
||
f"{hs}/_matrix/client/v3/profile/{user_id}/avatar_url",
|
||
headers={"Authorization": f"Bearer {self.client.access_token}"},
|
||
timeout=10,
|
||
)
|
||
if resp.status_code == 200:
|
||
existing = resp.json().get("avatar_url", "")
|
||
if existing:
|
||
self._avatar_mxc = existing
|
||
logger.info("Bot avatar already set: %s", existing)
|
||
return
|
||
# Upload and set
|
||
data = path.read_bytes()
|
||
mxc = await self._upload_file(data, "image/jpeg", "avatar.jpg")
|
||
if not mxc:
|
||
return
|
||
self._avatar_mxc = mxc
|
||
resp = await http.put(
|
||
f"{hs}/_matrix/client/v3/profile/{user_id}/avatar_url",
|
||
json={"avatar_url": mxc},
|
||
headers={"Authorization": f"Bearer {self.client.access_token}"},
|
||
timeout=15,
|
||
)
|
||
if resp.status_code == 200:
|
||
logger.info("Set bot profile avatar: %s", mxc)
|
||
else:
|
||
logger.warning("Failed to set profile avatar (%d): %s",
|
||
resp.status_code, resp.text[:200])
|
||
except Exception as e:
|
||
logger.warning("Failed to set bot avatar: %s", e)
|
||
|
||
async def _set_room_avatar(self, room_id: str) -> None:
|
||
"""Set room avatar to bot's avatar if not already set. Uses HTTP API directly."""
|
||
if not self._avatar_mxc:
|
||
return
|
||
try:
|
||
from urllib.parse import quote
|
||
hs = self.client.homeserver
|
||
rid = quote(room_id, safe="")
|
||
async with httpx.AsyncClient() as http:
|
||
# Check if avatar already set
|
||
resp = await http.get(
|
||
f"{hs}/_matrix/client/v3/rooms/{rid}/state/m.room.avatar",
|
||
headers={"Authorization": f"Bearer {self.client.access_token}"},
|
||
timeout=10,
|
||
)
|
||
if resp.status_code == 200:
|
||
return # already has avatar
|
||
# Set avatar
|
||
resp = await http.put(
|
||
f"{hs}/_matrix/client/v3/rooms/{rid}/state/m.room.avatar",
|
||
json={"url": self._avatar_mxc},
|
||
headers={"Authorization": f"Bearer {self.client.access_token}"},
|
||
timeout=10,
|
||
)
|
||
if resp.status_code == 200:
|
||
logger.info("Set room avatar for %s", room_id)
|
||
else:
|
||
logger.warning("Failed to set room avatar for %s (%d): %s",
|
||
room_id, resp.status_code, resp.text[:200])
|
||
except Exception as e:
|
||
logger.warning("Failed to set room avatar for %s: %s", room_id, e)
|
||
|
||
# --- Room management ---
|
||
|
||
async def _generate_room_label(self, room_id: str, current_label: str = "") -> str | None:
|
||
"""Generate a short room label via local LLM based on conversation history.
|
||
|
||
Returns None if generation fails, or the new label string.
|
||
"""
|
||
# Build context from history
|
||
history_file = self._room_dir(room_id) / "history.jsonl"
|
||
chat_lines = []
|
||
if history_file.exists():
|
||
try:
|
||
with open(history_file) as f:
|
||
all_lines = f.readlines()
|
||
for line in all_lines[-15:]:
|
||
line = line.strip()
|
||
if line:
|
||
msg = json.loads(line)
|
||
name = msg.get("name", "?")
|
||
text = msg.get("text", "")[:150]
|
||
chat_lines.append(f"{name}: {text}")
|
||
except Exception:
|
||
pass
|
||
if not chat_lines:
|
||
return None
|
||
|
||
conversation = "\n".join(chat_lines)
|
||
user_content = conversation
|
||
if current_label:
|
||
user_content = f"Current name: {current_label}\n\n{conversation}"
|
||
|
||
api_base = os.environ.get("LOCAL_LLM_URL") or os.environ.get("OPENAI_API_BASE", "http://localhost:4000/v1")
|
||
api_key = os.environ.get("OPENAI_API_KEY", "")
|
||
model = os.environ.get("LOCAL_LLM_MODEL", "qwen3.5-122b")
|
||
llm_url = api_base.rstrip("/") + "/chat/completions"
|
||
headers = {}
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
try:
|
||
async with httpx.AsyncClient() as http:
|
||
resp = await http.post(llm_url, json={
|
||
"model": model,
|
||
"messages": [
|
||
{"role": "system", "content": (
|
||
"You generate short chat room titles (3-5 words) based on what the user is asking about. "
|
||
"Rules: output ONLY the title. No quotes, no prefixes. Same language as the user. "
|
||
"Focus on the user's main question or task, ignore bot replies and minor tangents."
|
||
)},
|
||
{"role": "user", "content": user_content},
|
||
],
|
||
"max_tokens": 20,
|
||
"temperature": 0.3,
|
||
"chat_template_kwargs": {"enable_thinking": False},
|
||
}, headers=headers, timeout=15)
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
label = data["choices"][0]["message"]["content"].strip().strip('"\'')
|
||
return label[:80] if label else None
|
||
except Exception as e:
|
||
logger.warning("Failed to generate room label: %s", e)
|
||
return None
|
||
|
||
async def _rename_room(self, room_id: str, safe_id: str,
|
||
user_text: str = "", response: str = "") -> None:
|
||
"""Rename room if it still has the default 'Bot: ' prefix."""
|
||
room = self.client.rooms.get(room_id)
|
||
if not room:
|
||
return
|
||
current_name = room.name or ""
|
||
if not current_name.startswith(self._default_room_prefix):
|
||
return # user renamed it manually — don't touch
|
||
current_label = current_name[len(self._default_room_prefix):].strip()
|
||
label = await self._generate_room_label(room_id, current_label)
|
||
if not label:
|
||
return
|
||
new_name = f"{self._default_room_prefix}{label}"
|
||
if new_name == current_name:
|
||
return
|
||
try:
|
||
from nio.responses import RoomPutStateError
|
||
resp = await self.client.room_put_state(
|
||
room_id, "m.room.name", {"name": new_name[:255]},
|
||
)
|
||
if isinstance(resp, RoomPutStateError):
|
||
logger.warning("Cannot rename room %s: %s", room_id, resp.status_code)
|
||
return
|
||
logger.info("Renamed room %s to: %s", room_id, new_name)
|
||
await self._set_room_avatar(room_id)
|
||
except Exception as e:
|
||
logger.warning("Failed to rename room: %s", e)
|
||
|
||
async def _create_conversation_room(self, name: str, for_user: str | None = None) -> str | None:
|
||
"""Create a private encrypted room and invite the user."""
|
||
initial_state = [
|
||
{
|
||
"type": "m.room.encryption",
|
||
"state_key": "",
|
||
"content": {"algorithm": "m.megolm.v1.aes-sha2"},
|
||
},
|
||
]
|
||
if self._avatar_mxc:
|
||
initial_state.append({
|
||
"type": "m.room.avatar",
|
||
"state_key": "",
|
||
"content": {"url": self._avatar_mxc},
|
||
})
|
||
body: dict = {
|
||
"name": name,
|
||
"visibility": "private",
|
||
"preset": "trusted_private_chat",
|
||
"invite": [for_user] if for_user else [],
|
||
}
|
||
# Give the target user admin power (matches Element-created rooms)
|
||
if for_user:
|
||
body["power_level_content_override"] = {
|
||
"users": {
|
||
self.client.user_id: 100,
|
||
for_user: 100,
|
||
},
|
||
}
|
||
if initial_state:
|
||
body["initial_state"] = initial_state
|
||
try:
|
||
async with httpx.AsyncClient() as http:
|
||
resp = await http.post(
|
||
f"{self.client.homeserver}/_matrix/client/v3/createRoom",
|
||
headers={
|
||
"Authorization": f"Bearer {self.client.access_token}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json=body,
|
||
timeout=15,
|
||
)
|
||
if resp.status_code == 200:
|
||
room_id = resp.json()["room_id"]
|
||
logger.info("Created room %s: %s", room_id, name)
|
||
return room_id
|
||
logger.error("Failed to create room (%d): %s", resp.status_code, resp.text[:200])
|
||
except Exception as e:
|
||
logger.error("Failed to create room: %s", e)
|
||
return None
|
||
|
||
# --- Sending ---
|
||
|
||
async def _send_response(self, room_id: str, response: str,
|
||
ignore_unverified_devices: bool = True) -> None:
|
||
"""Send response with HTML formatting."""
|
||
html = self._md_to_html(response)
|
||
await self.client.room_send(
|
||
room_id, "m.room.message",
|
||
{
|
||
"msgtype": "m.text",
|
||
"body": response,
|
||
"format": "org.matrix.custom.html",
|
||
"formatted_body": html,
|
||
},
|
||
ignore_unverified_devices=ignore_unverified_devices,
|
||
)
|
||
|
||
async def _upload_file(self, data: bytes, content_type: str, filename: str) -> str | None:
|
||
"""Upload file to Matrix via HTTP API directly."""
|
||
homeserver = self.client.homeserver
|
||
url = f"{homeserver}/_matrix/media/v3/upload?filename={filename}"
|
||
async with httpx.AsyncClient() as http:
|
||
resp = await http.post(
|
||
url, content=data,
|
||
headers={
|
||
"Authorization": f"Bearer {self.client.access_token}",
|
||
"Content-Type": content_type,
|
||
},
|
||
timeout=60,
|
||
)
|
||
if resp.status_code == 200:
|
||
return resp.json().get("content_uri")
|
||
logger.error("Matrix upload failed (%d): %s", resp.status_code, resp.text[:200])
|
||
return None
|
||
|
||
async def _download_media(self, event) -> bytes | None:
|
||
"""Download media from Matrix, decrypting if E2E encrypted."""
|
||
resp = await self.client.download(event.url)
|
||
if not hasattr(resp, "body"):
|
||
logger.error("Failed to download media: %s", resp)
|
||
return None
|
||
data = resp.body
|
||
# Encrypted media (RoomEncryptedImage/Audio/File) has key/hashes/iv
|
||
if hasattr(event, "key") and hasattr(event, "hashes") and hasattr(event, "iv"):
|
||
try:
|
||
data = decrypt_attachment(
|
||
data, event.key["k"], event.hashes["sha256"], event.iv,
|
||
)
|
||
except Exception as e:
|
||
logger.error("Failed to decrypt attachment: %s", e)
|
||
return None
|
||
return data
|
||
|
||
async def _send_outbox(self, room_id: str, room_dir: Path) -> None:
|
||
"""Send files queued in outbox.jsonl by Claude via send-to-user tool."""
|
||
outbox = room_dir / "outbox.jsonl"
|
||
if not outbox.exists():
|
||
return
|
||
|
||
entries = []
|
||
try:
|
||
with open(outbox) as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
entries.append(json.loads(line))
|
||
outbox.unlink()
|
||
except Exception as e:
|
||
logger.error("Failed to read outbox: %s", e)
|
||
return
|
||
|
||
mime_map = {
|
||
"jpg": "image/jpeg", "jpeg": "image/jpeg", "png": "image/png",
|
||
"webp": "image/webp", "gif": "image/gif", "bmp": "image/bmp",
|
||
"mp4": "video/mp4", "mov": "video/quicktime", "webm": "video/webm",
|
||
"ogg": "audio/ogg", "mp3": "audio/mpeg", "wav": "audio/wav", "m4a": "audio/mp4",
|
||
"pdf": "application/pdf", "doc": "application/msword",
|
||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||
"html": "text/html", "txt": "text/plain", "csv": "text/csv",
|
||
"zip": "application/zip", "json": "application/json",
|
||
}
|
||
|
||
for entry in entries:
|
||
fpath = Path(entry.get("path", ""))
|
||
ftype = entry.get("type", "document")
|
||
|
||
if not fpath.is_file():
|
||
logger.warning("Outbox file not found: %s", fpath)
|
||
continue
|
||
|
||
try:
|
||
data = fpath.read_bytes()
|
||
ext = fpath.suffix.lstrip(".").lower()
|
||
content_type = mime_map.get(ext, "application/octet-stream")
|
||
|
||
content_uri = await self._upload_file(data, content_type, fpath.name)
|
||
if not content_uri:
|
||
continue
|
||
|
||
if ftype == "image":
|
||
msgtype = "m.image"
|
||
elif ftype == "video":
|
||
msgtype = "m.video"
|
||
elif ftype == "audio":
|
||
msgtype = "m.audio"
|
||
else:
|
||
msgtype = "m.file"
|
||
|
||
await self.client.room_send(
|
||
room_id, "m.room.message",
|
||
{
|
||
"msgtype": msgtype,
|
||
"body": fpath.name,
|
||
"filename": fpath.name,
|
||
"url": content_uri,
|
||
"info": {"mimetype": content_type, "size": len(data)},
|
||
},
|
||
ignore_unverified_devices=True,
|
||
)
|
||
logger.info("Sent %s to Matrix: %s", ftype, fpath.name)
|
||
except Exception as e:
|
||
logger.error("Failed to send %s %s: %s", ftype, fpath.name, e)
|
||
|
||
def _sender_display_name(self, room: MatrixRoom, sender: str) -> str:
|
||
"""Get display name for a sender in a room, fallback to localpart."""
|
||
member = room.users.get(sender)
|
||
if member and member.display_name:
|
||
return member.display_name
|
||
return sender.split(":")[0].lstrip("@")
|
||
|
||
async def _fetch_recent_messages(self, room_id: str, limit: int = 5) -> list[dict]:
|
||
"""Fetch recent messages from a room for context mode."""
|
||
room = self.client.rooms.get(room_id)
|
||
if not room or not room.prev_batch:
|
||
return []
|
||
resp = await self.client.room_messages(room_id, start=room.prev_batch, limit=limit)
|
||
if not hasattr(resp, "chunk"):
|
||
return []
|
||
messages = []
|
||
for event in reversed(resp.chunk): # chronological order
|
||
if event.sender == self.client.user_id:
|
||
continue
|
||
body = getattr(event, "body", None)
|
||
if not body:
|
||
continue
|
||
name = self._sender_display_name(room, event.sender)
|
||
messages.append({"sender": name, "text": body})
|
||
return messages
|
||
|
||
# --- Thread status messaging ---
|
||
|
||
async def _send_thread_message(self, room_id: str, thread_root_event_id: str,
|
||
body: str) -> str | None:
|
||
"""Send a notice in a thread under the given event."""
|
||
content = {
|
||
"msgtype": "m.notice",
|
||
"body": body,
|
||
"m.relates_to": {
|
||
"rel_type": "m.thread",
|
||
"event_id": thread_root_event_id,
|
||
"is_falling_back": True,
|
||
"m.in_reply_to": {"event_id": thread_root_event_id},
|
||
},
|
||
}
|
||
resp = await self.client.room_send(
|
||
room_id, "m.room.message", content,
|
||
ignore_unverified_devices=True,
|
||
)
|
||
if hasattr(resp, "event_id"):
|
||
return resp.event_id
|
||
return None
|
||
|
||
async def _edit_message(self, room_id: str, event_id: str, new_body: str) -> None:
|
||
"""Edit an existing message using m.replace relation."""
|
||
content = {
|
||
"msgtype": "m.notice",
|
||
"body": f"* {new_body}",
|
||
"m.new_content": {
|
||
"msgtype": "m.notice",
|
||
"body": new_body,
|
||
},
|
||
"m.relates_to": {
|
||
"rel_type": "m.replace",
|
||
"event_id": event_id,
|
||
},
|
||
}
|
||
await self.client.room_send(
|
||
room_id, "m.room.message", content,
|
||
ignore_unverified_devices=True,
|
||
)
|
||
|
||
async def _run_claude_session(self, room: MatrixRoom, event, message: str,
|
||
security_msg: str | None = None,
|
||
on_question=None,
|
||
on_done=None,
|
||
**extra_kwargs) -> None:
|
||
"""Run a Claude session as a background task.
|
||
|
||
Runs concurrently so the sync loop stays free to process !stop etc.
|
||
on_done(response) is called after session completes (for logging, renaming).
|
||
"""
|
||
room_id = room.room_id
|
||
safe_id = room_id.replace(":", "_").replace("!", "")
|
||
|
||
cancel_event = asyncio.Event()
|
||
idle_timeout_ref = [self.config.claude_idle_timeout]
|
||
session = SessionState(
|
||
cancel_event=cancel_event,
|
||
user_event_id=event.event_id,
|
||
idle_timeout_ref=idle_timeout_ref,
|
||
start_time=time.monotonic(),
|
||
)
|
||
self._active_sessions[room_id] = session
|
||
|
||
status_event_id = await self._send_thread_message(
|
||
room_id, event.event_id, "Working..."
|
||
)
|
||
session.status_event_id = status_event_id
|
||
on_status = self._make_on_status(room_id, session)
|
||
|
||
user_profile = self._get_user_profile(event.sender)
|
||
workspace_dir = self._get_user_workspace(event.sender)
|
||
|
||
# Default on_question: post to room, wait for user reply
|
||
if on_question is None:
|
||
async def on_question(question: str) -> str:
|
||
await self.client.room_send(
|
||
room_id, "m.room.message",
|
||
{"msgtype": "m.text", "body": f"? {question}"},
|
||
ignore_unverified_devices=True,
|
||
)
|
||
future = asyncio.get_event_loop().create_future()
|
||
self._pending_questions[safe_id] = future
|
||
return await future
|
||
|
||
# Run as background task so sync loop stays free to process !stop etc.
|
||
async def _session_task():
|
||
response = ""
|
||
try:
|
||
response = await self._call_claude(
|
||
room_id, safe_id, message,
|
||
on_status=on_status, cancel_event=cancel_event,
|
||
idle_timeout_ref=idle_timeout_ref,
|
||
on_question=on_question,
|
||
user_profile=user_profile, sender=event.sender,
|
||
workspace_dir=workspace_dir,
|
||
**extra_kwargs,
|
||
)
|
||
display = response + f"\n\n{security_msg}" if security_msg else response
|
||
await self._send_response(room_id, display)
|
||
except RuntimeError as e:
|
||
if cancel_event.is_set():
|
||
await self._send_response(room_id, "Stopped.")
|
||
response = "[cancelled]"
|
||
else:
|
||
logger.error("Claude error in room %s: %s", room.display_name, e)
|
||
await self._send_response(room_id, f"Error: {e}")
|
||
response = f"[error] {e}"
|
||
finally:
|
||
elapsed = int(time.monotonic() - session.start_time)
|
||
mins, secs = divmod(elapsed, 60)
|
||
time_str = f"{mins}m {secs:02d}s" if mins else f"{secs}s"
|
||
tools_used = len(session.status_lines)
|
||
final_status = f"Done ({time_str}, {tools_used} tools)"
|
||
if session.cancel_event.is_set():
|
||
final_status = f"Cancelled ({time_str})"
|
||
try:
|
||
if session.status_event_id:
|
||
await self._edit_message(room_id, session.status_event_id, final_status)
|
||
except Exception:
|
||
pass
|
||
|
||
await self._send_outbox(room_id, self._topic_dir(safe_id))
|
||
|
||
# Auto-commit workspace changes
|
||
if workspace_dir:
|
||
asyncio.create_task(self._auto_commit_workspace(workspace_dir, room))
|
||
|
||
# Post-session callback (logging, renaming, etc.)
|
||
if on_done:
|
||
try:
|
||
await on_done(response)
|
||
except Exception as e:
|
||
logger.warning("on_done callback failed: %s", e)
|
||
|
||
# Process queued messages — combine all into one prompt.
|
||
# Drain BEFORE popping session so room stays "busy" and new
|
||
# messages don't sneak in between drain and new session start.
|
||
queued, last_eid = self._drain_queue(room_id)
|
||
if queued and last_eid:
|
||
# _process_queued_messages calls _run_claude_session which
|
||
# overwrites _active_sessions[room_id] with a new session.
|
||
await self._process_queued_messages(room, queued, last_eid)
|
||
else:
|
||
self._active_sessions.pop(room_id, None)
|
||
|
||
asyncio.create_task(_session_task())
|
||
|
||
async def _auto_commit_workspace(self, workspace_dir: Path, room: MatrixRoom) -> None:
|
||
"""Git commit workspace changes after a session, if any."""
|
||
try:
|
||
# Check for uncommitted changes
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"git", "status", "--porcelain",
|
||
cwd=str(workspace_dir),
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
stdout, _ = await proc.communicate()
|
||
if not stdout.strip():
|
||
return # nothing changed
|
||
|
||
# Stage all and commit
|
||
await (await asyncio.create_subprocess_exec(
|
||
"git", "add", "-A",
|
||
cwd=str(workspace_dir),
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)).communicate()
|
||
|
||
room_name = room.display_name or room.room_id
|
||
msg = f"auto: {room_name}"
|
||
await (await asyncio.create_subprocess_exec(
|
||
"git", "commit", "-m", msg, "--no-gpg-sign",
|
||
cwd=str(workspace_dir),
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)).communicate()
|
||
logger.info("Auto-committed workspace changes: %s", workspace_dir)
|
||
except Exception as e:
|
||
logger.warning("Workspace auto-commit failed: %s", e)
|
||
|
||
def _is_room_busy(self, room_id: str) -> bool:
|
||
return room_id in self._active_sessions
|
||
|
||
def _enqueue_message(self, room_id: str, event_id: str, sender: str,
|
||
text: str, msg_type: str = "text",
|
||
file_path: str | None = None) -> None:
|
||
"""Queue a processed message to queue.jsonl for later delivery."""
|
||
queue_file = self._room_dir(room_id) / "queue.jsonl"
|
||
entry = {
|
||
"ts": datetime.now(timezone.utc).isoformat(),
|
||
"event_id": event_id,
|
||
"sender": sender,
|
||
"type": msg_type,
|
||
"text": text,
|
||
}
|
||
if file_path:
|
||
entry["file"] = file_path
|
||
with open(queue_file, "a") as f:
|
||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||
count = sum(1 for _ in open(queue_file))
|
||
logger.info("Queued message for room %s (%d pending)", room_id, count)
|
||
|
||
def _drain_queue(self, room_id: str) -> tuple[list[dict], str | None]:
|
||
"""Read and clear queue.jsonl. Returns (messages, last_event_id)."""
|
||
queue_file = self._room_dir(room_id) / "queue.jsonl"
|
||
if not queue_file.exists():
|
||
return [], None
|
||
messages = []
|
||
try:
|
||
with open(queue_file) as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line:
|
||
messages.append(json.loads(line))
|
||
queue_file.unlink()
|
||
except Exception as e:
|
||
logger.warning("Failed to drain queue for %s: %s", room_id, e)
|
||
last_event_id = messages[-1]["event_id"] if messages else None
|
||
return messages, last_event_id
|
||
|
||
async def _process_queued_messages(self, room: MatrixRoom,
|
||
messages: list[dict], last_event_id: str) -> None:
|
||
"""Combine queued messages into one prompt and send to Claude."""
|
||
room_id = room.room_id
|
||
safe_id = room_id.replace(":", "_").replace("!", "")
|
||
|
||
# Build combined prompt
|
||
parts = []
|
||
for msg in messages:
|
||
mtype = msg.get("type", "text")
|
||
text = msg.get("text", "")
|
||
fpath = msg.get("file", "")
|
||
if mtype == "image":
|
||
parts.append(f"[User sent an image: {fpath}]")
|
||
if text:
|
||
parts.append(text)
|
||
elif mtype == "audio":
|
||
parts.append(f"[voice message]: {text}")
|
||
elif mtype == "file":
|
||
parts.append(f"[User sent a file: {fpath}]")
|
||
else:
|
||
parts.append(text)
|
||
|
||
combined = "\n".join(parts)
|
||
if len(messages) > 1:
|
||
combined = (f"[{len(messages)} messages arrived while you were busy. "
|
||
f"Process them all:]\n\n{combined}")
|
||
|
||
# Minimal event-like object — covers all attributes accessed by
|
||
# _run_claude_session and downstream code paths
|
||
sender = messages[-1].get("sender", "")
|
||
event = type("QueuedEvent", (), {
|
||
"event_id": last_event_id,
|
||
"sender": sender,
|
||
"body": combined[:100],
|
||
"source": {"content": {}}, # empty — won't match thread checks
|
||
})()
|
||
|
||
mode = self._get_room_mode(room_id)
|
||
|
||
async def _on_done(response: str):
|
||
if mode == "full":
|
||
self._save_room_message(room_id, self.client.user_id, "text", response)
|
||
await self._rename_room(room_id, safe_id)
|
||
self._log_interaction(room_id, combined[:200], response)
|
||
|
||
# Add full context if in full mode
|
||
message_for_claude = combined
|
||
if mode == "full":
|
||
for msg in messages:
|
||
self._save_room_message(room_id, msg.get("sender", ""),
|
||
msg.get("type", "text"), msg.get("text", ""))
|
||
context = self._get_room_context(room_id)
|
||
if context:
|
||
message_for_claude = context + "\n\n---\n\n" + combined
|
||
|
||
await self._run_claude_session(
|
||
room, event, message_for_claude, on_done=_on_done,
|
||
)
|
||
|
||
async def _handle_thread_command(self, room_id: str, user_text: str,
|
||
session: SessionState) -> bool:
|
||
"""Handle user commands in a session thread. Returns True if handled."""
|
||
cmd = user_text.strip().lower().lstrip("!")
|
||
if cmd in ("stop", "cancel", "abort"):
|
||
session.cancel_event.set()
|
||
await self._send_thread_message(room_id, session.user_event_id, "Stopping...")
|
||
return True
|
||
if cmd in ("more time", "+5m", "+5"):
|
||
session.idle_timeout_ref[0] += 300
|
||
mins = session.idle_timeout_ref[0] // 60
|
||
await self._send_thread_message(
|
||
room_id, session.user_event_id, f"Timeout extended to {mins}m")
|
||
return True
|
||
if cmd in ("+10m", "+10"):
|
||
session.idle_timeout_ref[0] += 600
|
||
mins = session.idle_timeout_ref[0] // 60
|
||
await self._send_thread_message(
|
||
room_id, session.user_event_id, f"Timeout extended to {mins}m")
|
||
return True
|
||
return False
|
||
|
||
def _make_on_status(self, room_id: str, session: SessionState):
|
||
"""Create an on_status callback that posts individual thread messages."""
|
||
async def on_status(status: dict):
|
||
event_type = status.get("event")
|
||
msg = None
|
||
|
||
if event_type == "tool_start":
|
||
tool = status.get("tool", "?")
|
||
preview = status.get("input_preview", "")
|
||
session.status_lines.append(tool) # count for final summary
|
||
if preview:
|
||
msg = f"`{tool}`: {preview}"
|
||
else:
|
||
msg = f"`{tool}`"
|
||
elif event_type == "tool_end":
|
||
pass # tool_start already posted, no need for end message
|
||
elif event_type == "agent_start":
|
||
desc = status.get("description", "subagent")
|
||
bg = " (bg)" if status.get("background") else ""
|
||
session.status_lines.append("Agent")
|
||
msg = f"`Agent{bg}`: {desc}"
|
||
elif event_type == "thinking":
|
||
text = status.get("text", "").strip()
|
||
if text:
|
||
msg = text
|
||
|
||
if msg and session.user_event_id:
|
||
try:
|
||
await self._send_thread_message(room_id, session.user_event_id, msg)
|
||
except Exception as e:
|
||
logger.debug("Failed to send thread status: %s", e)
|
||
|
||
return on_status
|
||
|
||
# --- Claude call wrapper ---
|
||
|
||
async def _notify_fallback_used(self, room_id: str, sender: str) -> None:
|
||
"""Send notification to admin when fallback provider was used."""
|
||
if not self.admin_mxid or sender == self.admin_mxid:
|
||
return # Don't notify if no admin or admin triggered it
|
||
|
||
# Find DM room with admin — prefer room named exactly after the bot
|
||
# Priority: exact bot name > "Bot: something" > any 1:1 room
|
||
dm_room_id = None
|
||
named_dm_id = None
|
||
any_dm_id = None
|
||
bot_name = self.client.user_id.split(":")[0].lstrip("@")
|
||
for room in self.client.rooms.values():
|
||
if len(room.users) == 2 and self.admin_mxid in room.users:
|
||
name = (room.name or "").strip()
|
||
if name.lower() == bot_name.lower():
|
||
dm_room_id = room.room_id
|
||
break
|
||
if bot_name.lower() in name.lower() and not named_dm_id:
|
||
named_dm_id = room.room_id
|
||
if not any_dm_id:
|
||
any_dm_id = room.room_id
|
||
if not dm_room_id:
|
||
dm_room_id = named_dm_id or any_dm_id
|
||
|
||
if not dm_room_id:
|
||
# Create DM room with admin
|
||
resp = await self.client.room_create(
|
||
visibility="private",
|
||
preset="trusted_private_chat",
|
||
invite=[self.admin_mxid],
|
||
)
|
||
if hasattr(resp, "room_id"):
|
||
dm_room_id = resp.room_id
|
||
logger.info("Created DM room with admin: %s", dm_room_id)
|
||
|
||
if dm_room_id:
|
||
room_link = f"https://matrix.to/#/{room_id}"
|
||
await self.client.room_send(
|
||
dm_room_id, "m.room.message",
|
||
{
|
||
"msgtype": "m.notice",
|
||
"body": f"⚠️ Fallback (z.ai) used for room {room_link} (sender: {sender})",
|
||
},
|
||
ignore_unverified_devices=True,
|
||
)
|
||
|
||
async def _call_claude(self, room_id: str, safe_id: str, message: str,
|
||
sender: str = "", on_status=None, cancel_event=None,
|
||
idle_timeout_ref=None, **kwargs) -> str:
|
||
"""Call Claude CLI with typing indicator and status updates."""
|
||
await self.client.room_typing(room_id, typing_state=True, timeout=30000)
|
||
try:
|
||
response = await claude_send(
|
||
self.config, safe_id, message,
|
||
on_status=on_status, cancel_event=cancel_event,
|
||
idle_timeout_ref=idle_timeout_ref,
|
||
**kwargs,
|
||
)
|
||
# Check if fallback was used and notify owner
|
||
if "(via z.ai fallback)" in response and sender:
|
||
asyncio.create_task(self._notify_fallback_used(room_id, sender))
|
||
return response
|
||
finally:
|
||
await self.client.room_typing(room_id, typing_state=False)
|
||
|
||
# --- Bot commands ---
|
||
|
||
async def _handle_status(self, room: MatrixRoom) -> None:
|
||
"""Handle !status: show room/session info."""
|
||
safe_id = room.room_id.replace(":", "_").replace("!", "")
|
||
topic_dir = self._topic_dir(safe_id)
|
||
is_busy = room.room_id in self._active_sessions
|
||
lines = [f"**Status: {'working' if is_busy else 'idle'}**", f"Room: `{safe_id}`"]
|
||
|
||
# Session info
|
||
session_file = topic_dir / "session.txt"
|
||
if session_file.exists():
|
||
sid = session_file.read_text().strip()
|
||
lines.append(f"Session: `{sid[:12]}...`")
|
||
else:
|
||
lines.append("Session: new")
|
||
|
||
# Topic dir size
|
||
if topic_dir.exists():
|
||
total = sum(f.stat().st_size for f in topic_dir.rglob("*") if f.is_file())
|
||
files = sum(1 for f in topic_dir.rglob("*") if f.is_file())
|
||
if total < 1024:
|
||
size_str = f"{total} B"
|
||
elif total < 1024 * 1024:
|
||
size_str = f"{total // 1024} KB"
|
||
else:
|
||
size_str = f"{total // (1024 * 1024)} MB"
|
||
lines.append(f"Dir: {files} files, {size_str}")
|
||
|
||
# Interaction count from log
|
||
log_file = self._room_dir(room.room_id) / "log.jsonl"
|
||
if log_file.exists():
|
||
count = sum(1 for _ in open(log_file))
|
||
lines.append(f"Interactions: {count}")
|
||
|
||
# Auth info
|
||
if os.environ.get("CLAUDE_CODE_OAUTH_TOKEN"):
|
||
lines.append("Auth: `CLAUDE_CODE_OAUTH_TOKEN` (long-lived)")
|
||
else:
|
||
lines.append("Auth: OAuth credentials (short-lived)")
|
||
|
||
await self._send_response(room.room_id, "\n".join(lines))
|
||
|
||
async def _handle_help(self, room: MatrixRoom) -> None:
|
||
"""Show available commands."""
|
||
room_id = room.room_id
|
||
mode = self._get_room_mode(room_id)
|
||
await self._send_response(room_id,
|
||
f"**Commands:**\n"
|
||
f"`!new [topic]` — new conversation room\n"
|
||
f"`!mode [mode]` — set room mode (current: `{mode}`)\n"
|
||
f" `quiet` — transcribe voice only\n"
|
||
f" `context` — include recent history\n"
|
||
f" `full` — persistent session with full history\n"
|
||
f" `collect` — accumulate notes/images/voice, no replies\n"
|
||
f"`!stop` — stop active Claude session\n"
|
||
f"`!status` — bot status and active sessions\n"
|
||
f"`!security [mode]` — room security level\n"
|
||
f"`!claude-auth` — refresh OAuth token (admin, 1:1 only)\n"
|
||
f"`!help` — this message")
|
||
|
||
async def _handle_mode_command(self, room: MatrixRoom, args: str) -> None:
|
||
"""Handle !mode [quiet|context|full]: set or show room mode."""
|
||
room_id = room.room_id
|
||
mode = args.strip().lower()
|
||
if not mode:
|
||
current = self._get_room_mode(room_id)
|
||
await self._send_response(room_id,
|
||
f"**Mode:** `{current}`\n"
|
||
f"Available: `quiet` (transcribe only), `context` (recent history), "
|
||
f"`full` (persistent session), `collect` (accumulate context, no replies)")
|
||
return
|
||
if mode not in self.ROOM_MODES:
|
||
await self._send_response(room_id,
|
||
f"Unknown mode `{mode}`. Use: quiet, context, full, collect")
|
||
return
|
||
prev_mode = self._get_room_mode(room_id)
|
||
self._set_room_mode(room_id, mode)
|
||
|
||
# When leaving collect mode, summarize what was accumulated
|
||
if prev_mode == "collect" and mode != "collect":
|
||
summary = self._collect_summary(room_id)
|
||
if summary:
|
||
await self._send_response(room_id,
|
||
f"Mode set to `{mode}`\n\n{summary}")
|
||
# Store preamble for next Claude call
|
||
safe_id = room_id.replace(":", "_").replace("!", "")
|
||
self._collect_preambles[safe_id] = summary
|
||
else:
|
||
await self._send_response(room_id, f"Mode set to `{mode}`")
|
||
else:
|
||
await self._send_response(room_id, f"Mode set to `{mode}`")
|
||
|
||
def _collect_summary(self, room_id: str) -> str:
|
||
"""Summarize what was accumulated in collect mode."""
|
||
history_file = self._room_dir(room_id) / "history.jsonl"
|
||
if not history_file.exists():
|
||
return ""
|
||
images, voice, texts, files = 0, 0, 0, 0
|
||
try:
|
||
with open(history_file) as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
msg = json.loads(line)
|
||
mtype = msg.get("type", "text")
|
||
sender = msg.get("sender", "")
|
||
if sender == self.client.user_id:
|
||
continue # skip bot messages
|
||
if mtype == "image":
|
||
images += 1
|
||
elif mtype == "audio":
|
||
voice += 1
|
||
elif mtype == "file":
|
||
files += 1
|
||
else:
|
||
texts += 1
|
||
except Exception:
|
||
return ""
|
||
parts = []
|
||
if images:
|
||
parts.append(f"{images} image(s)")
|
||
if voice:
|
||
parts.append(f"{voice} voice note(s)")
|
||
if texts:
|
||
parts.append(f"{texts} text message(s)")
|
||
if files:
|
||
parts.append(f"{files} file(s)")
|
||
if not parts:
|
||
return ""
|
||
return f"Accumulated: {', '.join(parts)}"
|
||
|
||
async def _handle_security_command(self, room: MatrixRoom, sender: str, args: str) -> None:
|
||
"""Handle !security [strict|guarded|open]: set or show room security mode."""
|
||
room_id = room.room_id
|
||
mode = args.strip().lower()
|
||
if not mode:
|
||
current = self._get_security_mode(room_id)
|
||
unverified = self._get_unverified_devices(room_id)
|
||
lines = [
|
||
f"**Security:** `{current}`",
|
||
"Available: `strict` (block all if unverified), "
|
||
"`guarded` (block unverified users), `open` (allow all + warning)",
|
||
]
|
||
if unverified:
|
||
lines.append(self._format_unverified_warning(unverified))
|
||
else:
|
||
lines.append("All devices in room are verified.")
|
||
await self._send_response(room_id, "\n".join(lines))
|
||
return
|
||
if mode not in self.SECURITY_MODES:
|
||
await self._send_response(room_id,
|
||
f"Unknown security mode `{mode}`. Use: strict, guarded, open")
|
||
return
|
||
# Loosening security requires fully verified sender
|
||
current = self._get_security_mode(room_id)
|
||
mode_rank = {"strict": 2, "guarded": 1, "open": 0}
|
||
if mode_rank[mode] < mode_rank[current]:
|
||
if not self._user_fully_verified(sender):
|
||
await self._send_response(room_id,
|
||
"Only users with all devices verified can loosen security.")
|
||
return
|
||
self._set_security_mode(room_id, mode)
|
||
await self._send_response(room_id, f"Security set to `{mode}`")
|
||
|
||
async def _handle_claude_auth_command(self, room: MatrixRoom, sender: str, args: str) -> None:
|
||
"""Handle !claude-auth command: refresh Claude Code OAuth token.
|
||
|
||
Restricted to admin (MATRIX_ADMIN_MXID) in 1:1 rooms only.
|
||
|
||
Flow:
|
||
1. !claude-auth -> runs `claude setup-token` in tmux, extracts URL
|
||
2. User opens URL, authenticates, copies token
|
||
3. User pastes token here -> bot feeds it to tmux via send-keys
|
||
4. `claude setup-token` finishes and writes credentials itself
|
||
"""
|
||
room_id = room.room_id
|
||
|
||
# Admin-only, 1:1 rooms only (token must not leak to group chat history)
|
||
if not self.admin_mxid or sender != self.admin_mxid:
|
||
await self._send_response(room_id, "This command is admin-only.")
|
||
return
|
||
if self._is_group_room(room):
|
||
await self._send_response(room_id, "This command only works in 1:1 rooms (token security).")
|
||
return
|
||
|
||
safe_id = room_id.replace(":", "_").replace("!", "")
|
||
|
||
# Phase 2: user pasted the token — feed it to tmux
|
||
if safe_id in self._auth_flows:
|
||
token = args.strip()
|
||
flow = self._auth_flows.get(safe_id, {})
|
||
tmux_session = flow.get("tmux_session")
|
||
|
||
if not tmux_session:
|
||
self._auth_flows.pop(safe_id, None)
|
||
await self._send_response(room_id, "Auth flow lost its tmux session. Run `!claude-auth` again.")
|
||
return
|
||
|
||
try:
|
||
# Feed token to claude setup-token via tmux
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"tmux", "send-keys", "-t", tmux_session, token, "Enter",
|
||
stdout=asyncio.subprocess.DEVNULL,
|
||
stderr=asyncio.subprocess.PIPE
|
||
)
|
||
_, stderr = await proc.communicate()
|
||
if proc.returncode != 0:
|
||
self._auth_flows.pop(safe_id, None)
|
||
await self._send_response(room_id,
|
||
f"Failed to send token to tmux: {stderr.decode().strip()}\nRun `!claude-auth` again.")
|
||
return
|
||
|
||
# Wait for setup-token to process and exit
|
||
await self._send_response(room_id, "Token sent to `claude setup-token`, waiting for it to finish...")
|
||
|
||
success = False
|
||
for _ in range(15):
|
||
await asyncio.sleep(1)
|
||
# Check if tmux session still exists
|
||
check = await asyncio.create_subprocess_exec(
|
||
"tmux", "has-session", "-t", tmux_session,
|
||
stdout=asyncio.subprocess.DEVNULL,
|
||
stderr=asyncio.subprocess.DEVNULL
|
||
)
|
||
await check.wait()
|
||
if check.returncode != 0:
|
||
# Session exited — setup-token finished
|
||
success = True
|
||
break
|
||
|
||
# Also check pane output for success/error messages
|
||
cap = await asyncio.create_subprocess_exec(
|
||
"tmux", "capture-pane", "-t", tmux_session, "-p",
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.DEVNULL
|
||
)
|
||
stdout, _ = await cap.communicate()
|
||
output = stdout.decode('utf-8', errors='replace').lower()
|
||
if 'success' in output or 'saved' in output or 'authenticated' in output:
|
||
success = True
|
||
break
|
||
if 'error' in output or 'invalid' in output or 'failed' in output:
|
||
clean = re.sub(r'\x1b\[[0-9;]*[a-zA-Z]', '', stdout.decode('utf-8', errors='replace'))
|
||
self._auth_flows.pop(safe_id, None)
|
||
await self._kill_tmux(tmux_session)
|
||
await self._send_response(room_id,
|
||
f"`claude setup-token` reported an error:\n```\n{clean.strip()[-500:]}\n```")
|
||
return
|
||
|
||
self._auth_flows.pop(safe_id, None)
|
||
|
||
# Capture pane output BEFORE killing tmux — it contains the long-lived token
|
||
final_output = ""
|
||
if success:
|
||
cap = await asyncio.create_subprocess_exec(
|
||
"tmux", "capture-pane", "-t", tmux_session, "-p", "-S", "-100",
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.DEVNULL
|
||
)
|
||
stdout, _ = await cap.communicate()
|
||
final_output = stdout.decode('utf-8', errors='replace')
|
||
|
||
await self._kill_tmux(tmux_session)
|
||
|
||
if success:
|
||
# Extract long-lived token from setup-token output
|
||
clean_output = re.sub(r'\x1b\[[0-9;]*[a-zA-Z]', '', final_output)
|
||
clean_output = re.sub(r'\x1b[^a-zA-Z]*[a-zA-Z]', '', clean_output)
|
||
oauth_token = self._extract_oauth_token(clean_output)
|
||
|
||
if oauth_token:
|
||
# Try to save to deploy .env
|
||
saved = self._save_oauth_token_to_env(oauth_token)
|
||
if saved:
|
||
msg = "Long-lived token saved to deploy `.env`. Restart bot to apply."
|
||
else:
|
||
msg = (f"Token extracted. Set in deploy `.env` and restart:\n"
|
||
f"`CLAUDE_CODE_OAUTH_TOKEN={oauth_token}`")
|
||
else:
|
||
msg = "Auth completed but could not extract long-lived token from output."
|
||
|
||
# Also verify with claude auth status
|
||
status_proc = await asyncio.create_subprocess_exec(
|
||
"claude", "auth", "status",
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE
|
||
)
|
||
status_out, _ = await status_proc.communicate()
|
||
status_text = status_out.decode('utf-8', errors='replace').strip()
|
||
|
||
await self._send_response(room_id,
|
||
f"{msg}\n\n```\n{status_text[:500]}\n```")
|
||
logger.info("Claude auth flow completed for room %s (token saved: %s)",
|
||
room_id, bool(oauth_token))
|
||
else:
|
||
await self._send_response(room_id,
|
||
"`claude setup-token` didn't finish within 15s. "
|
||
"Check manually with `claude auth status`.")
|
||
|
||
except Exception as e:
|
||
self._auth_flows.pop(safe_id, None)
|
||
await self._kill_tmux(tmux_session)
|
||
logger.error("Error feeding token to tmux: %s", e)
|
||
await self._send_response(room_id, f"Error: {e}")
|
||
return
|
||
|
||
# Phase 1: start claude setup-token in tmux, extract URL
|
||
await self._send_response(room_id, "Starting Claude Code OAuth flow...")
|
||
|
||
tmux_session = f"claude-auth-{safe_id[:20]}"
|
||
|
||
try:
|
||
# Kill any leftover session
|
||
await self._kill_tmux(tmux_session)
|
||
await asyncio.sleep(0.3)
|
||
|
||
# Start claude setup-token in tmux
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"tmux", "new-session", "-d", "-s", tmux_session,
|
||
"-x", "200", "-y", "50",
|
||
"claude", "setup-token"
|
||
)
|
||
await proc.wait()
|
||
|
||
# Poll for the OAuth URL to appear
|
||
output = ""
|
||
for _ in range(15):
|
||
await asyncio.sleep(1)
|
||
|
||
cap = await asyncio.create_subprocess_exec(
|
||
"tmux", "capture-pane", "-t", tmux_session, "-p",
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.DEVNULL
|
||
)
|
||
stdout, _ = await cap.communicate()
|
||
output = stdout.decode('utf-8', errors='replace')
|
||
|
||
if 'oauth/authorize' in output.lower() or 'console.anthropic.com' in output.lower():
|
||
break
|
||
|
||
# Strip ANSI escapes
|
||
clean_output = re.sub(r'\x1b\[[0-9;]*[a-zA-Z]', '', output)
|
||
clean_output = re.sub(r'\x1b[^a-zA-Z]*[a-zA-Z]', '', clean_output)
|
||
|
||
# tmux wraps long URLs across lines — join continuation lines
|
||
# Remove newlines that break mid-URL (lines not starting with whitespace
|
||
# after a line ending with a URL-safe char)
|
||
lines = clean_output.split('\n')
|
||
joined = lines[0] if lines else ''
|
||
for line in lines[1:]:
|
||
stripped = line.strip()
|
||
# If prev line ends with URL-safe char and this line looks like URL continuation
|
||
if stripped and not stripped.startswith(('$', '#', '>', ' ')) and re.match(r'^[a-zA-Z0-9%&=_.~:/?#\[\]@!$\'()*+,;-]', stripped):
|
||
# Check if we're likely in a URL context
|
||
if joined.rstrip().endswith(tuple('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789%&=_.-~:/?#[]@!$\'()*+,;')):
|
||
joined += stripped
|
||
continue
|
||
joined += '\n' + line
|
||
clean_output = joined
|
||
|
||
# Extract URL
|
||
url_match = re.search(r'(https://[^\s]*(?:oauth/authorize|console\.anthropic\.com)[^\s]*)', clean_output)
|
||
|
||
if not url_match:
|
||
await self._kill_tmux(tmux_session)
|
||
await self._send_response(room_id,
|
||
"Could not extract auth URL from `claude setup-token`.\n"
|
||
f"```\n{clean_output.strip()[:500]}\n```")
|
||
logger.warning("claude setup-token output: %s", clean_output)
|
||
return
|
||
|
||
auth_url = url_match.group(1)
|
||
|
||
# Register auth flow
|
||
self._auth_flows[safe_id] = {
|
||
"tmux_session": tmux_session,
|
||
"started": time.time()
|
||
}
|
||
|
||
await self._send_response(room_id,
|
||
"**Claude Code Authentication**\n\n"
|
||
f"1. Open: {auth_url}\n\n"
|
||
"2. Authenticate and copy the token from the page\n\n"
|
||
"3. Paste it here\n\n"
|
||
"Flow expires in 5 minutes."
|
||
)
|
||
|
||
# Timeout cleanup
|
||
async def _auth_cleanup():
|
||
await asyncio.sleep(300)
|
||
if safe_id in self._auth_flows:
|
||
flow = self._auth_flows.pop(safe_id, {})
|
||
await self._kill_tmux(flow.get("tmux_session"))
|
||
await self._send_response(room_id, "Auth flow expired. Run `!claude-auth` to restart.")
|
||
|
||
asyncio.create_task(_auth_cleanup())
|
||
|
||
except Exception as e:
|
||
await self._kill_tmux(tmux_session)
|
||
logger.error("Error starting claude setup-token: %s", e)
|
||
await self._send_response(room_id, f"Error: {e}")
|
||
|
||
async def _kill_tmux(self, session: str | None) -> None:
|
||
"""Kill a tmux session if it exists."""
|
||
if not session:
|
||
return
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"tmux", "kill-session", "-t", session,
|
||
stdout=asyncio.subprocess.DEVNULL,
|
||
stderr=asyncio.subprocess.DEVNULL
|
||
)
|
||
await proc.wait()
|
||
|
||
@staticmethod
|
||
def _extract_oauth_token(text: str) -> str | None:
|
||
"""Extract CLAUDE_CODE_OAUTH_TOKEN from setup-token output."""
|
||
# Look for the token after "export CLAUDE_CODE_OAUTH_TOKEN=" or similar
|
||
m = re.search(r'CLAUDE_CODE_OAUTH_TOKEN[=\s]+([a-zA-Z0-9_\-]+)', text)
|
||
if m:
|
||
return m.group(1)
|
||
# Fallback: look for sk-ant-oat pattern (setup-token format)
|
||
m = re.search(r'(sk-ant-oat[a-zA-Z0-9_\-]+)', text)
|
||
if m:
|
||
return m.group(1)
|
||
return None
|
||
|
||
def _save_oauth_token_to_env(self, token: str) -> bool:
|
||
"""Save CLAUDE_CODE_OAUTH_TOKEN to workspace .env file."""
|
||
if not self.config.workspace_dir:
|
||
return False
|
||
env_path = Path(self.config.workspace_dir) / ".env"
|
||
try:
|
||
content = env_path.read_text() if env_path.exists() else ""
|
||
if "CLAUDE_CODE_OAUTH_TOKEN=" in content:
|
||
content = re.sub(
|
||
r'CLAUDE_CODE_OAUTH_TOKEN=.*',
|
||
f'CLAUDE_CODE_OAUTH_TOKEN={token}',
|
||
content
|
||
)
|
||
else:
|
||
content = content.rstrip('\n') + f'\nCLAUDE_CODE_OAUTH_TOKEN={token}\n'
|
||
env_path.write_text(content)
|
||
os.chmod(env_path, 0o600)
|
||
logger.info("Saved CLAUDE_CODE_OAUTH_TOKEN to %s", env_path)
|
||
return True
|
||
except Exception as e:
|
||
logger.error("Failed to save token to %s: %s", env_path, e)
|
||
return False
|
||
|
||
async def _handle_new_command(self, room: MatrixRoom, event_sender: str, topic: str) -> None:
|
||
"""Handle !new command: create a new conversation room and invite user."""
|
||
room_id = room.room_id
|
||
name = topic.strip() if topic.strip() else f"{self._default_room_prefix}Новый чат"
|
||
|
||
new_room_id = await self._create_conversation_room(name, for_user=event_sender)
|
||
if not new_room_id:
|
||
await self._send_response(room_id, "Failed to create room.")
|
||
return
|
||
|
||
room_link = f"https://matrix.to/#/{new_room_id}"
|
||
display_name = name.removeprefix(self._default_room_prefix)
|
||
await self.client.room_send(
|
||
room_id, "m.room.message",
|
||
{
|
||
"msgtype": "m.text",
|
||
"body": f"{display_name}: {room_link}",
|
||
"format": "org.matrix.custom.html",
|
||
"formatted_body": f"<a href='{room_link}'>{display_name}</a>",
|
||
},
|
||
ignore_unverified_devices=True,
|
||
)
|
||
logger.info("Created /new room %s: %s", new_room_id, name)
|
||
|
||
# --- Message handlers ---
|
||
|
||
async def _handle_text(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||
is_group = self._is_group_room(room)
|
||
|
||
# 1:1 rooms: only owner can use the bot
|
||
# Group rooms: anyone can mention the bot
|
||
if not is_group and not self._is_allowed_user(event.sender):
|
||
return
|
||
|
||
user_text = event.body
|
||
room_id = room.room_id
|
||
safe_id = room_id.replace(":", "_").replace("!", "")
|
||
|
||
# Check if this is a session command — thread reply or !command while busy
|
||
session = self._active_sessions.get(room_id)
|
||
if session:
|
||
relates_to = event.source.get("content", {}).get("m.relates_to", {})
|
||
is_thread = relates_to.get("rel_type") == "m.thread"
|
||
is_bang_cmd = user_text.strip().lower().lstrip("!") in (
|
||
"stop", "cancel", "abort", "+5m", "+5", "+10m", "+10",
|
||
)
|
||
if is_thread or is_bang_cmd:
|
||
if await self._handle_thread_command(room_id, user_text, session):
|
||
return
|
||
|
||
# Strip mention prefix (e.g. "Bot: !status" → "!status")
|
||
command_text = self._strip_mention_prefix(user_text)
|
||
|
||
# If Claude is waiting for an answer in this room, deliver it
|
||
if safe_id in self._pending_questions:
|
||
future = self._pending_questions.pop(safe_id)
|
||
if not future.done():
|
||
future.set_result(user_text)
|
||
return
|
||
|
||
# Check if we're in an auth flow for this room
|
||
if safe_id in self._auth_flows:
|
||
# Only intercept if it looks like a token (long, no spaces, no command prefix)
|
||
candidate = user_text.strip()
|
||
if len(candidate) > 20 and ' ' not in candidate and not candidate.startswith('!'):
|
||
# Redact the token message from chat history
|
||
try:
|
||
await self.client.room_redact(room_id, event.event_id, reason="auth token")
|
||
except Exception:
|
||
pass # best-effort, E2E rooms may not support redaction
|
||
await self._handle_claude_auth_command(room, event.sender, user_text)
|
||
return
|
||
# If it looks like a command or normal message, check for !claude-auth cancel
|
||
if candidate.lower() in ('!cancel', '!claude-auth cancel', 'cancel'):
|
||
flow = self._auth_flows.pop(safe_id, {})
|
||
await self._kill_tmux(flow.get("tmux_session"))
|
||
await self._send_response(room_id, "Auth flow cancelled.")
|
||
return
|
||
# Fall through to normal message handling
|
||
|
||
# Bot commands — only allowed users
|
||
if self._is_allowed_user(event.sender):
|
||
if command_text.strip() in ("!help", "!commands", "!?"):
|
||
await self._handle_help(room)
|
||
return
|
||
if command_text.startswith("!new"):
|
||
topic = command_text[4:].strip()
|
||
await self._handle_new_command(room, event.sender, topic)
|
||
return
|
||
if command_text.strip() == "!status":
|
||
await self._handle_status(room)
|
||
return
|
||
if command_text.startswith("!mode"):
|
||
await self._handle_mode_command(room, command_text[5:])
|
||
return
|
||
if command_text.startswith("!security"):
|
||
await self._handle_security_command(room, event.sender, command_text[9:])
|
||
return
|
||
if command_text.strip() in ("!claude-auth", "!claudeauth"):
|
||
await self._handle_claude_auth_command(room, event.sender, "")
|
||
return
|
||
|
||
mode = self._get_room_mode(room_id)
|
||
|
||
# Group rooms: only respond when mentioned (quiet/context modes)
|
||
if is_group and mode not in ("full", "collect"):
|
||
logger.info("Group room %s (members=%d), checking mention", room_id, room.member_count)
|
||
if not self._is_bot_mentioned(event):
|
||
logger.info("Not mentioned in group room, skipping")
|
||
return
|
||
|
||
# Collect mode: save to history, acknowledge, no Claude
|
||
if mode == "collect":
|
||
self._save_room_message(room_id, event.sender, "text", user_text)
|
||
return
|
||
|
||
# Check if already processing in this room — queue if busy
|
||
if self._is_room_busy(room_id):
|
||
self._enqueue_message(room_id, event.event_id, event.sender, user_text)
|
||
return
|
||
|
||
# Security check — after mention check, before Claude interaction
|
||
allowed, security_msg = await self._check_security(room_id, event.sender)
|
||
if not allowed:
|
||
await self._send_response(room_id, security_msg)
|
||
return
|
||
|
||
# In full mode, save every message to room history
|
||
if mode == "full":
|
||
self._save_room_message(room_id, event.sender, "text", user_text)
|
||
|
||
# Build message for Claude
|
||
message_for_claude = user_text
|
||
if mode == "context":
|
||
recent = await self._fetch_recent_messages(room_id, limit=10)
|
||
if recent:
|
||
context_lines = [f"{m['sender']}: {m['text']}" for m in recent]
|
||
context_block = "\n".join(context_lines)
|
||
message_for_claude = (
|
||
"[Recent room messages for context]\n"
|
||
f"{context_block}\n\n---\n\n{user_text}"
|
||
)
|
||
elif mode == "full":
|
||
context = self._get_room_context(room_id)
|
||
if context:
|
||
message_for_claude = context + "\n\n---\n\n" + user_text
|
||
|
||
# Inject collect mode preamble if switching from collect
|
||
preamble = self._collect_preambles.pop(safe_id, "")
|
||
if preamble:
|
||
message_for_claude = (
|
||
"[CONTEXT UPDATE: User just switched from COLLECT mode. "
|
||
"New material was accumulated in this room's history — images, voice notes, "
|
||
"and/or text that you haven't seen yet. Review the conversation history above carefully, "
|
||
"especially entries with [image:] paths (use Read tool to view them) "
|
||
"and voice transcriptions. Process all accumulated material before responding.]\n\n"
|
||
+ message_for_claude
|
||
)
|
||
|
||
async def _on_done(response: str):
|
||
self._pending_questions.pop(safe_id, None)
|
||
if mode == "full":
|
||
self._save_room_message(room_id, self.client.user_id, "text", response)
|
||
await self._rename_room(room_id, safe_id, user_text=user_text, response=response)
|
||
self._log_interaction(room_id, user_text, response)
|
||
|
||
await self._run_claude_session(
|
||
room, event, message_for_claude,
|
||
security_msg=security_msg, on_done=_on_done,
|
||
)
|
||
|
||
async def _handle_image(self, room: MatrixRoom, event) -> None:
|
||
if not self._is_allowed_user(event.sender):
|
||
return
|
||
mode = self._get_room_mode(room.room_id)
|
||
if self._is_group_room(room) and mode not in ("full", "collect"):
|
||
return
|
||
|
||
room_id = room.room_id
|
||
safe_id = room_id.replace(":", "_").replace("!", "")
|
||
|
||
# Download and save image regardless of mode
|
||
images_dir = self._room_dir(room_id) / "images"
|
||
images_dir.mkdir(exist_ok=True)
|
||
|
||
data = await self._download_media(event)
|
||
if data is None:
|
||
return
|
||
|
||
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||
filename = f"{ts}_{event.body or 'image'}"
|
||
if not any(filename.endswith(ext) for ext in (".jpg", ".jpeg", ".png", ".webp", ".gif")):
|
||
filename += ".jpg"
|
||
filepath = images_dir / filename
|
||
with open(filepath, "wb") as f:
|
||
f.write(data)
|
||
|
||
caption = event.body if event.body and event.body != "image" else ""
|
||
|
||
# Collect mode: save to history, no Claude
|
||
if mode == "collect":
|
||
history_text = f"[image: {filepath}]"
|
||
if caption:
|
||
history_text += f" {caption}"
|
||
self._save_room_message(room_id, event.sender, "image", history_text, file_path=str(filepath))
|
||
return
|
||
|
||
# Security check
|
||
allowed, security_msg = await self._check_security(room_id, event.sender)
|
||
if not allowed:
|
||
await self._send_response(room_id, security_msg)
|
||
return
|
||
|
||
message = f"User sent an image: {filepath}"
|
||
if caption:
|
||
message += f"\nCaption: {caption}"
|
||
|
||
if self._is_room_busy(room_id):
|
||
history_text = f"[image: {filepath}]"
|
||
if caption:
|
||
history_text += f" {caption}"
|
||
self._enqueue_message(room_id, event.event_id, event.sender,
|
||
history_text, msg_type="image", file_path=str(filepath))
|
||
return
|
||
|
||
async def _on_done(response: str):
|
||
await self._rename_room(room_id, safe_id, user_text=message, response=response)
|
||
self._log_interaction(room_id, f"[image] {event.body}", response)
|
||
|
||
await self._run_claude_session(
|
||
room, event, message, security_msg=security_msg, on_done=_on_done,
|
||
)
|
||
|
||
async def _handle_audio(self, room: MatrixRoom, event) -> None:
|
||
is_group = self._is_group_room(room)
|
||
if not is_group and not self._is_allowed_user(event.sender):
|
||
return
|
||
|
||
room_id = room.room_id
|
||
safe_id = room_id.replace(":", "_").replace("!", "")
|
||
mode = self._get_room_mode(room_id)
|
||
voice_dir = self._room_dir(room_id) / "voice"
|
||
voice_dir.mkdir(exist_ok=True)
|
||
|
||
data = await self._download_media(event)
|
||
if data is None:
|
||
return
|
||
|
||
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||
filename = f"{ts}_{event.body or 'voice.ogg'}"
|
||
filepath = voice_dir / filename
|
||
with open(filepath, "wb") as f:
|
||
f.write(data)
|
||
|
||
# Transcribe
|
||
transcribed_text = None
|
||
engine_tag = ""
|
||
if self.config.stt_url:
|
||
try:
|
||
transcribed_text, engine_tag = await transcribe(
|
||
str(filepath), self.config.stt_url,
|
||
whisper_url=os.environ.get("STT_SHORT_URL"),
|
||
)
|
||
logger.info("Transcribed voice in room %s: %d chars [%s]",
|
||
room.display_name, len(transcribed_text), engine_tag)
|
||
except RuntimeError as e:
|
||
logger.error("ASR failed for room %s: %s", room.display_name, e)
|
||
|
||
# Post transcription with sender attribution + engine tag
|
||
if transcribed_text:
|
||
sender_name = self._sender_display_name(room, event.sender)
|
||
notice = f"🎙 {sender_name}: {transcribed_text}"
|
||
if engine_tag and os.environ.get("STT_SHORT_URL"):
|
||
notice += f" // {engine_tag}"
|
||
await self.client.room_send(
|
||
room_id, "m.room.message",
|
||
{"msgtype": "m.notice", "body": notice},
|
||
ignore_unverified_devices=True,
|
||
)
|
||
|
||
# Save to history in full/collect modes
|
||
if mode in ("full", "collect"):
|
||
history_text = transcribed_text or f"[audio: {filepath}]"
|
||
self._save_room_message(room_id, event.sender, "audio", history_text, file_path=str(filepath))
|
||
|
||
# Collect mode: transcribe and save, no Claude
|
||
if mode == "collect":
|
||
return
|
||
|
||
# Decide whether to respond via Claude
|
||
should_respond = not is_group # always respond in 1:1
|
||
if is_group and transcribed_text and self._text_mentions_bot(transcribed_text):
|
||
should_respond = True
|
||
if not should_respond:
|
||
return
|
||
|
||
if self._is_room_busy(room_id):
|
||
queue_text = transcribed_text or f"[audio: {filepath}]"
|
||
self._enqueue_message(room_id, event.event_id, event.sender,
|
||
queue_text, msg_type="audio", file_path=str(filepath))
|
||
return
|
||
|
||
# Security check — before Claude interaction
|
||
allowed, security_msg = await self._check_security(room_id, event.sender)
|
||
if not allowed:
|
||
await self._send_response(room_id, security_msg)
|
||
return
|
||
|
||
# Build message for Claude
|
||
if transcribed_text:
|
||
message = f"[voice message transcription]: {transcribed_text}"
|
||
else:
|
||
message = f"User sent a voice message: {filepath}"
|
||
|
||
if mode == "context":
|
||
recent = await self._fetch_recent_messages(room_id, limit=10)
|
||
if recent:
|
||
context_lines = [f"{m['sender']}: {m['text']}" for m in recent]
|
||
context_block = "\n".join(context_lines)
|
||
message = f"[Recent room messages for context]\n{context_block}\n\n---\n\n{message}"
|
||
|
||
async def _on_done(response: str):
|
||
if mode == "full":
|
||
self._save_room_message(room_id, self.client.user_id, "text", response)
|
||
await self._rename_room(room_id, safe_id, user_text=message, response=response)
|
||
self._log_interaction(room_id, message, response)
|
||
|
||
await self._run_claude_session(
|
||
room, event, message, security_msg=security_msg, on_done=_on_done,
|
||
)
|
||
|
||
async def _handle_file(self, room: MatrixRoom, event) -> None:
|
||
if not self._is_allowed_user(event.sender):
|
||
return
|
||
mode = self._get_room_mode(room.room_id)
|
||
if self._is_group_room(room) and mode not in ("full", "collect"):
|
||
return
|
||
|
||
room_id = room.room_id
|
||
safe_id = room_id.replace(":", "_").replace("!", "")
|
||
|
||
# Download and save file regardless of mode
|
||
docs_dir = self._room_dir(room_id) / "documents"
|
||
docs_dir.mkdir(exist_ok=True)
|
||
|
||
data = await self._download_media(event)
|
||
if data is None:
|
||
return
|
||
|
||
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||
orig_name = event.body or "document"
|
||
filename = f"{ts}_{orig_name}"
|
||
filepath = docs_dir / filename
|
||
with open(filepath, "wb") as f:
|
||
f.write(data)
|
||
|
||
# Collect mode: save to history, no Claude
|
||
if mode == "collect":
|
||
self._save_room_message(room_id, event.sender, "file",
|
||
f"[file: {orig_name}]", file_path=str(filepath))
|
||
return
|
||
|
||
# Security check
|
||
allowed, security_msg = await self._check_security(room_id, event.sender)
|
||
if not allowed:
|
||
await self._send_response(room_id, security_msg)
|
||
return
|
||
|
||
message = f"User sent a document: {filepath} (name: {orig_name}, size: {len(data)} bytes)"
|
||
|
||
if self._is_room_busy(room_id):
|
||
self._enqueue_message(room_id, event.event_id, event.sender,
|
||
f"[file: {orig_name}]", msg_type="file", file_path=str(filepath))
|
||
return
|
||
|
||
async def _on_done(response: str):
|
||
await self._rename_room(room_id, safe_id, user_text=message, response=response)
|
||
self._log_interaction(room_id, f"[document: {orig_name}]", response)
|
||
|
||
await self._run_claude_session(
|
||
room, event, message, security_msg=security_msg, on_done=_on_done,
|
||
)
|
||
|
||
# --- E2E cross-signing & trust ---
|
||
|
||
async def _setup_cross_signing(self) -> None:
|
||
"""Generate cross-signing keys (or load existing) and self-sign device."""
|
||
if not self.client.olm:
|
||
return
|
||
import base64
|
||
import olm as _olm
|
||
|
||
seeds_path = self.config.data_dir / "crypto_store" / "cross_signing_seeds.json"
|
||
|
||
# Load or generate seeds
|
||
if seeds_path.exists():
|
||
seeds = json.loads(seeds_path.read_text())
|
||
master_seed = base64.b64decode(seeds["master_seed"])
|
||
self_signing_seed = base64.b64decode(seeds["self_signing_seed"])
|
||
user_signing_seed = base64.b64decode(seeds["user_signing_seed"])
|
||
else:
|
||
master_seed = _olm.PkSigning.generate_seed()
|
||
self_signing_seed = _olm.PkSigning.generate_seed()
|
||
user_signing_seed = _olm.PkSigning.generate_seed()
|
||
seeds_path.parent.mkdir(parents=True, exist_ok=True)
|
||
seeds_path.write_text(json.dumps({
|
||
"master_seed": base64.b64encode(master_seed).decode(),
|
||
"self_signing_seed": base64.b64encode(self_signing_seed).decode(),
|
||
"user_signing_seed": base64.b64encode(user_signing_seed).decode(),
|
||
}))
|
||
|
||
master = _olm.PkSigning(master_seed)
|
||
self_signing = _olm.PkSigning(self_signing_seed)
|
||
_olm.PkSigning(user_signing_seed) # validate
|
||
|
||
def _canonical(obj):
|
||
return json.dumps(obj, separators=(",", ":"), sort_keys=True, ensure_ascii=False)
|
||
|
||
def _sign(obj, key_id, signing_key):
|
||
to_sign = {k: v for k, v in obj.items() if k not in ("signatures", "unsigned")}
|
||
sig = signing_key.sign(_canonical(to_sign))
|
||
obj.setdefault("signatures", {}).setdefault(self.client.user_id, {})[key_id] = sig
|
||
|
||
user_id = self.client.user_id
|
||
hs = self.client.homeserver
|
||
|
||
async with httpx.AsyncClient() as http:
|
||
headers = {"Authorization": f"Bearer {self.client.access_token}",
|
||
"Content-Type": "application/json"}
|
||
|
||
# Check if already uploaded
|
||
resp = await http.post(f"{hs}/_matrix/client/v3/keys/query",
|
||
headers=headers, json={"device_keys": {user_id: []}}, timeout=10)
|
||
existing = resp.json().get("master_keys", {}).get(user_id)
|
||
if existing:
|
||
logger.info("Cross-signing keys already uploaded")
|
||
else:
|
||
# Build and upload cross-signing keys
|
||
master_key = {"user_id": user_id, "usage": ["master"],
|
||
"keys": {f"ed25519:{master.public_key}": master.public_key}}
|
||
self_signing_key = {"user_id": user_id, "usage": ["self_signing"],
|
||
"keys": {f"ed25519:{self_signing.public_key}": self_signing.public_key}}
|
||
user_signing_key_obj = {"user_id": user_id, "usage": ["user_signing"],
|
||
"keys": {f"ed25519:{_olm.PkSigning(user_signing_seed).public_key}":
|
||
_olm.PkSigning(user_signing_seed).public_key}}
|
||
_sign(self_signing_key, f"ed25519:{master.public_key}", master)
|
||
_sign(user_signing_key_obj, f"ed25519:{master.public_key}", master)
|
||
resp = await http.post(f"{hs}/_matrix/client/v3/keys/device_signing/upload",
|
||
headers=headers, timeout=10,
|
||
json={"master_key": master_key,
|
||
"self_signing_key": self_signing_key,
|
||
"user_signing_key": user_signing_key_obj})
|
||
if resp.status_code == 401:
|
||
session = resp.json().get("session", "")
|
||
resp = await http.post(f"{hs}/_matrix/client/v3/keys/device_signing/upload",
|
||
headers=headers, timeout=10,
|
||
json={"master_key": master_key,
|
||
"self_signing_key": self_signing_key,
|
||
"user_signing_key": user_signing_key_obj,
|
||
"auth": {"type": "m.login.dummy", "session": session}})
|
||
if resp.status_code == 200:
|
||
logger.info("Uploaded cross-signing keys")
|
||
else:
|
||
logger.error("Failed to upload cross-signing keys (%d): %s",
|
||
resp.status_code, resp.text[:200])
|
||
return
|
||
|
||
# Self-sign our device with self-signing key
|
||
resp = await http.post(f"{hs}/_matrix/client/v3/keys/query",
|
||
headers=headers, json={"device_keys": {user_id: []}}, timeout=10)
|
||
device_keys = resp.json()["device_keys"][user_id].get(self.client.device_id)
|
||
if not device_keys:
|
||
logger.error("Own device keys not found on server")
|
||
return
|
||
|
||
# Check if already signed by self-signing key
|
||
existing_sigs = device_keys.get("signatures", {}).get(user_id, {})
|
||
ss_key_id = f"ed25519:{self_signing.public_key}"
|
||
if ss_key_id in existing_sigs:
|
||
logger.info("Device already self-signed")
|
||
return
|
||
|
||
to_sign = {k: v for k, v in device_keys.items() if k not in ("signatures", "unsigned")}
|
||
sig = self_signing.sign(_canonical(to_sign))
|
||
sig_body = {user_id: {self.client.device_id: {
|
||
**to_sign,
|
||
"signatures": {user_id: {ss_key_id: sig}},
|
||
}}}
|
||
resp = await http.post(f"{hs}/_matrix/client/v3/keys/signatures/upload",
|
||
headers=headers, json=sig_body, timeout=10)
|
||
if resp.status_code == 200:
|
||
logger.info("Self-signed device %s", self.client.device_id)
|
||
else:
|
||
logger.error("Failed to self-sign device (%d): %s",
|
||
resp.status_code, resp.text[:200])
|
||
|
||
async def _sync_cross_signing_trust(self) -> None:
|
||
"""Query server for cross-signing keys and trust devices signed by self-signing keys.
|
||
|
||
This bridges the gap between server-side cross-signing verification
|
||
(what Element shows as green/red) and nio's local device trust store.
|
||
A device is considered verified if it's signed by its owner's self-signing key.
|
||
"""
|
||
if not self.client.olm:
|
||
return
|
||
hs = self.client.homeserver
|
||
headers = {"Authorization": f"Bearer {self.client.access_token}",
|
||
"Content-Type": "application/json"}
|
||
|
||
# Collect all user IDs we care about
|
||
user_ids = set(self._users.keys())
|
||
if not user_ids:
|
||
return
|
||
|
||
try:
|
||
async with httpx.AsyncClient() as http:
|
||
resp = await http.post(
|
||
f"{hs}/_matrix/client/v3/keys/query",
|
||
headers=headers,
|
||
json={"device_keys": {uid: [] for uid in user_ids}},
|
||
timeout=10,
|
||
)
|
||
if resp.status_code != 200:
|
||
logger.warning("Cross-signing trust sync failed (%d)", resp.status_code)
|
||
return
|
||
data = resp.json()
|
||
except Exception as e:
|
||
logger.warning("Cross-signing trust sync error: %s", e)
|
||
return
|
||
|
||
# For each user, find their self-signing key
|
||
for user_id in user_ids:
|
||
ss_key_obj = data.get("self_signing_keys", {}).get(user_id)
|
||
if not ss_key_obj:
|
||
continue
|
||
# Extract the self-signing public key
|
||
ss_keys = ss_key_obj.get("keys", {})
|
||
ss_pubkey = None
|
||
for key_id, key_val in ss_keys.items():
|
||
if key_id.startswith("ed25519:"):
|
||
ss_pubkey = key_id # e.g. "ed25519:ABCDEF..."
|
||
break
|
||
if not ss_pubkey:
|
||
continue
|
||
|
||
# Check each device: is it signed by the self-signing key?
|
||
user_devices = data.get("device_keys", {}).get(user_id, {})
|
||
for device_id, dev_keys in user_devices.items():
|
||
sigs = dev_keys.get("signatures", {}).get(user_id, {})
|
||
is_cross_signed = ss_pubkey in sigs
|
||
|
||
# Find this device in nio's local store
|
||
nio_device = None
|
||
for d in self.client.device_store.active_user_devices(user_id):
|
||
if d.id == device_id:
|
||
nio_device = d
|
||
break
|
||
|
||
if nio_device is None:
|
||
continue
|
||
|
||
if is_cross_signed and not nio_device.verified:
|
||
self.client.verify_device(nio_device)
|
||
logger.info("Trusted cross-signed device %s of %s", device_id, user_id)
|
||
elif not is_cross_signed and nio_device.verified:
|
||
# Device lost cross-signing — untrust it
|
||
# (nio has no unverify, but we can note it)
|
||
logger.warning("Device %s of %s no longer cross-signed", device_id, user_id)
|
||
|
||
logger.info("Cross-signing trust sync complete")
|
||
|
||
# --- Auto-join and room locking ---
|
||
|
||
async def _auto_join_invites(self) -> None:
|
||
for room_id in list(self.client.invited_rooms):
|
||
await self.client.join(room_id)
|
||
logger.info("Accepted invite to room %s", room_id)
|
||
|
||
def _load_sync_token(self) -> str | None:
|
||
if self._sync_token_path.exists():
|
||
token = self._sync_token_path.read_text().strip()
|
||
return token if token else None
|
||
return None
|
||
|
||
def _save_sync_token(self, token: str) -> None:
|
||
self._sync_token_path.parent.mkdir(parents=True, exist_ok=True)
|
||
self._sync_token_path.write_text(token)
|
||
|
||
async def run(self) -> None:
|
||
"""Start the Matrix bot."""
|
||
# Plain events
|
||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||
self.client.add_event_callback(self._on_image, RoomMessageImage)
|
||
self.client.add_event_callback(self._on_audio, RoomMessageAudio)
|
||
self.client.add_event_callback(self._on_file, RoomMessageFile)
|
||
self.client.add_event_callback(self._on_member, RoomMemberEvent)
|
||
# Encrypted events (nio auto-decrypts to RoomMessage* types above,
|
||
# but encrypted media comes as RoomEncrypted* types)
|
||
self.client.add_event_callback(self._on_image, RoomEncryptedImage)
|
||
self.client.add_event_callback(self._on_audio, RoomEncryptedAudio)
|
||
self.client.add_event_callback(self._on_file, RoomEncryptedFile)
|
||
# Undecryptable events (missing keys)
|
||
self.client.add_event_callback(self._on_megolm, MegolmEvent)
|
||
# In-room verification events (Element X, FluffyChat)
|
||
self.client.add_event_callback(self._on_room_verify_event, RoomMessageUnknown)
|
||
self.client.add_event_callback(self._on_room_verify_event, UnknownEvent)
|
||
self.client.add_response_callback(self._on_sync, SyncResponse)
|
||
# SAS key verification (to-device events)
|
||
self.client.add_to_device_callback(self._on_verify_start, KeyVerificationStart)
|
||
self.client.add_to_device_callback(self._on_verify_key, KeyVerificationKey)
|
||
self.client.add_to_device_callback(self._on_verify_mac, KeyVerificationMac)
|
||
self.client.add_to_device_callback(self._on_verify_cancel, KeyVerificationCancel)
|
||
|
||
logger.info("Matrix bot starting as %s", self.client.user_id)
|
||
|
||
saved_token = self._load_sync_token()
|
||
if saved_token:
|
||
logger.info("Resuming from saved sync token")
|
||
|
||
resp = await self.client.sync(timeout=10000, since=saved_token, full_state=True)
|
||
if hasattr(resp, "next_batch") and resp.next_batch:
|
||
self._save_sync_token(resp.next_batch)
|
||
await self._auto_join_invites()
|
||
# E2E setup: upload our keys, then fetch and trust other users' devices
|
||
if self.client.olm:
|
||
if self.client.should_upload_keys:
|
||
await self.client.keys_upload()
|
||
logger.info("Uploaded device keys to server")
|
||
try:
|
||
await self.client.keys_query()
|
||
except Exception:
|
||
pass # no keys to query yet (fresh user, no rooms)
|
||
# Note: we intentionally do NOT auto-trust all user devices here.
|
||
# The security model (strict/guarded/open) handles unverified devices
|
||
# per room. Devices are verified via in-room verification or cross-signing.
|
||
await self._sync_cross_signing_trust()
|
||
await self._setup_cross_signing()
|
||
await self._set_bot_avatar()
|
||
self._synced = True
|
||
logger.info("Initial sync complete, E2E=%s, listening for new messages",
|
||
"enabled" if self.client.olm else "disabled")
|
||
|
||
await self.client.sync_forever(timeout=30000)
|
||
|
||
def _should_process(self, event, room: MatrixRoom | None = None) -> bool:
|
||
"""Check if event should be processed (not own, not old, not duplicate, after sync)."""
|
||
eid = event.event_id
|
||
room_id = room.room_id if room else "?"
|
||
logger.info("_should_process: eid=%s sender=%s room=%s ts=%s body=%s",
|
||
eid, event.sender, room_id, event.server_timestamp,
|
||
getattr(event, 'body', '')[:50])
|
||
if not self._synced:
|
||
return False
|
||
if event.sender == self.client.user_id:
|
||
return False
|
||
if eid in self._processed_events:
|
||
logger.warning("Duplicate event %s, skipping", eid)
|
||
return False
|
||
self._processed_events.add(eid)
|
||
# Keep set bounded
|
||
if len(self._processed_events) > 1000:
|
||
self._processed_events = set(list(self._processed_events)[-500:])
|
||
return True
|
||
|
||
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||
if not self._should_process(event, room):
|
||
return
|
||
await self._handle_text(room, event)
|
||
|
||
async def _on_image(self, room: MatrixRoom, event) -> None:
|
||
if not self._should_process(event, room):
|
||
return
|
||
await self._handle_image(room, event)
|
||
|
||
async def _on_audio(self, room: MatrixRoom, event) -> None:
|
||
if not self._should_process(event, room):
|
||
return
|
||
await self._handle_audio(room, event)
|
||
|
||
async def _on_file(self, room: MatrixRoom, event) -> None:
|
||
if not self._should_process(event, room):
|
||
return
|
||
await self._handle_file(room, event)
|
||
|
||
async def _on_megolm(self, room: MatrixRoom, event: MegolmEvent) -> None:
|
||
"""Handle messages we couldn't decrypt."""
|
||
if not self._synced:
|
||
return
|
||
logger.warning("Could not decrypt event %s in %s from %s (session %s)",
|
||
event.event_id, room.room_id, event.sender,
|
||
event.session_id)
|
||
|
||
# --- SAS key verification (auto-accept for allowed users) ---
|
||
|
||
async def _on_verify_start(self, event: KeyVerificationStart) -> None:
|
||
"""Incoming verification request — auto-accept from allowed users."""
|
||
if not self._is_allowed_user(event.sender):
|
||
logger.warning("Verification from non-allowed user %s, ignoring", event.sender)
|
||
return
|
||
logger.info("Verification request from %s (tx=%s), auto-accepting",
|
||
event.sender, event.transaction_id)
|
||
resp = await self.client.accept_key_verification(event.transaction_id)
|
||
if hasattr(resp, "message"):
|
||
logger.error("Failed to accept verification: %s", resp.message)
|
||
|
||
async def _on_verify_key(self, event: KeyVerificationKey) -> None:
|
||
"""Key exchange done — emojis available. Auto-confirm (bot trusts allowed users)."""
|
||
sas = self.client.key_verifications.get(event.transaction_id)
|
||
if not sas:
|
||
return
|
||
emojis = sas.get_emoji()
|
||
emoji_str = " ".join(f"{e[0]} ({e[1]})" for e in emojis)
|
||
logger.info("Verification emojis for %s: %s", sas.other_olm_device.user_id, emoji_str)
|
||
resp = await self.client.confirm_short_auth_string(event.transaction_id)
|
||
if hasattr(resp, "message"):
|
||
logger.error("Failed to confirm SAS: %s", resp.message)
|
||
|
||
async def _on_verify_mac(self, event: KeyVerificationMac) -> None:
|
||
"""MAC received — verification complete."""
|
||
sas = self.client.key_verifications.get(event.transaction_id)
|
||
if not sas:
|
||
return
|
||
if sas.verified:
|
||
logger.info("Device %s of %s verified via SAS",
|
||
sas.other_olm_device.id, sas.other_olm_device.user_id)
|
||
else:
|
||
logger.warning("SAS verification failed for %s", event.transaction_id)
|
||
|
||
async def _on_verify_cancel(self, event: KeyVerificationCancel) -> None:
|
||
"""Verification canceled."""
|
||
logger.info("Verification %s canceled by %s: %s",
|
||
event.transaction_id, event.sender, event.reason)
|
||
|
||
# --- In-room verification (used by Element X, FluffyChat) ---
|
||
|
||
async def _on_room_verify_event(self, room: MatrixRoom, event) -> None:
|
||
"""Handle in-room verification events (m.key.verification.*)."""
|
||
if not self._synced:
|
||
return
|
||
source = getattr(event, "source", {})
|
||
content = source.get("content", {})
|
||
event_type = source.get("type", "")
|
||
sender = source.get("sender", "")
|
||
event_id = source.get("event_id", "")
|
||
logger.debug("Room event: type=%s sender=%s eid=%s keys=%s",
|
||
event_type, sender, event_id, list(content.keys()))
|
||
|
||
# m.room.message with msgtype m.key.verification.request
|
||
if event_type == "m.room.message":
|
||
msgtype = content.get("msgtype", "")
|
||
if msgtype != "m.key.verification.request":
|
||
return
|
||
event_type = "m.key.verification.request"
|
||
|
||
if not event_type.startswith("m.key.verification."):
|
||
return
|
||
|
||
if sender == self.client.user_id:
|
||
return
|
||
|
||
if not self._is_allowed_user(sender):
|
||
return
|
||
|
||
# Get transaction_id from m.relates_to or from the request event_id
|
||
relates_to = content.get("m.relates_to", {})
|
||
tx_id = relates_to.get("event_id", "")
|
||
|
||
room_id = room.room_id
|
||
logger.info("In-room verification: %s from %s (tx=%s)", event_type, sender, tx_id or event_id)
|
||
|
||
if event_type == "m.key.verification.request":
|
||
tx_id = event_id # the request event_id IS the transaction_id
|
||
# Store SAS state
|
||
import olm as _olm
|
||
sas_obj = _olm.Sas()
|
||
self._room_verifications[tx_id] = {
|
||
"sas": sas_obj,
|
||
"room_id": room_id,
|
||
"sender": sender,
|
||
"from_device": content.get("from_device", ""),
|
||
}
|
||
# Send m.key.verification.ready
|
||
await self.client.room_send(room_id, "m.key.verification.ready", {
|
||
"from_device": self.client.device_id,
|
||
"methods": ["m.sas.v1"],
|
||
"m.relates_to": {"rel_type": "m.reference", "event_id": tx_id},
|
||
}, ignore_unverified_devices=True)
|
||
logger.info("Sent verification ready for tx=%s", tx_id)
|
||
# Send start immediately (bot always initiates SAS after ready)
|
||
try:
|
||
resp = await self.client.room_send(room_id, "m.key.verification.start", {
|
||
"from_device": self.client.device_id,
|
||
"method": "m.sas.v1",
|
||
"key_agreement_protocols": ["curve25519-hkdf-sha256"],
|
||
"hashes": ["sha256"],
|
||
"message_authentication_codes": ["hkdf-hmac-sha256.v2"],
|
||
"short_authentication_string": ["decimal", "emoji"],
|
||
"m.relates_to": {"rel_type": "m.reference", "event_id": tx_id},
|
||
}, ignore_unverified_devices=True)
|
||
logger.info("Sent verification start for tx=%s", tx_id)
|
||
except Exception as e:
|
||
logger.error("Failed to send verification start: %s", e)
|
||
|
||
elif event_type == "m.key.verification.accept":
|
||
state = self._room_verifications.get(tx_id)
|
||
if not state:
|
||
return
|
||
state["their_commitment"] = content.get("commitment", "")
|
||
state["mac_method"] = content.get("message_authentication_code", "hkdf-hmac-sha256.v2")
|
||
# Send our public key
|
||
await self.client.room_send(room_id, "m.key.verification.key", {
|
||
"key": state["sas"].pubkey,
|
||
"m.relates_to": {"rel_type": "m.reference", "event_id": tx_id},
|
||
}, ignore_unverified_devices=True)
|
||
logger.info("Sent verification key for tx=%s", tx_id)
|
||
|
||
elif event_type == "m.key.verification.start":
|
||
state = self._room_verifications.get(tx_id)
|
||
if not state:
|
||
return
|
||
# Send our key
|
||
await self.client.room_send(room_id, "m.key.verification.key", {
|
||
"key": state["sas"].pubkey,
|
||
"m.relates_to": {"rel_type": "m.reference", "event_id": tx_id},
|
||
}, ignore_unverified_devices=True)
|
||
logger.info("Sent verification key for tx=%s", tx_id)
|
||
|
||
elif event_type == "m.key.verification.key":
|
||
state = self._room_verifications.get(tx_id)
|
||
if not state:
|
||
return
|
||
their_key = content.get("key", "")
|
||
state["sas"].set_their_pubkey(their_key)
|
||
# Generate SAS bytes for emoji
|
||
sas_info = (
|
||
"MATRIX_KEY_VERIFICATION_SAS"
|
||
f"{self.client.user_id}{self.client.device_id}"
|
||
f"{state['sas'].pubkey}"
|
||
f"{state['sender']}{state['from_device']}"
|
||
f"{their_key}{tx_id}"
|
||
)
|
||
sas_bytes = state["sas"].generate_bytes(sas_info, 6)
|
||
state["sas_bytes"] = sas_bytes
|
||
emojis = self._sas_to_emojis(sas_bytes)
|
||
logger.info("Verification emojis for %s: %s", state["sender"],
|
||
" ".join(f"{e[0]}({e[1]})" for e in emojis))
|
||
# Auto-confirm: calculate and send MAC for device key + master key
|
||
mac_info_base = (
|
||
"MATRIX_KEY_VERIFICATION_MAC"
|
||
f"{self.client.user_id}{self.client.device_id}"
|
||
f"{state['sender']}{state['from_device']}{tx_id}"
|
||
)
|
||
own_device_key_id = f"ed25519:{self.client.device_id}"
|
||
own_ed25519 = self.client.olm.account.identity_keys["ed25519"]
|
||
mac_dict = {}
|
||
key_ids = []
|
||
# MAC device key
|
||
mac_dict[own_device_key_id] = state["sas"].calculate_mac_fixed_base64(
|
||
own_ed25519, mac_info_base + own_device_key_id)
|
||
key_ids.append(own_device_key_id)
|
||
# MAC master key (so other side can cross-sign our identity)
|
||
seeds_path = self.config.data_dir / "crypto_store" / "cross_signing_seeds.json"
|
||
if seeds_path.exists():
|
||
import base64
|
||
import olm as _olm
|
||
seeds = json.loads(seeds_path.read_text())
|
||
master_pubkey = _olm.PkSigning(base64.b64decode(seeds["master_seed"])).public_key
|
||
master_key_id = f"ed25519:{master_pubkey}"
|
||
mac_dict[master_key_id] = state["sas"].calculate_mac_fixed_base64(
|
||
master_pubkey, mac_info_base + master_key_id)
|
||
key_ids.append(master_key_id)
|
||
# KEY_IDS mac covers sorted comma-separated key ids
|
||
key_ids.sort()
|
||
keys_str = ",".join(key_ids)
|
||
keys_mac = state["sas"].calculate_mac_fixed_base64(
|
||
keys_str, mac_info_base + "KEY_IDS")
|
||
await self.client.room_send(room_id, "m.key.verification.mac", {
|
||
"keys": keys_mac,
|
||
"mac": mac_dict,
|
||
"m.relates_to": {"rel_type": "m.reference", "event_id": tx_id},
|
||
}, ignore_unverified_devices=True)
|
||
logger.info("Sent verification MAC for tx=%s", tx_id)
|
||
|
||
elif event_type == "m.key.verification.mac":
|
||
state = self._room_verifications.get(tx_id)
|
||
if not state:
|
||
return
|
||
# Send done
|
||
await self.client.room_send(room_id, "m.key.verification.done", {
|
||
"m.relates_to": {"rel_type": "m.reference", "event_id": tx_id},
|
||
}, ignore_unverified_devices=True)
|
||
# Cross-sign the user's master key with our user-signing key
|
||
await self._cross_sign_user(state["sender"])
|
||
logger.info("Verification complete for tx=%s with %s", tx_id, state["sender"])
|
||
self._room_verifications.pop(tx_id, None)
|
||
|
||
elif event_type == "m.key.verification.cancel":
|
||
logger.info("In-room verification %s canceled: %s", tx_id, content.get("reason", ""))
|
||
self._room_verifications.pop(tx_id, None)
|
||
|
||
elif event_type == "m.key.verification.done":
|
||
logger.info("In-room verification %s done by %s", tx_id, sender)
|
||
self._room_verifications.pop(tx_id, None)
|
||
|
||
async def _cross_sign_user(self, user_id: str) -> None:
|
||
"""Sign user's master key with our user-signing key after successful verification."""
|
||
import base64
|
||
import olm as _olm
|
||
|
||
seeds_path = self.config.data_dir / "crypto_store" / "cross_signing_seeds.json"
|
||
if not seeds_path.exists():
|
||
logger.warning("No cross-signing seeds, cannot cross-sign user")
|
||
return
|
||
|
||
seeds = json.loads(seeds_path.read_text())
|
||
user_signing = _olm.PkSigning(base64.b64decode(seeds["user_signing_seed"]))
|
||
|
||
hs = self.client.homeserver
|
||
headers = {"Authorization": f"Bearer {self.client.access_token}",
|
||
"Content-Type": "application/json"}
|
||
|
||
async with httpx.AsyncClient() as http:
|
||
# Get user's master key
|
||
resp = await http.post(f"{hs}/_matrix/client/v3/keys/query",
|
||
headers=headers,
|
||
json={"device_keys": {user_id: []}}, timeout=10)
|
||
data = resp.json()
|
||
master_key_obj = data.get("master_keys", {}).get(user_id)
|
||
if not master_key_obj:
|
||
logger.warning("No master key found for %s", user_id)
|
||
return
|
||
|
||
# Sign the master key with our user-signing key
|
||
to_sign = {k: v for k, v in master_key_obj.items()
|
||
if k not in ("signatures", "unsigned")}
|
||
canonical = json.dumps(to_sign, separators=(",", ":"),
|
||
sort_keys=True, ensure_ascii=False)
|
||
sig = user_signing.sign(canonical)
|
||
us_key_id = f"ed25519:{user_signing.public_key}"
|
||
|
||
sig_body = {user_id: {
|
||
list(master_key_obj["keys"].keys())[0].split(":")[1]: {
|
||
**to_sign,
|
||
"signatures": {self.client.user_id: {us_key_id: sig}},
|
||
}
|
||
}}
|
||
resp = await http.post(f"{hs}/_matrix/client/v3/keys/signatures/upload",
|
||
headers=headers, json=sig_body, timeout=10)
|
||
if resp.status_code == 200:
|
||
logger.info("Cross-signed master key of %s", user_id)
|
||
else:
|
||
logger.error("Failed to cross-sign %s (%d): %s",
|
||
user_id, resp.status_code, resp.text[:200])
|
||
|
||
@staticmethod
|
||
def _sas_to_emojis(sas_bytes: bytes) -> list[tuple[str, str]]:
|
||
"""Convert 6 SAS bytes to 7 emojis (per Matrix spec)."""
|
||
emoji_list = [
|
||
("🐶","Dog"),("🐱","Cat"),("🦁","Lion"),("🐴","Horse"),("🦄","Unicorn"),
|
||
("🐷","Pig"),("🐘","Elephant"),("🐰","Rabbit"),("🐼","Panda"),("🐔","Rooster"),
|
||
("🐧","Penguin"),("🐢","Turtle"),("🐟","Fish"),("🐙","Octopus"),("🦋","Butterfly"),
|
||
("🌷","Flower"),("🌳","Tree"),("🌵","Cactus"),("🍄","Mushroom"),("🌏","Globe"),
|
||
("🌙","Moon"),("☁️","Cloud"),("🔥","Fire"),("🍌","Banana"),("🍎","Apple"),
|
||
("🍓","Strawberry"),("🌽","Corn"),("🍕","Pizza"),("🎂","Cake"),("❤️","Heart"),
|
||
("😀","Smiley"),("🤖","Robot"),("🎩","Hat"),("👓","Glasses"),("🔧","Wrench"),
|
||
("🎅","Santa"),("👍","Thumbs Up"),("☂️","Umbrella"),("⌛","Hourglass"),("⏰","Clock"),
|
||
("🎁","Gift"),("💡","Light Bulb"),("📕","Book"),("✏️","Pencil"),("📎","Paperclip"),
|
||
("✂️","Scissors"),("🔒","Lock"),("🔑","Key"),("🔨","Hammer"),("☎️","Telephone"),
|
||
("🏁","Flag"),("🚂","Train"),("🚲","Bicycle"),("✈️","Airplane"),("🚀","Rocket"),
|
||
("🏆","Trophy"),("⚽","Ball"),("🎸","Guitar"),("🎺","Trumpet"),("🔔","Bell"),
|
||
("⚓","Anchor"),("🎧","Headphones"),("📁","Folder"),("📌","Pin"),
|
||
]
|
||
# 6 bytes → 42 bits → 7 × 6-bit indices
|
||
val = int.from_bytes(sas_bytes, "big")
|
||
result = []
|
||
for i in range(6, -1, -1):
|
||
idx = (val >> (i * 6)) & 0x3F
|
||
result.append(emoji_list[idx])
|
||
return result
|
||
|
||
async def _on_member(self, room: MatrixRoom, event: RoomMemberEvent) -> None:
|
||
"""Handle member events (joins, leaves)."""
|
||
if not self._synced:
|
||
return
|
||
if event.sender == self.client.user_id:
|
||
return
|
||
# Query keys for new members so we know their devices
|
||
if event.membership == "join" and self.client.olm:
|
||
try:
|
||
await self.client.keys_query()
|
||
except Exception:
|
||
pass
|
||
|
||
async def _on_sync(self, response: SyncResponse) -> None:
|
||
if response.next_batch:
|
||
self._save_sync_token(response.next_batch)
|
||
if self._synced:
|
||
await self._auto_join_invites()
|
||
# Query keys and re-sync cross-signing trust when device lists change
|
||
if self.client.olm and response.device_list.changed:
|
||
try:
|
||
await self.client.keys_query()
|
||
await self._sync_cross_signing_trust()
|
||
except Exception:
|
||
pass
|
||
|
||
async def close(self) -> None:
|
||
await self.client.close()
|