fix(state): add missing thread locks to 4 SessionDB methods

search_sessions(), clear_messages(), delete_session(), and
prune_sessions() all accessed self._conn without acquiring self._lock.
Every other method in the class uses the lock. In multi-threaded
contexts (gateway serving concurrent platform messages), these
unprotected methods can cause sqlite3.ProgrammingError from concurrent
cursor operations on the same connection.
This commit is contained in:
teknium1 2026-03-17 03:50:06 -07:00
parent ce7418e274
commit efa778a0ef

View file

@ -809,17 +809,18 @@ class SessionDB:
offset: int = 0, offset: int = 0,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""List sessions, optionally filtered by source.""" """List sessions, optionally filtered by source."""
if source: with self._lock:
cursor = self._conn.execute( if source:
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?", cursor = self._conn.execute(
(source, limit, offset), "SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
) (source, limit, offset),
else: )
cursor = self._conn.execute( else:
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?", cursor = self._conn.execute(
(limit, offset), "SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
) (limit, offset),
return [dict(row) for row in cursor.fetchall()] )
return [dict(row) for row in cursor.fetchall()]
# ========================================================================= # =========================================================================
# Utility # Utility
@ -871,26 +872,28 @@ class SessionDB:
def clear_messages(self, session_id: str) -> None: def clear_messages(self, session_id: str) -> None:
"""Delete all messages for a session and reset its counters.""" """Delete all messages for a session and reset its counters."""
self._conn.execute( with self._lock:
"DELETE FROM messages WHERE session_id = ?", (session_id,) self._conn.execute(
) "DELETE FROM messages WHERE session_id = ?", (session_id,)
self._conn.execute( )
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", self._conn.execute(
(session_id,), "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
) (session_id,),
self._conn.commit() )
self._conn.commit()
def delete_session(self, session_id: str) -> bool: def delete_session(self, session_id: str) -> bool:
"""Delete a session and all its messages. Returns True if found.""" """Delete a session and all its messages. Returns True if found."""
cursor = self._conn.execute( with self._lock:
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,) cursor = self._conn.execute(
) "SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
if cursor.fetchone()[0] == 0: )
return False if cursor.fetchone()[0] == 0:
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) return False
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
self._conn.commit() self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
return True self._conn.commit()
return True
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int: def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
""" """
@ -900,22 +903,23 @@ class SessionDB:
import time as _time import time as _time
cutoff = _time.time() - (older_than_days * 86400) cutoff = _time.time() - (older_than_days * 86400)
if source: with self._lock:
cursor = self._conn.execute( if source:
"""SELECT id FROM sessions cursor = self._conn.execute(
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""", """SELECT id FROM sessions
(cutoff, source), WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
) (cutoff, source),
else: )
cursor = self._conn.execute( else:
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL", cursor = self._conn.execute(
(cutoff,), "SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
) (cutoff,),
session_ids = [row["id"] for row in cursor.fetchall()] )
session_ids = [row["id"] for row in cursor.fetchall()]
for sid in session_ids: for sid in session_ids:
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,)) self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
self._conn.commit() self._conn.commit()
return len(session_ids) return len(session_ids)