167 lines
4.9 KiB
Python
167 lines
4.9 KiB
Python
# adapter/telegram/db.py
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import sqlite3
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
|
|
DB_PATH = os.environ.get("DB_PATH", "lambda_bot.db")
|
|
|
|
|
|
@contextmanager
|
|
def _conn():
|
|
con = sqlite3.connect(DB_PATH)
|
|
con.row_factory = sqlite3.Row
|
|
try:
|
|
yield con
|
|
con.commit()
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
def init_db() -> None:
|
|
with _conn() as con:
|
|
con.executescript("""
|
|
CREATE TABLE IF NOT EXISTS tg_users (
|
|
tg_user_id INTEGER PRIMARY KEY,
|
|
platform_user_id TEXT NOT NULL,
|
|
display_name TEXT,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
forum_group_id INTEGER
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS chats (
|
|
chat_id TEXT PRIMARY KEY,
|
|
tg_user_id INTEGER NOT NULL,
|
|
name TEXT NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
archived_at TIMESTAMP,
|
|
forum_thread_id INTEGER,
|
|
FOREIGN KEY(tg_user_id) REFERENCES tg_users(tg_user_id)
|
|
);
|
|
""")
|
|
# Миграция для существующих БД
|
|
try:
|
|
con.execute("ALTER TABLE tg_users ADD COLUMN forum_group_id INTEGER")
|
|
except Exception:
|
|
pass
|
|
try:
|
|
con.execute("ALTER TABLE chats ADD COLUMN forum_thread_id INTEGER")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def get_or_create_tg_user(
|
|
tg_user_id: int,
|
|
platform_user_id: str,
|
|
display_name: str | None,
|
|
) -> dict:
|
|
with _conn() as con:
|
|
row = con.execute(
|
|
"SELECT * FROM tg_users WHERE tg_user_id = ?", (tg_user_id,)
|
|
).fetchone()
|
|
if row:
|
|
return dict(row)
|
|
con.execute(
|
|
"INSERT INTO tg_users (tg_user_id, platform_user_id, display_name) VALUES (?, ?, ?)",
|
|
(tg_user_id, platform_user_id, display_name),
|
|
)
|
|
return {
|
|
"tg_user_id": tg_user_id,
|
|
"platform_user_id": platform_user_id,
|
|
"display_name": display_name,
|
|
}
|
|
|
|
|
|
def create_chat(tg_user_id: int, name: str) -> str:
|
|
chat_id = str(uuid.uuid4())
|
|
with _conn() as con:
|
|
con.execute(
|
|
"INSERT INTO chats (chat_id, tg_user_id, name) VALUES (?, ?, ?)",
|
|
(chat_id, tg_user_id, name),
|
|
)
|
|
return chat_id
|
|
|
|
|
|
def get_last_chat(tg_user_id: int) -> dict | None:
|
|
with _conn() as con:
|
|
row = con.execute(
|
|
"SELECT * FROM chats WHERE tg_user_id = ? AND archived_at IS NULL "
|
|
"ORDER BY created_at DESC LIMIT 1",
|
|
(tg_user_id,),
|
|
).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
|
|
def get_user_chats(tg_user_id: int) -> list[dict]:
|
|
with _conn() as con:
|
|
rows = con.execute(
|
|
"SELECT * FROM chats WHERE tg_user_id = ? AND archived_at IS NULL "
|
|
"ORDER BY created_at ASC",
|
|
(tg_user_id,),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
def count_chats(tg_user_id: int) -> int:
|
|
with _conn() as con:
|
|
row = con.execute(
|
|
"SELECT COUNT(*) FROM chats WHERE tg_user_id = ? AND archived_at IS NULL",
|
|
(tg_user_id,),
|
|
).fetchone()
|
|
return row[0]
|
|
|
|
|
|
def get_chat_by_id(chat_id: str) -> dict | None:
|
|
with _conn() as con:
|
|
row = con.execute("SELECT * FROM chats WHERE chat_id = ?", (chat_id,)).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
|
|
def rename_chat(chat_id: str, new_name: str) -> None:
|
|
with _conn() as con:
|
|
con.execute("UPDATE chats SET name = ? WHERE chat_id = ?", (new_name, chat_id))
|
|
|
|
|
|
def archive_chat(chat_id: str) -> None:
|
|
with _conn() as con:
|
|
con.execute(
|
|
"UPDATE chats SET archived_at = CURRENT_TIMESTAMP WHERE chat_id = ?",
|
|
(chat_id,),
|
|
)
|
|
|
|
|
|
def set_forum_group(tg_user_id: int, group_id: int) -> None:
|
|
with _conn() as con:
|
|
con.execute(
|
|
"UPDATE tg_users SET forum_group_id = ? WHERE tg_user_id = ?",
|
|
(group_id, tg_user_id),
|
|
)
|
|
|
|
|
|
def get_forum_group(tg_user_id: int) -> int | None:
|
|
with _conn() as con:
|
|
row = con.execute(
|
|
"SELECT forum_group_id FROM tg_users WHERE tg_user_id = ?",
|
|
(tg_user_id,),
|
|
).fetchone()
|
|
return row["forum_group_id"] if row else None
|
|
|
|
|
|
def set_forum_thread(chat_id: str, thread_id: int) -> None:
|
|
with _conn() as con:
|
|
con.execute(
|
|
"UPDATE chats SET forum_thread_id = ? WHERE chat_id = ?",
|
|
(thread_id, chat_id),
|
|
)
|
|
|
|
|
|
def get_chat_by_thread(tg_user_id: int, thread_id: int) -> dict | None:
|
|
with _conn() as con:
|
|
row = con.execute(
|
|
"SELECT * FROM chats WHERE tg_user_id = ? AND forum_thread_id = ? "
|
|
"AND archived_at IS NULL",
|
|
(tg_user_id, thread_id),
|
|
).fetchone()
|
|
return dict(row) if row else None
|