fix: eliminate 3x SQLite message duplication in gateway sessions (#860)
Three separate code paths all wrote to the same SQLite state.db with no deduplication, inflating session transcripts by 3-4x: 1. _log_msg_to_db() — wrote each message individually after append 2. _flush_messages_to_session_db() — re-wrote ALL new messages at every _persist_session() call (~18 exit points), with no tracking of what was already written 3. gateway append_to_transcript() — wrote everything a third time after the agent returned Since load_transcript() prefers SQLite over JSONL, the inflated data was loaded on every session resume, causing proportional token waste. Fix: - Remove _log_msg_to_db() and all 16 call sites (redundant with flush) - Add _last_flushed_db_idx tracking in _flush_messages_to_session_db() so repeated _persist_session() calls only write truly new messages - Reset flush cursor on compression (new session ID) - Add skip_db parameter to SessionStore.append_to_transcript() so the gateway skips SQLite writes when the agent already persisted them - Gateway now passes skip_db=True for agent-managed messages, still writes to JSONL as backup Verified: a 12-message CLI session with tool calls produces exactly 12 SQLite rows with zero duplicates (previously would be 36-48). Tests: 9 new tests covering flush deduplication, skip_db behavior, compression reset, and initialization. Full suite passes (2869 tests).
This commit is contained in:
parent
5fc751e543
commit
c1171fe666
5 changed files with 323 additions and 54 deletions
|
|
@ -1322,6 +1322,11 @@ class GatewayRunner:
|
||||||
{"role": "assistant", "content": response, "timestamp": ts}
|
{"role": "assistant", "content": response, "timestamp": ts}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# The agent already persisted these messages to SQLite via
|
||||||
|
# _flush_messages_to_session_db(), so skip the DB write here
|
||||||
|
# to prevent the duplicate-write bug (#860). We still write
|
||||||
|
# to JSONL for backward compatibility and as a backup.
|
||||||
|
agent_persisted = self._session_db is not None
|
||||||
for msg in new_messages:
|
for msg in new_messages:
|
||||||
# Skip system messages (they're rebuilt each run)
|
# Skip system messages (they're rebuilt each run)
|
||||||
if msg.get("role") == "system":
|
if msg.get("role") == "system":
|
||||||
|
|
@ -1329,7 +1334,8 @@ class GatewayRunner:
|
||||||
# Add timestamp to each message for debugging
|
# Add timestamp to each message for debugging
|
||||||
entry = {**msg, "timestamp": ts}
|
entry = {**msg, "timestamp": ts}
|
||||||
self.session_store.append_to_transcript(
|
self.session_store.append_to_transcript(
|
||||||
session_entry.session_id, entry
|
session_entry.session_id, entry,
|
||||||
|
skip_db=agent_persisted,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update session
|
# Update session
|
||||||
|
|
|
||||||
|
|
@ -677,10 +677,17 @@ class SessionStore:
|
||||||
"""Get the path to a session's legacy transcript file."""
|
"""Get the path to a session's legacy transcript file."""
|
||||||
return self.sessions_dir / f"{session_id}.jsonl"
|
return self.sessions_dir / f"{session_id}.jsonl"
|
||||||
|
|
||||||
def append_to_transcript(self, session_id: str, message: Dict[str, Any]) -> None:
|
def append_to_transcript(self, session_id: str, message: Dict[str, Any], skip_db: bool = False) -> None:
|
||||||
"""Append a message to a session's transcript (SQLite + legacy JSONL)."""
|
"""Append a message to a session's transcript (SQLite + legacy JSONL).
|
||||||
# Write to SQLite
|
|
||||||
if self._db:
|
Args:
|
||||||
|
skip_db: When True, only write to JSONL and skip the SQLite write.
|
||||||
|
Used when the agent already persisted messages to SQLite
|
||||||
|
via its own _flush_messages_to_session_db(), preventing
|
||||||
|
the duplicate-write bug (#860).
|
||||||
|
"""
|
||||||
|
# Write to SQLite (unless the agent already handled it)
|
||||||
|
if self._db and not skip_db:
|
||||||
try:
|
try:
|
||||||
self._db.append_message(
|
self._db.append_message(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
|
|
||||||
58
run_agent.py
58
run_agent.py
|
|
@ -497,6 +497,7 @@ class AIAgent:
|
||||||
|
|
||||||
# SQLite session store (optional -- provided by CLI or gateway)
|
# SQLite session store (optional -- provided by CLI or gateway)
|
||||||
self._session_db = session_db
|
self._session_db = session_db
|
||||||
|
self._last_flushed_db_idx = 0 # tracks DB-write cursor to prevent duplicate writes
|
||||||
if self._session_db:
|
if self._session_db:
|
||||||
try:
|
try:
|
||||||
self._session_db.create_session(
|
self._session_db.create_session(
|
||||||
|
|
@ -802,45 +803,19 @@ class AIAgent:
|
||||||
self._save_session_log(messages)
|
self._save_session_log(messages)
|
||||||
self._flush_messages_to_session_db(messages, conversation_history)
|
self._flush_messages_to_session_db(messages, conversation_history)
|
||||||
|
|
||||||
def _log_msg_to_db(self, msg: Dict):
|
|
||||||
"""Log a single message to SQLite immediately. Called after each messages.append()."""
|
|
||||||
if not self._session_db:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
role = msg.get("role", "unknown")
|
|
||||||
content = msg.get("content")
|
|
||||||
tool_calls_data = None
|
|
||||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
|
||||||
tool_calls_data = [
|
|
||||||
{"name": tc.function.name, "arguments": tc.function.arguments}
|
|
||||||
for tc in msg.tool_calls
|
|
||||||
]
|
|
||||||
elif isinstance(msg.get("tool_calls"), list):
|
|
||||||
tool_calls_data = msg["tool_calls"]
|
|
||||||
self._session_db.append_message(
|
|
||||||
session_id=self.session_id,
|
|
||||||
role=role,
|
|
||||||
content=content,
|
|
||||||
tool_name=msg.get("tool_name"),
|
|
||||||
tool_calls=tool_calls_data,
|
|
||||||
tool_call_id=msg.get("tool_call_id"),
|
|
||||||
finish_reason=msg.get("finish_reason"),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("Session DB log_msg failed: %s", e)
|
|
||||||
|
|
||||||
def _flush_messages_to_session_db(self, messages: List[Dict], conversation_history: List[Dict] = None):
|
def _flush_messages_to_session_db(self, messages: List[Dict], conversation_history: List[Dict] = None):
|
||||||
"""Persist any un-logged messages to the SQLite session store.
|
"""Persist any un-flushed messages to the SQLite session store.
|
||||||
|
|
||||||
Called both at the normal end of run_conversation and from every early-
|
Uses _last_flushed_db_idx to track which messages have already been
|
||||||
return path so that tool calls, tool responses, and assistant messages
|
written, so repeated calls (from multiple exit paths) only write
|
||||||
are never lost even when the conversation errors out.
|
truly new messages — preventing the duplicate-write bug (#860).
|
||||||
"""
|
"""
|
||||||
if not self._session_db:
|
if not self._session_db:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
start_idx = len(conversation_history) if conversation_history else 0
|
start_idx = len(conversation_history) if conversation_history else 0
|
||||||
for msg in messages[start_idx:]:
|
flush_from = max(start_idx, self._last_flushed_db_idx)
|
||||||
|
for msg in messages[flush_from:]:
|
||||||
role = msg.get("role", "unknown")
|
role = msg.get("role", "unknown")
|
||||||
content = msg.get("content")
|
content = msg.get("content")
|
||||||
tool_calls_data = None
|
tool_calls_data = None
|
||||||
|
|
@ -860,6 +835,7 @@ class AIAgent:
|
||||||
tool_call_id=msg.get("tool_call_id"),
|
tool_call_id=msg.get("tool_call_id"),
|
||||||
finish_reason=msg.get("finish_reason"),
|
finish_reason=msg.get("finish_reason"),
|
||||||
)
|
)
|
||||||
|
self._last_flushed_db_idx = len(messages)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Session DB append_message failed: %s", e)
|
logger.debug("Session DB append_message failed: %s", e)
|
||||||
|
|
||||||
|
|
@ -2689,6 +2665,8 @@ class AIAgent:
|
||||||
except (ValueError, Exception) as e:
|
except (ValueError, Exception) as e:
|
||||||
logger.debug("Could not propagate title on compression: %s", e)
|
logger.debug("Could not propagate title on compression: %s", e)
|
||||||
self._session_db.update_system_prompt(self.session_id, new_system_prompt)
|
self._session_db.update_system_prompt(self.session_id, new_system_prompt)
|
||||||
|
# Reset flush cursor — new session starts with no messages written
|
||||||
|
self._last_flushed_db_idx = 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Session DB compression split failed: %s", e)
|
logger.debug("Session DB compression split failed: %s", e)
|
||||||
|
|
||||||
|
|
@ -2712,7 +2690,6 @@ class AIAgent:
|
||||||
"tool_call_id": skipped_tc.id,
|
"tool_call_id": skipped_tc.id,
|
||||||
}
|
}
|
||||||
messages.append(skip_msg)
|
messages.append(skip_msg)
|
||||||
self._log_msg_to_db(skip_msg)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
function_name = tool_call.function.name
|
function_name = tool_call.function.name
|
||||||
|
|
@ -2921,7 +2898,6 @@ class AIAgent:
|
||||||
"tool_call_id": tool_call.id
|
"tool_call_id": tool_call.id
|
||||||
}
|
}
|
||||||
messages.append(tool_msg)
|
messages.append(tool_msg)
|
||||||
self._log_msg_to_db(tool_msg)
|
|
||||||
|
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||||
|
|
@ -2938,7 +2914,6 @@ class AIAgent:
|
||||||
"tool_call_id": skipped_tc.id
|
"tool_call_id": skipped_tc.id
|
||||||
}
|
}
|
||||||
messages.append(skip_msg)
|
messages.append(skip_msg)
|
||||||
self._log_msg_to_db(skip_msg)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||||
|
|
@ -3169,7 +3144,6 @@ class AIAgent:
|
||||||
# Add user message
|
# Add user message
|
||||||
user_msg = {"role": "user", "content": user_message}
|
user_msg = {"role": "user", "content": user_message}
|
||||||
messages.append(user_msg)
|
messages.append(user_msg)
|
||||||
self._log_msg_to_db(user_msg)
|
|
||||||
|
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")
|
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")
|
||||||
|
|
@ -3572,7 +3546,6 @@ class AIAgent:
|
||||||
length_continue_retries += 1
|
length_continue_retries += 1
|
||||||
interim_msg = self._build_assistant_message(assistant_message, finish_reason)
|
interim_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||||
messages.append(interim_msg)
|
messages.append(interim_msg)
|
||||||
self._log_msg_to_db(interim_msg)
|
|
||||||
if assistant_message.content:
|
if assistant_message.content:
|
||||||
truncated_response_prefix += assistant_message.content
|
truncated_response_prefix += assistant_message.content
|
||||||
|
|
||||||
|
|
@ -3590,7 +3563,6 @@ class AIAgent:
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
messages.append(continue_msg)
|
messages.append(continue_msg)
|
||||||
self._log_msg_to_db(continue_msg)
|
|
||||||
self._session_messages = messages
|
self._session_messages = messages
|
||||||
self._save_session_log(messages)
|
self._save_session_log(messages)
|
||||||
restart_with_length_continuation = True
|
restart_with_length_continuation = True
|
||||||
|
|
@ -4063,7 +4035,6 @@ class AIAgent:
|
||||||
)
|
)
|
||||||
if not duplicate_interim:
|
if not duplicate_interim:
|
||||||
messages.append(interim_msg)
|
messages.append(interim_msg)
|
||||||
self._log_msg_to_db(interim_msg)
|
|
||||||
|
|
||||||
if self._codex_incomplete_retries < 3:
|
if self._codex_incomplete_retries < 3:
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
|
|
@ -4114,7 +4085,6 @@ class AIAgent:
|
||||||
print(f"{self.log_prefix}⚠️ Unknown tool '{invalid_preview}' — sending error to model for self-correction")
|
print(f"{self.log_prefix}⚠️ Unknown tool '{invalid_preview}' — sending error to model for self-correction")
|
||||||
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
|
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||||
messages.append(assistant_msg)
|
messages.append(assistant_msg)
|
||||||
self._log_msg_to_db(assistant_msg)
|
|
||||||
for tc in assistant_message.tool_calls:
|
for tc in assistant_message.tool_calls:
|
||||||
if tc.function.name not in self.valid_tool_names:
|
if tc.function.name not in self.valid_tool_names:
|
||||||
content = f"Tool '{tc.function.name}' does not exist. Available tools: {available}"
|
content = f"Tool '{tc.function.name}' does not exist. Available tools: {available}"
|
||||||
|
|
@ -4169,7 +4139,6 @@ class AIAgent:
|
||||||
)
|
)
|
||||||
recovery_dict = {"role": "user", "content": recovery_msg}
|
recovery_dict = {"role": "user", "content": recovery_msg}
|
||||||
messages.append(recovery_dict)
|
messages.append(recovery_dict)
|
||||||
self._log_msg_to_db(recovery_dict)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Reset retry counter on successful JSON validation
|
# Reset retry counter on successful JSON validation
|
||||||
|
|
@ -4191,7 +4160,6 @@ class AIAgent:
|
||||||
print(f" ┊ 💬 {clean}")
|
print(f" ┊ 💬 {clean}")
|
||||||
|
|
||||||
messages.append(assistant_msg)
|
messages.append(assistant_msg)
|
||||||
self._log_msg_to_db(assistant_msg)
|
|
||||||
|
|
||||||
self._execute_tool_calls(assistant_message, messages, effective_task_id)
|
self._execute_tool_calls(assistant_message, messages, effective_task_id)
|
||||||
|
|
||||||
|
|
@ -4292,7 +4260,6 @@ class AIAgent:
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
messages.append(empty_msg)
|
messages.append(empty_msg)
|
||||||
self._log_msg_to_db(empty_msg)
|
|
||||||
|
|
||||||
self._cleanup_task_resources(effective_task_id)
|
self._cleanup_task_resources(effective_task_id)
|
||||||
self._persist_session(messages, conversation_history)
|
self._persist_session(messages, conversation_history)
|
||||||
|
|
@ -4323,7 +4290,6 @@ class AIAgent:
|
||||||
codex_ack_continuations += 1
|
codex_ack_continuations += 1
|
||||||
interim_msg = self._build_assistant_message(assistant_message, "incomplete")
|
interim_msg = self._build_assistant_message(assistant_message, "incomplete")
|
||||||
messages.append(interim_msg)
|
messages.append(interim_msg)
|
||||||
self._log_msg_to_db(interim_msg)
|
|
||||||
|
|
||||||
continue_msg = {
|
continue_msg = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|
@ -4333,7 +4299,6 @@ class AIAgent:
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
messages.append(continue_msg)
|
messages.append(continue_msg)
|
||||||
self._log_msg_to_db(continue_msg)
|
|
||||||
self._session_messages = messages
|
self._session_messages = messages
|
||||||
self._save_session_log(messages)
|
self._save_session_log(messages)
|
||||||
continue
|
continue
|
||||||
|
|
@ -4349,7 +4314,6 @@ class AIAgent:
|
||||||
final_msg = self._build_assistant_message(assistant_message, finish_reason)
|
final_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||||
|
|
||||||
messages.append(final_msg)
|
messages.append(final_msg)
|
||||||
self._log_msg_to_db(final_msg)
|
|
||||||
|
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)")
|
print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)")
|
||||||
|
|
@ -4386,7 +4350,6 @@ class AIAgent:
|
||||||
"content": f"Error executing tool: {error_msg}",
|
"content": f"Error executing tool: {error_msg}",
|
||||||
}
|
}
|
||||||
messages.append(err_msg)
|
messages.append(err_msg)
|
||||||
self._log_msg_to_db(err_msg)
|
|
||||||
pending_handled = True
|
pending_handled = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -4399,7 +4362,6 @@ class AIAgent:
|
||||||
"content": f"[System error during processing: {error_msg}]",
|
"content": f"[System error during processing: {error_msg}]",
|
||||||
}
|
}
|
||||||
messages.append(sys_err_msg)
|
messages.append(sys_err_msg)
|
||||||
self._log_msg_to_db(sys_err_msg)
|
|
||||||
|
|
||||||
# If we're near the limit, break to avoid infinite loops
|
# If we're near the limit, break to avoid infinite loops
|
||||||
if api_call_count >= self.max_iterations - 1:
|
if api_call_count >= self.max_iterations - 1:
|
||||||
|
|
|
||||||
294
tests/test_860_dedup.py
Normal file
294
tests/test_860_dedup.py
Normal file
|
|
@ -0,0 +1,294 @@
|
||||||
|
"""Tests for issue #860 — SQLite session transcript deduplication.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
1. _flush_messages_to_session_db uses _last_flushed_db_idx to avoid re-writing
|
||||||
|
2. Multiple _persist_session calls don't duplicate messages
|
||||||
|
3. append_to_transcript(skip_db=True) skips SQLite but writes JSONL
|
||||||
|
4. The gateway doesn't double-write messages the agent already persisted
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test: _flush_messages_to_session_db only writes new messages
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFlushDeduplication:
|
||||||
|
"""Verify _flush_messages_to_session_db tracks what it already wrote."""
|
||||||
|
|
||||||
|
def _make_agent(self, session_db):
|
||||||
|
"""Create a minimal AIAgent with a real session DB."""
|
||||||
|
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||||
|
from run_agent import AIAgent
|
||||||
|
agent = AIAgent(
|
||||||
|
model="test/model",
|
||||||
|
quiet_mode=True,
|
||||||
|
session_db=session_db,
|
||||||
|
session_id="test-session-860",
|
||||||
|
skip_context_files=True,
|
||||||
|
skip_memory=True,
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def test_flush_writes_only_new_messages(self):
|
||||||
|
"""First flush writes all new messages, second flush writes none."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.db"
|
||||||
|
db = SessionDB(db_path=db_path)
|
||||||
|
|
||||||
|
agent = self._make_agent(db)
|
||||||
|
|
||||||
|
conversation_history = [
|
||||||
|
{"role": "user", "content": "old message"},
|
||||||
|
]
|
||||||
|
messages = list(conversation_history) + [
|
||||||
|
{"role": "user", "content": "new question"},
|
||||||
|
{"role": "assistant", "content": "new answer"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# First flush — should write 2 new messages
|
||||||
|
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||||
|
|
||||||
|
rows = db.get_messages(agent.session_id)
|
||||||
|
assert len(rows) == 2, f"Expected 2 messages, got {len(rows)}"
|
||||||
|
|
||||||
|
# Second flush with SAME messages — should write 0 new messages
|
||||||
|
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||||
|
|
||||||
|
rows = db.get_messages(agent.session_id)
|
||||||
|
assert len(rows) == 2, f"Expected still 2 messages after second flush, got {len(rows)}"
|
||||||
|
|
||||||
|
def test_flush_writes_incrementally(self):
|
||||||
|
"""Messages added between flushes are written exactly once."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.db"
|
||||||
|
db = SessionDB(db_path=db_path)
|
||||||
|
|
||||||
|
agent = self._make_agent(db)
|
||||||
|
|
||||||
|
conversation_history = []
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# First flush — 1 message
|
||||||
|
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||||
|
rows = db.get_messages(agent.session_id)
|
||||||
|
assert len(rows) == 1
|
||||||
|
|
||||||
|
# Add more messages
|
||||||
|
messages.append({"role": "assistant", "content": "hi there"})
|
||||||
|
messages.append({"role": "user", "content": "follow up"})
|
||||||
|
|
||||||
|
# Second flush — should write only 2 new messages
|
||||||
|
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||||
|
rows = db.get_messages(agent.session_id)
|
||||||
|
assert len(rows) == 3, f"Expected 3 total messages, got {len(rows)}"
|
||||||
|
|
||||||
|
def test_persist_session_multiple_calls_no_duplication(self):
|
||||||
|
"""Multiple _persist_session calls don't duplicate DB entries."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.db"
|
||||||
|
db = SessionDB(db_path=db_path)
|
||||||
|
|
||||||
|
agent = self._make_agent(db)
|
||||||
|
# Stub out _save_session_log to avoid file I/O
|
||||||
|
agent._save_session_log = MagicMock()
|
||||||
|
|
||||||
|
conversation_history = [{"role": "user", "content": "old"}]
|
||||||
|
messages = list(conversation_history) + [
|
||||||
|
{"role": "user", "content": "q1"},
|
||||||
|
{"role": "assistant", "content": "a1"},
|
||||||
|
{"role": "user", "content": "q2"},
|
||||||
|
{"role": "assistant", "content": "a2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Simulate multiple persist calls (like the agent's many exit paths)
|
||||||
|
for _ in range(5):
|
||||||
|
agent._persist_session(messages, conversation_history)
|
||||||
|
|
||||||
|
rows = db.get_messages(agent.session_id)
|
||||||
|
assert len(rows) == 4, f"Expected 4 messages, got {len(rows)} (duplication bug!)"
|
||||||
|
|
||||||
|
def test_flush_reset_after_compression(self):
|
||||||
|
"""After compression creates a new session, flush index resets."""
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.db"
|
||||||
|
db = SessionDB(db_path=db_path)
|
||||||
|
|
||||||
|
agent = self._make_agent(db)
|
||||||
|
|
||||||
|
# Write some messages
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "msg1"},
|
||||||
|
{"role": "assistant", "content": "reply1"},
|
||||||
|
]
|
||||||
|
agent._flush_messages_to_session_db(messages, [])
|
||||||
|
|
||||||
|
old_session = agent.session_id
|
||||||
|
assert agent._last_flushed_db_idx == 2
|
||||||
|
|
||||||
|
# Simulate what _compress_context does: new session, reset idx
|
||||||
|
agent.session_id = "compressed-session-new"
|
||||||
|
db.create_session(session_id=agent.session_id, source="test")
|
||||||
|
agent._last_flushed_db_idx = 0
|
||||||
|
|
||||||
|
# Now flush compressed messages to new session
|
||||||
|
compressed_messages = [
|
||||||
|
{"role": "user", "content": "summary of conversation"},
|
||||||
|
]
|
||||||
|
agent._flush_messages_to_session_db(compressed_messages, [])
|
||||||
|
|
||||||
|
new_rows = db.get_messages(agent.session_id)
|
||||||
|
assert len(new_rows) == 1
|
||||||
|
|
||||||
|
# Old session should still have its 2 messages
|
||||||
|
old_rows = db.get_messages(old_session)
|
||||||
|
assert len(old_rows) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test: append_to_transcript skip_db parameter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAppendToTranscriptSkipDb:
|
||||||
|
"""Verify skip_db=True writes JSONL but not SQLite."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def store(self, tmp_path):
|
||||||
|
from gateway.config import GatewayConfig
|
||||||
|
from gateway.session import SessionStore
|
||||||
|
config = GatewayConfig()
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
s._db = None # no SQLite for these JSONL-focused tests
|
||||||
|
s._loaded = True
|
||||||
|
return s
|
||||||
|
|
||||||
|
def test_skip_db_writes_jsonl_only(self, store, tmp_path):
|
||||||
|
"""With skip_db=True, message appears in JSONL but not SQLite."""
|
||||||
|
session_id = "test-skip-db"
|
||||||
|
msg = {"role": "assistant", "content": "hello world"}
|
||||||
|
store.append_to_transcript(session_id, msg, skip_db=True)
|
||||||
|
|
||||||
|
# JSONL should have the message
|
||||||
|
jsonl_path = store.get_transcript_path(session_id)
|
||||||
|
assert jsonl_path.exists()
|
||||||
|
with open(jsonl_path) as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
assert len(lines) == 1
|
||||||
|
parsed = json.loads(lines[0])
|
||||||
|
assert parsed["content"] == "hello world"
|
||||||
|
|
||||||
|
def test_skip_db_prevents_sqlite_write(self, tmp_path):
|
||||||
|
"""With skip_db=True and a real DB, message does NOT appear in SQLite."""
|
||||||
|
from gateway.config import GatewayConfig
|
||||||
|
from gateway.session import SessionStore
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
|
||||||
|
db_path = tmp_path / "test_skip.db"
|
||||||
|
db = SessionDB(db_path=db_path)
|
||||||
|
|
||||||
|
config = GatewayConfig()
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
store._db = db
|
||||||
|
store._loaded = True
|
||||||
|
|
||||||
|
session_id = "test-skip-db-real"
|
||||||
|
db.create_session(session_id=session_id, source="test")
|
||||||
|
|
||||||
|
msg = {"role": "assistant", "content": "hello world"}
|
||||||
|
store.append_to_transcript(session_id, msg, skip_db=True)
|
||||||
|
|
||||||
|
# SQLite should NOT have the message
|
||||||
|
rows = db.get_messages(session_id)
|
||||||
|
assert len(rows) == 0, f"Expected 0 DB rows with skip_db=True, got {len(rows)}"
|
||||||
|
|
||||||
|
# But JSONL should have it
|
||||||
|
jsonl_path = store.get_transcript_path(session_id)
|
||||||
|
with open(jsonl_path) as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
assert len(lines) == 1
|
||||||
|
|
||||||
|
def test_default_writes_both(self, tmp_path):
|
||||||
|
"""Without skip_db, message appears in both JSONL and SQLite."""
|
||||||
|
from gateway.config import GatewayConfig
|
||||||
|
from gateway.session import SessionStore
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
|
||||||
|
db_path = tmp_path / "test_both.db"
|
||||||
|
db = SessionDB(db_path=db_path)
|
||||||
|
|
||||||
|
config = GatewayConfig()
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
store._db = db
|
||||||
|
store._loaded = True
|
||||||
|
|
||||||
|
session_id = "test-default-write"
|
||||||
|
db.create_session(session_id=session_id, source="test")
|
||||||
|
|
||||||
|
msg = {"role": "user", "content": "test message"}
|
||||||
|
store.append_to_transcript(session_id, msg)
|
||||||
|
|
||||||
|
# JSONL should have the message
|
||||||
|
jsonl_path = store.get_transcript_path(session_id)
|
||||||
|
with open(jsonl_path) as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
assert len(lines) == 1
|
||||||
|
|
||||||
|
# SQLite should also have the message
|
||||||
|
rows = db.get_messages(session_id)
|
||||||
|
assert len(rows) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test: _last_flushed_db_idx initialization
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFlushIdxInit:
|
||||||
|
"""Verify _last_flushed_db_idx is properly initialized."""
|
||||||
|
|
||||||
|
def test_init_zero(self):
|
||||||
|
"""Agent starts with _last_flushed_db_idx = 0."""
|
||||||
|
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||||
|
from run_agent import AIAgent
|
||||||
|
agent = AIAgent(
|
||||||
|
model="test/model",
|
||||||
|
quiet_mode=True,
|
||||||
|
skip_context_files=True,
|
||||||
|
skip_memory=True,
|
||||||
|
)
|
||||||
|
assert agent._last_flushed_db_idx == 0
|
||||||
|
|
||||||
|
def test_no_session_db_noop(self):
|
||||||
|
"""Without session_db, flush is a no-op and doesn't crash."""
|
||||||
|
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||||
|
from run_agent import AIAgent
|
||||||
|
agent = AIAgent(
|
||||||
|
model="test/model",
|
||||||
|
quiet_mode=True,
|
||||||
|
skip_context_files=True,
|
||||||
|
skip_memory=True,
|
||||||
|
)
|
||||||
|
messages = [{"role": "user", "content": "test"}]
|
||||||
|
agent._flush_messages_to_session_db(messages, [])
|
||||||
|
# Should not crash, idx should remain 0
|
||||||
|
assert agent._last_flushed_db_idx == 0
|
||||||
|
|
@ -88,7 +88,7 @@ class TestPreToolCheck:
|
||||||
agent = MagicMock()
|
agent = MagicMock()
|
||||||
agent._interrupt_requested = True
|
agent._interrupt_requested = True
|
||||||
agent.log_prefix = ""
|
agent.log_prefix = ""
|
||||||
agent._log_msg_to_db = MagicMock()
|
agent._persist_session = MagicMock()
|
||||||
|
|
||||||
# Import and call the method
|
# Import and call the method
|
||||||
from run_agent import AIAgent
|
from run_agent import AIAgent
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue