Merge pull request #1704 from NousResearch/fix/hermes-state-thread-locks
fix(state): add missing thread locks to 4 SessionDB methods
This commit is contained in:
commit
ed3bcae8bd
1 changed files with 48 additions and 44 deletions
|
|
@ -809,17 +809,18 @@ class SessionDB:
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""List sessions, optionally filtered by source."""
|
"""List sessions, optionally filtered by source."""
|
||||||
if source:
|
with self._lock:
|
||||||
cursor = self._conn.execute(
|
if source:
|
||||||
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
cursor = self._conn.execute(
|
||||||
(source, limit, offset),
|
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||||
)
|
(source, limit, offset),
|
||||||
else:
|
)
|
||||||
cursor = self._conn.execute(
|
else:
|
||||||
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
cursor = self._conn.execute(
|
||||||
(limit, offset),
|
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
|
||||||
)
|
(limit, offset),
|
||||||
return [dict(row) for row in cursor.fetchall()]
|
)
|
||||||
|
return [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Utility
|
# Utility
|
||||||
|
|
@ -871,26 +872,28 @@ class SessionDB:
|
||||||
|
|
||||||
def clear_messages(self, session_id: str) -> None:
|
def clear_messages(self, session_id: str) -> None:
|
||||||
"""Delete all messages for a session and reset its counters."""
|
"""Delete all messages for a session and reset its counters."""
|
||||||
self._conn.execute(
|
with self._lock:
|
||||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
self._conn.execute(
|
||||||
)
|
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||||
self._conn.execute(
|
)
|
||||||
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
self._conn.execute(
|
||||||
(session_id,),
|
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||||
)
|
(session_id,),
|
||||||
self._conn.commit()
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
|
||||||
def delete_session(self, session_id: str) -> bool:
|
def delete_session(self, session_id: str) -> bool:
|
||||||
"""Delete a session and all its messages. Returns True if found."""
|
"""Delete a session and all its messages. Returns True if found."""
|
||||||
cursor = self._conn.execute(
|
with self._lock:
|
||||||
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
cursor = self._conn.execute(
|
||||||
)
|
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
||||||
if cursor.fetchone()[0] == 0:
|
)
|
||||||
return False
|
if cursor.fetchone()[0] == 0:
|
||||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
return False
|
||||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||||
self._conn.commit()
|
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||||
return True
|
self._conn.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
|
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
@ -900,22 +903,23 @@ class SessionDB:
|
||||||
import time as _time
|
import time as _time
|
||||||
cutoff = _time.time() - (older_than_days * 86400)
|
cutoff = _time.time() - (older_than_days * 86400)
|
||||||
|
|
||||||
if source:
|
with self._lock:
|
||||||
cursor = self._conn.execute(
|
if source:
|
||||||
"""SELECT id FROM sessions
|
cursor = self._conn.execute(
|
||||||
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
|
"""SELECT id FROM sessions
|
||||||
(cutoff, source),
|
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
|
||||||
)
|
(cutoff, source),
|
||||||
else:
|
)
|
||||||
cursor = self._conn.execute(
|
else:
|
||||||
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
|
cursor = self._conn.execute(
|
||||||
(cutoff,),
|
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
|
||||||
)
|
(cutoff,),
|
||||||
session_ids = [row["id"] for row in cursor.fetchall()]
|
)
|
||||||
|
session_ids = [row["id"] for row in cursor.fetchall()]
|
||||||
|
|
||||||
for sid in session_ids:
|
for sid in session_ids:
|
||||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||||
|
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
return len(session_ids)
|
return len(session_ids)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue