fix(gateway): persist transcript changes in /retry, /undo and fix /reset
/retry and /undo set session_entry.conversation_history which does not exist on SessionEntry. The truncated history was never written to disk, so the next message reload picked up the full unmodified transcript. Added SessionStore.rewrite_transcript() that persists changes to both the JSONL file and SQLite database, and updated both commands to use it. /reset accessed self.session_store._sessions which does not exist on SessionStore (the correct attribute is _entries). Also replaced the hand-coded session key with _generate_session_key() to fix WhatsApp DM sessions using the wrong key format. Closes #210
This commit is contained in:
parent
6366177118
commit
b7f8a17c24
3 changed files with 95 additions and 5 deletions
|
|
@ -901,13 +901,12 @@ class GatewayRunner:
|
||||||
source = event.source
|
source = event.source
|
||||||
|
|
||||||
# Get existing session key
|
# Get existing session key
|
||||||
session_key = f"agent:main:{source.platform.value}:" + \
|
session_key = self.session_store._generate_session_key(source)
|
||||||
(f"dm" if source.chat_type == "dm" else f"{source.chat_type}:{source.chat_id}")
|
|
||||||
|
|
||||||
# Memory flush before reset: load the old transcript and let a
|
# Memory flush before reset: load the old transcript and let a
|
||||||
# temporary agent save memories before the session is wiped.
|
# temporary agent save memories before the session is wiped.
|
||||||
try:
|
try:
|
||||||
old_entry = self.session_store._sessions.get(session_key)
|
old_entry = self.session_store._entries.get(session_key)
|
||||||
if old_entry:
|
if old_entry:
|
||||||
old_history = self.session_store.load_transcript(old_entry.session_id)
|
old_history = self.session_store.load_transcript(old_entry.session_id)
|
||||||
if old_history:
|
if old_history:
|
||||||
|
|
@ -1135,7 +1134,7 @@ class GatewayRunner:
|
||||||
|
|
||||||
# Truncate history to before the last user message
|
# Truncate history to before the last user message
|
||||||
truncated = history[:last_user_idx]
|
truncated = history[:last_user_idx]
|
||||||
session_entry.conversation_history = truncated
|
self.session_store.rewrite_transcript(session_entry.session_id, truncated)
|
||||||
|
|
||||||
# Re-send by creating a fake text event with the old message
|
# Re-send by creating a fake text event with the old message
|
||||||
retry_event = MessageEvent(
|
retry_event = MessageEvent(
|
||||||
|
|
@ -1167,7 +1166,7 @@ class GatewayRunner:
|
||||||
|
|
||||||
removed_msg = history[last_user_idx].get("content", "")
|
removed_msg = history[last_user_idx].get("content", "")
|
||||||
removed_count = len(history) - last_user_idx
|
removed_count = len(history) - last_user_idx
|
||||||
session_entry.conversation_history = history[:last_user_idx]
|
self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx])
|
||||||
|
|
||||||
preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg
|
preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg
|
||||||
return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\""
|
return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\""
|
||||||
|
|
|
||||||
|
|
@ -567,6 +567,37 @@ class SessionStore:
|
||||||
with open(transcript_path, "a") as f:
|
with open(transcript_path, "a") as f:
|
||||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
|
||||||
|
"""Replace a session's transcript with the given messages."""
|
||||||
|
# Rewrite SQLite
|
||||||
|
if self._db:
|
||||||
|
try:
|
||||||
|
self._db._conn.execute(
|
||||||
|
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||||
|
)
|
||||||
|
self._db._conn.execute(
|
||||||
|
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||||
|
(session_id,),
|
||||||
|
)
|
||||||
|
self._db._conn.commit()
|
||||||
|
for msg in messages:
|
||||||
|
self._db.append_message(
|
||||||
|
session_id=session_id,
|
||||||
|
role=msg.get("role", "unknown"),
|
||||||
|
content=msg.get("content"),
|
||||||
|
tool_name=msg.get("tool_name"),
|
||||||
|
tool_calls=msg.get("tool_calls"),
|
||||||
|
tool_call_id=msg.get("tool_call_id"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Session DB rewrite failed: %s", e)
|
||||||
|
|
||||||
|
# Rewrite legacy JSONL
|
||||||
|
transcript_path = self.get_transcript_path(session_id)
|
||||||
|
with open(transcript_path, "w") as f:
|
||||||
|
for msg in messages:
|
||||||
|
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
||||||
"""Load all messages from a session's transcript."""
|
"""Load all messages from a session's transcript."""
|
||||||
# Try SQLite first
|
# Try SQLite first
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,13 @@
|
||||||
"""Tests for gateway session management."""
|
"""Tests for gateway session management."""
|
||||||
|
|
||||||
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
|
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
|
||||||
from gateway.session import (
|
from gateway.session import (
|
||||||
SessionSource,
|
SessionSource,
|
||||||
|
SessionStore,
|
||||||
build_session_context,
|
build_session_context,
|
||||||
build_session_context_prompt,
|
build_session_context_prompt,
|
||||||
)
|
)
|
||||||
|
|
@ -199,3 +203,59 @@ class TestBuildSessionContextPrompt:
|
||||||
prompt = build_session_context_prompt(ctx)
|
prompt = build_session_context_prompt(ctx)
|
||||||
|
|
||||||
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()
|
assert "WhatsApp" in prompt or "whatsapp" in prompt.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionStoreRewriteTranscript:
|
||||||
|
"""Regression: /retry and /undo must persist truncated history to disk."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def store(self, tmp_path):
|
||||||
|
config = GatewayConfig()
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
s._db = None # no SQLite for these tests
|
||||||
|
s._loaded = True
|
||||||
|
return s
|
||||||
|
|
||||||
|
def test_rewrite_replaces_jsonl(self, store, tmp_path):
|
||||||
|
session_id = "test_session_1"
|
||||||
|
# Write initial transcript
|
||||||
|
for msg in [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "hi"},
|
||||||
|
{"role": "user", "content": "undo this"},
|
||||||
|
{"role": "assistant", "content": "ok"},
|
||||||
|
]:
|
||||||
|
store.append_to_transcript(session_id, msg)
|
||||||
|
|
||||||
|
# Rewrite with truncated history
|
||||||
|
store.rewrite_transcript(session_id, [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "hi"},
|
||||||
|
])
|
||||||
|
|
||||||
|
reloaded = store.load_transcript(session_id)
|
||||||
|
assert len(reloaded) == 2
|
||||||
|
assert reloaded[0]["content"] == "hello"
|
||||||
|
assert reloaded[1]["content"] == "hi"
|
||||||
|
|
||||||
|
def test_rewrite_with_empty_list(self, store):
|
||||||
|
session_id = "test_session_2"
|
||||||
|
store.append_to_transcript(session_id, {"role": "user", "content": "hi"})
|
||||||
|
|
||||||
|
store.rewrite_transcript(session_id, [])
|
||||||
|
|
||||||
|
reloaded = store.load_transcript(session_id)
|
||||||
|
assert reloaded == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionStoreEntriesAttribute:
|
||||||
|
"""Regression: /reset must access _entries, not _sessions."""
|
||||||
|
|
||||||
|
def test_entries_attribute_exists(self):
|
||||||
|
config = GatewayConfig()
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
store = SessionStore(sessions_dir=Path("/tmp"), config=config)
|
||||||
|
store._loaded = True
|
||||||
|
assert hasattr(store, "_entries")
|
||||||
|
assert not hasattr(store, "_sessions")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue