fix(gateway): persist transcript changes in /retry, /undo and fix /reset

/retry and /undo set session_entry.conversation_history which does not
exist on SessionEntry. The truncated history was never written to disk,
so the next message reload picked up the full unmodified transcript.

Added SessionStore.rewrite_transcript() that persists changes to both
the JSONL file and SQLite database, and updated both commands to use it.

/reset accessed self.session_store._sessions which does not exist on
SessionStore (the correct attribute is _entries). Also replaced the
hand-coded session key with _generate_session_key() to fix WhatsApp DM
sessions using the wrong key format.

Closes #210
This commit is contained in:
Farukest 2026-03-01 01:12:58 +03:00
parent 6366177118
commit b7f8a17c24
No known key found for this signature in database
GPG key ID: 73E2756B3FFF5241
3 changed files with 95 additions and 5 deletions

View file

@ -901,13 +901,12 @@ class GatewayRunner:
source = event.source source = event.source
# Get existing session key # Get existing session key
session_key = f"agent:main:{source.platform.value}:" + \ session_key = self.session_store._generate_session_key(source)
(f"dm" if source.chat_type == "dm" else f"{source.chat_type}:{source.chat_id}")
# Memory flush before reset: load the old transcript and let a # Memory flush before reset: load the old transcript and let a
# temporary agent save memories before the session is wiped. # temporary agent save memories before the session is wiped.
try: try:
old_entry = self.session_store._sessions.get(session_key) old_entry = self.session_store._entries.get(session_key)
if old_entry: if old_entry:
old_history = self.session_store.load_transcript(old_entry.session_id) old_history = self.session_store.load_transcript(old_entry.session_id)
if old_history: if old_history:
@ -1135,7 +1134,7 @@ class GatewayRunner:
# Truncate history to before the last user message # Truncate history to before the last user message
truncated = history[:last_user_idx] truncated = history[:last_user_idx]
session_entry.conversation_history = truncated self.session_store.rewrite_transcript(session_entry.session_id, truncated)
# Re-send by creating a fake text event with the old message # Re-send by creating a fake text event with the old message
retry_event = MessageEvent( retry_event = MessageEvent(
@ -1167,7 +1166,7 @@ class GatewayRunner:
removed_msg = history[last_user_idx].get("content", "") removed_msg = history[last_user_idx].get("content", "")
removed_count = len(history) - last_user_idx removed_count = len(history) - last_user_idx
session_entry.conversation_history = history[:last_user_idx] self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx])
preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg
return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\"" return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\""

View file

@ -567,6 +567,37 @@ class SessionStore:
with open(transcript_path, "a") as f: with open(transcript_path, "a") as f:
f.write(json.dumps(message, ensure_ascii=False) + "\n") f.write(json.dumps(message, ensure_ascii=False) + "\n")
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
"""Replace a session's transcript with the given messages."""
# Rewrite SQLite
if self._db:
try:
self._db._conn.execute(
"DELETE FROM messages WHERE session_id = ?", (session_id,)
)
self._db._conn.execute(
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
(session_id,),
)
self._db._conn.commit()
for msg in messages:
self._db.append_message(
session_id=session_id,
role=msg.get("role", "unknown"),
content=msg.get("content"),
tool_name=msg.get("tool_name"),
tool_calls=msg.get("tool_calls"),
tool_call_id=msg.get("tool_call_id"),
)
except Exception as e:
logger.debug("Session DB rewrite failed: %s", e)
# Rewrite legacy JSONL
transcript_path = self.get_transcript_path(session_id)
with open(transcript_path, "w") as f:
for msg in messages:
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]: def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
"""Load all messages from a session's transcript.""" """Load all messages from a session's transcript."""
# Try SQLite first # Try SQLite first

View file

@ -1,9 +1,13 @@
"""Tests for gateway session management.""" """Tests for gateway session management."""
import json
import pytest import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
from gateway.session import ( from gateway.session import (
SessionSource, SessionSource,
SessionStore,
build_session_context, build_session_context,
build_session_context_prompt, build_session_context_prompt,
) )
@ -199,3 +203,59 @@ class TestBuildSessionContextPrompt:
prompt = build_session_context_prompt(ctx) prompt = build_session_context_prompt(ctx)
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower() assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()
class TestSessionStoreRewriteTranscript:
"""Regression: /retry and /undo must persist truncated history to disk."""
@pytest.fixture()
def store(self, tmp_path):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None # no SQLite for these tests
s._loaded = True
return s
def test_rewrite_replaces_jsonl(self, store, tmp_path):
session_id = "test_session_1"
# Write initial transcript
for msg in [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "undo this"},
{"role": "assistant", "content": "ok"},
]:
store.append_to_transcript(session_id, msg)
# Rewrite with truncated history
store.rewrite_transcript(session_id, [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
])
reloaded = store.load_transcript(session_id)
assert len(reloaded) == 2
assert reloaded[0]["content"] == "hello"
assert reloaded[1]["content"] == "hi"
def test_rewrite_with_empty_list(self, store):
session_id = "test_session_2"
store.append_to_transcript(session_id, {"role": "user", "content": "hi"})
store.rewrite_transcript(session_id, [])
reloaded = store.load_transcript(session_id)
assert reloaded == []
class TestSessionStoreEntriesAttribute:
"""Regression: /reset must access _entries, not _sessions."""
def test_entries_attribute_exists(self):
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=Path("/tmp"), config=config)
store._loaded = True
assert hasattr(store, "_entries")
assert not hasattr(store, "_sessions")