surfaces/adapter/telegram/db.py

102 lines
2.9 KiB
Python

from __future__ import annotations
import os
import sqlite3
from contextlib import contextmanager
DB_PATH = os.environ.get("DB_PATH", "lambda_bot.db")
@contextmanager
def _conn():
con = sqlite3.connect(DB_PATH)
con.row_factory = sqlite3.Row
try:
yield con
con.commit()
finally:
con.close()
def init_db() -> None:
with _conn() as con:
con.executescript("""
CREATE TABLE IF NOT EXISTS chats (
user_id INTEGER NOT NULL,
thread_id INTEGER NOT NULL,
chat_name TEXT NOT NULL DEFAULT 'Чат #1',
archived_at DATETIME,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (user_id, thread_id)
);
""")
def create_chat(user_id: int, thread_id: int, chat_name: str) -> None:
with _conn() as con:
con.execute(
"INSERT OR IGNORE INTO chats (user_id, thread_id, chat_name) VALUES (?, ?, ?)",
(user_id, thread_id, chat_name),
)
def get_chat(user_id: int, thread_id: int) -> dict | None:
with _conn() as con:
row = con.execute(
"SELECT * FROM chats WHERE user_id = ? AND thread_id = ?",
(user_id, thread_id),
).fetchone()
return dict(row) if row else None
def get_active_chats(user_id: int) -> list[dict]:
with _conn() as con:
rows = con.execute(
"SELECT * FROM chats WHERE user_id = ? AND archived_at IS NULL "
"ORDER BY created_at ASC",
(user_id,),
).fetchall()
return [dict(r) for r in rows]
def count_active_chats(user_id: int) -> int:
with _conn() as con:
row = con.execute(
"SELECT COUNT(*) FROM chats WHERE user_id = ? AND archived_at IS NULL",
(user_id,),
).fetchone()
return row[0]
def archive_chat(user_id: int, thread_id: int) -> None:
with _conn() as con:
con.execute(
"UPDATE chats SET archived_at = CURRENT_TIMESTAMP "
"WHERE user_id = ? AND thread_id = ?",
(user_id, thread_id),
)
def rename_chat(user_id: int, thread_id: int, new_name: str) -> None:
with _conn() as con:
con.execute(
"UPDATE chats SET chat_name = ? WHERE user_id = ? AND thread_id = ?",
(new_name, user_id, thread_id),
)
def get_display_number(user_id: int, thread_id: int) -> int:
"""Return 1-based display number for a chat (by creation order)."""
with _conn() as con:
row = con.execute(
"""
SELECT rn FROM (
SELECT thread_id,
ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY created_at) AS rn
FROM chats
WHERE user_id = ?
) WHERE thread_id = ?
""",
(user_id, thread_id),
).fetchone()
return row[0] if row else 1