"""Tests for hermes_state.py — SessionDB SQLite CRUD, FTS5 search, export.""" import time import pytest from pathlib import Path from hermes_state import SessionDB @pytest.fixture() def db(tmp_path): """Create a SessionDB with a temp database file.""" db_path = tmp_path / "test_state.db" session_db = SessionDB(db_path=db_path) yield session_db session_db.close() # ========================================================================= # Session lifecycle # ========================================================================= class TestSessionLifecycle: def test_create_and_get_session(self, db): sid = db.create_session( session_id="s1", source="cli", model="test-model", ) assert sid == "s1" session = db.get_session("s1") assert session is not None assert session["source"] == "cli" assert session["model"] == "test-model" assert session["ended_at"] is None def test_get_nonexistent_session(self, db): assert db.get_session("nonexistent") is None def test_end_session(self, db): db.create_session(session_id="s1", source="cli") db.end_session("s1", end_reason="user_exit") session = db.get_session("s1") assert isinstance(session["ended_at"], float) assert session["end_reason"] == "user_exit" def test_update_system_prompt(self, db): db.create_session(session_id="s1", source="cli") db.update_system_prompt("s1", "You are a helpful assistant.") session = db.get_session("s1") assert session["system_prompt"] == "You are a helpful assistant." def test_update_token_counts(self, db): db.create_session(session_id="s1", source="cli") db.update_token_counts("s1", input_tokens=100, output_tokens=50) db.update_token_counts("s1", input_tokens=200, output_tokens=100) session = db.get_session("s1") assert session["input_tokens"] == 300 assert session["output_tokens"] == 150 def test_parent_session(self, db): db.create_session(session_id="parent", source="cli") db.create_session(session_id="child", source="cli", parent_session_id="parent") child = db.get_session("child") assert child["parent_session_id"] == "parent" # ========================================================================= # Message storage # ========================================================================= class TestMessageStorage: def test_append_and_get_messages(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi there!") messages = db.get_messages("s1") assert len(messages) == 2 assert messages[0]["role"] == "user" assert messages[0]["content"] == "Hello" assert messages[1]["role"] == "assistant" def test_message_increments_session_count(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi") session = db.get_session("s1") assert session["message_count"] == 2 def test_tool_message_increments_tool_count(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="tool", content="result", tool_name="web_search") session = db.get_session("s1") assert session["tool_call_count"] == 1 def test_tool_calls_serialization(self, db): db.create_session(session_id="s1", source="cli") tool_calls = [{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}] db.append_message("s1", role="assistant", tool_calls=tool_calls) messages = db.get_messages("s1") assert messages[0]["tool_calls"] == tool_calls def test_get_messages_as_conversation(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi!") conv = db.get_messages_as_conversation("s1") assert len(conv) == 2 assert conv[0] == {"role": "user", "content": "Hello"} assert conv[1] == {"role": "assistant", "content": "Hi!"} def test_finish_reason_stored(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="assistant", content="Done", finish_reason="stop") messages = db.get_messages("s1") assert messages[0]["finish_reason"] == "stop" # ========================================================================= # FTS5 search # ========================================================================= class TestFTS5Search: def test_search_finds_content(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="How do I deploy with Docker?") db.append_message("s1", role="assistant", content="Use docker compose up.") results = db.search_messages("docker") assert len(results) == 2 # At least one result should mention docker snippets = [r.get("snippet", "") for r in results] assert any("docker" in s.lower() or "Docker" in s for s in snippets) def test_search_empty_query(self, db): assert db.search_messages("") == [] assert db.search_messages(" ") == [] def test_search_with_source_filter(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="CLI question about Python") db.create_session(session_id="s2", source="telegram") db.append_message("s2", role="user", content="Telegram question about Python") results = db.search_messages("Python", source_filter=["telegram"]) # Should only find the telegram message sources = [r["source"] for r in results] assert all(s == "telegram" for s in sources) def test_search_with_role_filter(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="What is FastAPI?") db.append_message("s1", role="assistant", content="FastAPI is a web framework.") results = db.search_messages("FastAPI", role_filter=["assistant"]) roles = [r["role"] for r in results] assert all(r == "assistant" for r in roles) def test_search_returns_context(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Tell me about Kubernetes") db.append_message("s1", role="assistant", content="Kubernetes is an orchestrator.") results = db.search_messages("Kubernetes") assert len(results) == 2 assert "context" in results[0] assert isinstance(results[0]["context"], list) assert len(results[0]["context"]) > 0 def test_search_special_chars_do_not_crash(self, db): """FTS5 special characters in queries must not raise OperationalError.""" db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="How do I use C++ templates?") # Each of these previously caused sqlite3.OperationalError dangerous_queries = [ 'C++', # + is FTS5 column filter '"unterminated', # unbalanced double-quote '(problem', # unbalanced parenthesis 'hello AND', # dangling boolean operator '***', # repeated wildcard '{test}', # curly braces (column reference) 'OR hello', # leading boolean operator 'a AND OR b', # adjacent operators ] for query in dangerous_queries: # Must not raise — should return list (possibly empty) results = db.search_messages(query) assert isinstance(results, list), f"Query {query!r} did not return a list" def test_search_sanitized_query_still_finds_content(self, db): """Sanitization must not break normal keyword search.""" db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Learning C++ templates today") # "C++" sanitized to "C" should still match "C++" results = db.search_messages("C++") # The word "C" appears in the content, so FTS5 should find it assert isinstance(results, list) def test_sanitize_fts5_query_strips_dangerous_chars(self): """Unit test for _sanitize_fts5_query static method.""" from hermes_state import SessionDB s = SessionDB._sanitize_fts5_query assert s('hello world') == 'hello world' assert '+' not in s('C++') assert '"' not in s('"unterminated') assert '(' not in s('(problem') assert '{' not in s('{test}') # Dangling operators removed assert s('hello AND') == 'hello' assert s('OR world') == 'world' # Leading bare * removed assert s('***') == '' # Valid prefix kept assert s('deploy*') == 'deploy*' # ========================================================================= # Session search and listing # ========================================================================= class TestSearchSessions: def test_list_all_sessions(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") sessions = db.search_sessions() assert len(sessions) == 2 def test_filter_by_source(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") sessions = db.search_sessions(source="cli") assert len(sessions) == 1 assert sessions[0]["source"] == "cli" def test_pagination(self, db): for i in range(5): db.create_session(session_id=f"s{i}", source="cli") page1 = db.search_sessions(limit=2) page2 = db.search_sessions(limit=2, offset=2) assert len(page1) == 2 assert len(page2) == 2 assert page1[0]["id"] != page2[0]["id"] # ========================================================================= # Counts # ========================================================================= class TestCounts: def test_session_count(self, db): assert db.session_count() == 0 db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") assert db.session_count() == 2 def test_session_count_by_source(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") db.create_session(session_id="s3", source="cli") assert db.session_count(source="cli") == 2 assert db.session_count(source="telegram") == 1 def test_message_count_total(self, db): assert db.message_count() == 0 db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi") assert db.message_count() == 2 def test_message_count_per_session(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="cli") db.append_message("s1", role="user", content="A") db.append_message("s2", role="user", content="B") db.append_message("s2", role="user", content="C") assert db.message_count(session_id="s1") == 1 assert db.message_count(session_id="s2") == 2 # ========================================================================= # Delete and export # ========================================================================= class TestDeleteAndExport: def test_delete_session(self, db): db.create_session(session_id="s1", source="cli") db.append_message("s1", role="user", content="Hello") assert db.delete_session("s1") is True assert db.get_session("s1") is None assert db.message_count(session_id="s1") == 0 def test_delete_nonexistent(self, db): assert db.delete_session("nope") is False def test_export_session(self, db): db.create_session(session_id="s1", source="cli", model="test") db.append_message("s1", role="user", content="Hello") db.append_message("s1", role="assistant", content="Hi") export = db.export_session("s1") assert isinstance(export, dict) assert export["source"] == "cli" assert len(export["messages"]) == 2 def test_export_nonexistent(self, db): assert db.export_session("nope") is None def test_export_all(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") db.append_message("s1", role="user", content="A") exports = db.export_all() assert len(exports) == 2 def test_export_all_with_source(self, db): db.create_session(session_id="s1", source="cli") db.create_session(session_id="s2", source="telegram") exports = db.export_all(source="cli") assert len(exports) == 1 assert exports[0]["source"] == "cli" # ========================================================================= # Prune # ========================================================================= class TestPruneSessions: def test_prune_old_ended_sessions(self, db): # Create and end an "old" session db.create_session(session_id="old", source="cli") db.end_session("old", end_reason="done") # Manually backdate started_at db._conn.execute( "UPDATE sessions SET started_at = ? WHERE id = ?", (time.time() - 100 * 86400, "old"), ) db._conn.commit() # Create a recent session db.create_session(session_id="new", source="cli") pruned = db.prune_sessions(older_than_days=90) assert pruned == 1 assert db.get_session("old") is None session = db.get_session("new") assert session is not None assert session["id"] == "new" def test_prune_skips_active_sessions(self, db): db.create_session(session_id="active", source="cli") # Backdate but don't end db._conn.execute( "UPDATE sessions SET started_at = ? WHERE id = ?", (time.time() - 200 * 86400, "active"), ) db._conn.commit() pruned = db.prune_sessions(older_than_days=90) assert pruned == 0 assert db.get_session("active") is not None def test_prune_with_source_filter(self, db): for sid, src in [("old_cli", "cli"), ("old_tg", "telegram")]: db.create_session(session_id=sid, source=src) db.end_session(sid, end_reason="done") db._conn.execute( "UPDATE sessions SET started_at = ? WHERE id = ?", (time.time() - 200 * 86400, sid), ) db._conn.commit() pruned = db.prune_sessions(older_than_days=90, source="cli") assert pruned == 1 assert db.get_session("old_cli") is None assert db.get_session("old_tg") is not None # ========================================================================= # Schema and WAL mode # ========================================================================= class TestSchemaInit: def test_wal_mode(self, db): cursor = db._conn.execute("PRAGMA journal_mode") mode = cursor.fetchone()[0] assert mode == "wal" def test_foreign_keys_enabled(self, db): cursor = db._conn.execute("PRAGMA foreign_keys") assert cursor.fetchone()[0] == 1 def test_tables_exist(self, db): cursor = db._conn.execute( "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" ) tables = {row[0] for row in cursor.fetchall()} assert "sessions" in tables assert "messages" in tables assert "schema_version" in tables def test_schema_version(self, db): cursor = db._conn.execute("SELECT version FROM schema_version") version = cursor.fetchone()[0] assert version == 2