- commands.py: try/except TelegramBadRequest around all Bot API calls (#2); /new handles "topics limit" with user-friendly message (#4) - start.py: isolate _check_and_prune_stale_topics with try/except Exception (#3) - message.py: asyncio.timeout(30) around stream_message; handle TimeoutError (#6) - db.py: add idx_chats_user_id index in init_db() (#7) - settings.py: remove dead active_chat_id variable (#8) - tests: add test_message.py (stream error/success); add 2 tests in test_commands.py (topics limit, /archive in General topic)
103 lines
3 KiB
Python
103 lines
3 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import sqlite3
|
|
from contextlib import contextmanager
|
|
|
|
DB_PATH = os.environ.get("DB_PATH", "lambda_bot.db")
|
|
|
|
|
|
@contextmanager
|
|
def _conn():
|
|
con = sqlite3.connect(DB_PATH)
|
|
con.row_factory = sqlite3.Row
|
|
try:
|
|
yield con
|
|
con.commit()
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
def init_db() -> None:
|
|
with _conn() as con:
|
|
con.executescript("""
|
|
CREATE TABLE IF NOT EXISTS chats (
|
|
user_id INTEGER NOT NULL,
|
|
thread_id INTEGER NOT NULL,
|
|
chat_name TEXT NOT NULL DEFAULT 'Чат #1',
|
|
archived_at DATETIME,
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
PRIMARY KEY (user_id, thread_id)
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_chats_user_id ON chats(user_id);
|
|
""")
|
|
|
|
|
|
def create_chat(user_id: int, thread_id: int, chat_name: str) -> None:
|
|
with _conn() as con:
|
|
con.execute(
|
|
"INSERT OR IGNORE INTO chats (user_id, thread_id, chat_name) VALUES (?, ?, ?)",
|
|
(user_id, thread_id, chat_name),
|
|
)
|
|
|
|
|
|
def get_chat(user_id: int, thread_id: int) -> dict | None:
|
|
with _conn() as con:
|
|
row = con.execute(
|
|
"SELECT * FROM chats WHERE user_id = ? AND thread_id = ?",
|
|
(user_id, thread_id),
|
|
).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
|
|
def get_active_chats(user_id: int) -> list[dict]:
|
|
with _conn() as con:
|
|
rows = con.execute(
|
|
"SELECT * FROM chats WHERE user_id = ? AND archived_at IS NULL "
|
|
"ORDER BY created_at ASC",
|
|
(user_id,),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
def count_active_chats(user_id: int) -> int:
|
|
with _conn() as con:
|
|
row = con.execute(
|
|
"SELECT COUNT(*) FROM chats WHERE user_id = ? AND archived_at IS NULL",
|
|
(user_id,),
|
|
).fetchone()
|
|
return row[0]
|
|
|
|
|
|
def archive_chat(user_id: int, thread_id: int) -> None:
|
|
with _conn() as con:
|
|
con.execute(
|
|
"UPDATE chats SET archived_at = CURRENT_TIMESTAMP "
|
|
"WHERE user_id = ? AND thread_id = ?",
|
|
(user_id, thread_id),
|
|
)
|
|
|
|
|
|
def rename_chat(user_id: int, thread_id: int, new_name: str) -> None:
|
|
with _conn() as con:
|
|
con.execute(
|
|
"UPDATE chats SET chat_name = ? WHERE user_id = ? AND thread_id = ?",
|
|
(new_name, user_id, thread_id),
|
|
)
|
|
|
|
|
|
def get_display_number(user_id: int, thread_id: int) -> int:
|
|
"""Return 1-based display number for a chat (by creation order)."""
|
|
with _conn() as con:
|
|
row = con.execute(
|
|
"""
|
|
SELECT rn FROM (
|
|
SELECT thread_id,
|
|
ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY created_at) AS rn
|
|
FROM chats
|
|
WHERE user_id = ?
|
|
) WHERE thread_id = ?
|
|
""",
|
|
(user_id, thread_id),
|
|
).fetchone()
|
|
return row[0] if row else 1
|