feat(gateway): proactive async memory flush on session expiry
Previously, when a session expired (idle/daily reset), the memory flush ran synchronously inside get_or_create_session — blocking the user's message for 10-60s while an LLM call saved memories. Now a background watcher task (_session_expiry_watcher) runs every 5 min, detects expired sessions, and flushes memories proactively in a thread pool. By the time the user sends their next message, memories are already saved and the response is immediate. Changes: - Add _is_session_expired(entry) to SessionStore — works from entry alone without needing a SessionSource - Add _pre_flushed_sessions set to track already-flushed sessions - Remove sync _on_auto_reset callback from get_or_create_session - Refactor flush into _flush_memories_for_session (sync worker) + _async_flush_memories (thread pool wrapper) - Add _session_expiry_watcher background task, started in start() - Simplify /reset command to use shared fire-and-forget flush - Add 10 tests for expiry detection, callback removal, tracking
This commit is contained in:
parent
e64d646bad
commit
d80c30cc92
3 changed files with 282 additions and 42 deletions
|
|
@ -178,7 +178,6 @@ class GatewayRunner:
|
||||||
self.session_store = SessionStore(
|
self.session_store = SessionStore(
|
||||||
self.config.sessions_dir, self.config,
|
self.config.sessions_dir, self.config,
|
||||||
has_active_processes_fn=lambda key: process_registry.has_active_for_session(key),
|
has_active_processes_fn=lambda key: process_registry.has_active_for_session(key),
|
||||||
on_auto_reset=self._flush_memories_before_reset,
|
|
||||||
)
|
)
|
||||||
self.delivery_router = DeliveryRouter(self.config)
|
self.delivery_router = DeliveryRouter(self.config)
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
@ -209,15 +208,14 @@ class GatewayRunner:
|
||||||
from gateway.hooks import HookRegistry
|
from gateway.hooks import HookRegistry
|
||||||
self.hooks = HookRegistry()
|
self.hooks = HookRegistry()
|
||||||
|
|
||||||
def _flush_memories_before_reset(self, old_entry):
|
def _flush_memories_for_session(self, old_session_id: str):
|
||||||
"""Prompt the agent to save memories/skills before an auto-reset.
|
"""Prompt the agent to save memories/skills before context is lost.
|
||||||
|
|
||||||
Called synchronously by SessionStore before destroying an expired session.
|
Synchronous worker — meant to be called via run_in_executor from
|
||||||
Loads the transcript, gives the agent a real turn with memory + skills
|
an async context so it doesn't block the event loop.
|
||||||
tools, and explicitly asks it to preserve anything worth keeping.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
history = self.session_store.load_transcript(old_entry.session_id)
|
history = self.session_store.load_transcript(old_session_id)
|
||||||
if not history or len(history) < 4:
|
if not history or len(history) < 4:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -231,7 +229,7 @@ class GatewayRunner:
|
||||||
max_iterations=8,
|
max_iterations=8,
|
||||||
quiet_mode=True,
|
quiet_mode=True,
|
||||||
enabled_toolsets=["memory", "skills"],
|
enabled_toolsets=["memory", "skills"],
|
||||||
session_id=old_entry.session_id,
|
session_id=old_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build conversation history from transcript
|
# Build conversation history from transcript
|
||||||
|
|
@ -260,9 +258,14 @@ class GatewayRunner:
|
||||||
user_message=flush_prompt,
|
user_message=flush_prompt,
|
||||||
conversation_history=msgs,
|
conversation_history=msgs,
|
||||||
)
|
)
|
||||||
logger.info("Pre-reset save completed for session %s", old_entry.session_id)
|
logger.info("Pre-reset memory flush completed for session %s", old_session_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Pre-reset save failed for session %s: %s", old_entry.session_id, e)
|
logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e)
|
||||||
|
|
||||||
|
async def _async_flush_memories(self, old_session_id: str):
|
||||||
|
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
||||||
|
|
@ -464,10 +467,50 @@ class GatewayRunner:
|
||||||
# Check if we're restarting after a /update command
|
# Check if we're restarting after a /update command
|
||||||
await self._send_update_notification()
|
await self._send_update_notification()
|
||||||
|
|
||||||
|
# Start background session expiry watcher for proactive memory flushing
|
||||||
|
asyncio.create_task(self._session_expiry_watcher())
|
||||||
|
|
||||||
logger.info("Press Ctrl+C to stop")
|
logger.info("Press Ctrl+C to stop")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def _session_expiry_watcher(self, interval: int = 300):
|
||||||
|
"""Background task that proactively flushes memories for expired sessions.
|
||||||
|
|
||||||
|
Runs every `interval` seconds (default 5 min). For each session that
|
||||||
|
has expired according to its reset policy, flushes memories in a thread
|
||||||
|
pool and marks the session so it won't be flushed again.
|
||||||
|
|
||||||
|
This means memories are already saved by the time the user sends their
|
||||||
|
next message, so there's no blocking delay.
|
||||||
|
"""
|
||||||
|
await asyncio.sleep(60) # initial delay — let the gateway fully start
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
self.session_store._ensure_loaded()
|
||||||
|
for key, entry in list(self.session_store._entries.items()):
|
||||||
|
if entry.session_id in self.session_store._pre_flushed_sessions:
|
||||||
|
continue # already flushed this session
|
||||||
|
if not self.session_store._is_session_expired(entry):
|
||||||
|
continue # session still active
|
||||||
|
# Session has expired — flush memories in the background
|
||||||
|
logger.info(
|
||||||
|
"Session %s expired (key=%s), flushing memories proactively",
|
||||||
|
entry.session_id, key,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await self._async_flush_memories(entry.session_id)
|
||||||
|
self.session_store._pre_flushed_sessions.add(entry.session_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Session expiry watcher error: %s", e)
|
||||||
|
# Sleep in small increments so we can stop quickly
|
||||||
|
for _ in range(interval):
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the gateway and disconnect all adapters."""
|
"""Stop the gateway and disconnect all adapters."""
|
||||||
logger.info("Stopping gateway...")
|
logger.info("Stopping gateway...")
|
||||||
|
|
@ -1012,33 +1055,12 @@ class GatewayRunner:
|
||||||
# Get existing session key
|
# Get existing session key
|
||||||
session_key = self.session_store._generate_session_key(source)
|
session_key = self.session_store._generate_session_key(source)
|
||||||
|
|
||||||
# Memory flush before reset: load the old transcript and let a
|
# Flush memories in the background (fire-and-forget) so the user
|
||||||
# temporary agent save memories before the session is wiped.
|
# gets the "Session reset!" response immediately.
|
||||||
try:
|
try:
|
||||||
old_entry = self.session_store._entries.get(session_key)
|
old_entry = self.session_store._entries.get(session_key)
|
||||||
if old_entry:
|
if old_entry:
|
||||||
old_history = self.session_store.load_transcript(old_entry.session_id)
|
asyncio.create_task(self._async_flush_memories(old_entry.session_id))
|
||||||
if old_history:
|
|
||||||
from run_agent import AIAgent
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
_flush_kwargs = _resolve_runtime_agent_kwargs()
|
|
||||||
def _do_flush():
|
|
||||||
tmp_agent = AIAgent(
|
|
||||||
**_flush_kwargs,
|
|
||||||
max_iterations=5,
|
|
||||||
quiet_mode=True,
|
|
||||||
enabled_toolsets=["memory"],
|
|
||||||
session_id=old_entry.session_id,
|
|
||||||
)
|
|
||||||
# Build simple message list from transcript
|
|
||||||
msgs = []
|
|
||||||
for m in old_history:
|
|
||||||
role = m.get("role")
|
|
||||||
content = m.get("content")
|
|
||||||
if role in ("user", "assistant") and content:
|
|
||||||
msgs.append({"role": role, "content": content})
|
|
||||||
tmp_agent.flush_memories(msgs)
|
|
||||||
await loop.run_in_executor(None, _do_flush)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Gateway memory flush on reset failed: %s", e)
|
logger.debug("Gateway memory flush on reset failed: %s", e)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -311,7 +311,9 @@ class SessionStore:
|
||||||
self._entries: Dict[str, SessionEntry] = {}
|
self._entries: Dict[str, SessionEntry] = {}
|
||||||
self._loaded = False
|
self._loaded = False
|
||||||
self._has_active_processes_fn = has_active_processes_fn
|
self._has_active_processes_fn = has_active_processes_fn
|
||||||
self._on_auto_reset = on_auto_reset # callback(old_entry) before auto-reset
|
# on_auto_reset is deprecated — memory flush now runs proactively
|
||||||
|
# via the background session expiry watcher in GatewayRunner.
|
||||||
|
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
|
||||||
|
|
||||||
# Initialize SQLite session database
|
# Initialize SQLite session database
|
||||||
self._db = None
|
self._db = None
|
||||||
|
|
@ -353,6 +355,44 @@ class SessionStore:
|
||||||
"""Generate a session key from a source."""
|
"""Generate a session key from a source."""
|
||||||
return build_session_key(source)
|
return build_session_key(source)
|
||||||
|
|
||||||
|
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
||||||
|
"""Check if a session has expired based on its reset policy.
|
||||||
|
|
||||||
|
Works from the entry alone — no SessionSource needed.
|
||||||
|
Used by the background expiry watcher to proactively flush memories.
|
||||||
|
Sessions with active background processes are never considered expired.
|
||||||
|
"""
|
||||||
|
if self._has_active_processes_fn:
|
||||||
|
if self._has_active_processes_fn(entry.session_key):
|
||||||
|
return False
|
||||||
|
|
||||||
|
policy = self.config.get_reset_policy(
|
||||||
|
platform=entry.platform,
|
||||||
|
session_type=entry.chat_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if policy.mode == "none":
|
||||||
|
return False
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
if policy.mode in ("idle", "both"):
|
||||||
|
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||||
|
if now > idle_deadline:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if policy.mode in ("daily", "both"):
|
||||||
|
today_reset = now.replace(
|
||||||
|
hour=policy.at_hour,
|
||||||
|
minute=0, second=0, microsecond=0,
|
||||||
|
)
|
||||||
|
if now.hour < policy.at_hour:
|
||||||
|
today_reset -= timedelta(days=1)
|
||||||
|
if entry.updated_at < today_reset:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a session should be reset based on policy.
|
Check if a session should be reset based on policy.
|
||||||
|
|
@ -439,13 +479,11 @@ class SessionStore:
|
||||||
self._save()
|
self._save()
|
||||||
return entry
|
return entry
|
||||||
else:
|
else:
|
||||||
# Session is being auto-reset — flush memories before destroying
|
# Session is being auto-reset. The background expiry watcher
|
||||||
|
# should have already flushed memories proactively; discard
|
||||||
|
# the marker so it doesn't accumulate.
|
||||||
was_auto_reset = True
|
was_auto_reset = True
|
||||||
if self._on_auto_reset:
|
self._pre_flushed_sessions.discard(entry.session_id)
|
||||||
try:
|
|
||||||
self._on_auto_reset(entry)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("Auto-reset callback failed: %s", e)
|
|
||||||
if self._db:
|
if self._db:
|
||||||
try:
|
try:
|
||||||
self._db.end_session(entry.session_id, "session_reset")
|
self._db.end_session(entry.session_id, "session_reset")
|
||||||
|
|
|
||||||
180
tests/gateway/test_async_memory_flush.py
Normal file
180
tests/gateway/test_async_memory_flush.py
Normal file
|
|
@ -0,0 +1,180 @@
|
||||||
|
"""Tests for proactive memory flush on session expiry.
|
||||||
|
|
||||||
|
Verifies that:
|
||||||
|
1. _is_session_expired() works from a SessionEntry alone (no source needed)
|
||||||
|
2. The sync callback is no longer called in get_or_create_session
|
||||||
|
3. _pre_flushed_sessions tracking works correctly
|
||||||
|
4. The background watcher can detect expired sessions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from gateway.config import Platform, GatewayConfig, SessionResetPolicy
|
||||||
|
from gateway.session import SessionSource, SessionStore, SessionEntry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def idle_store(tmp_path):
|
||||||
|
"""SessionStore with a 60-minute idle reset policy."""
|
||||||
|
config = GatewayConfig(
|
||||||
|
default_reset_policy=SessionResetPolicy(mode="idle", idle_minutes=60),
|
||||||
|
)
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
s._db = None
|
||||||
|
s._loaded = True
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def no_reset_store(tmp_path):
|
||||||
|
"""SessionStore with no reset policy (mode=none)."""
|
||||||
|
config = GatewayConfig(
|
||||||
|
default_reset_policy=SessionResetPolicy(mode="none"),
|
||||||
|
)
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
s._db = None
|
||||||
|
s._loaded = True
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsSessionExpired:
|
||||||
|
"""_is_session_expired should detect expiry from entry alone."""
|
||||||
|
|
||||||
|
def test_idle_session_expired(self, idle_store):
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="agent:main:telegram:dm",
|
||||||
|
session_id="sid_1",
|
||||||
|
created_at=datetime.now() - timedelta(hours=3),
|
||||||
|
updated_at=datetime.now() - timedelta(minutes=120),
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
assert idle_store._is_session_expired(entry) is True
|
||||||
|
|
||||||
|
def test_active_session_not_expired(self, idle_store):
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="agent:main:telegram:dm",
|
||||||
|
session_id="sid_2",
|
||||||
|
created_at=datetime.now() - timedelta(hours=1),
|
||||||
|
updated_at=datetime.now() - timedelta(minutes=10),
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
assert idle_store._is_session_expired(entry) is False
|
||||||
|
|
||||||
|
def test_none_mode_never_expires(self, no_reset_store):
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="agent:main:telegram:dm",
|
||||||
|
session_id="sid_3",
|
||||||
|
created_at=datetime.now() - timedelta(days=30),
|
||||||
|
updated_at=datetime.now() - timedelta(days=30),
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
assert no_reset_store._is_session_expired(entry) is False
|
||||||
|
|
||||||
|
def test_active_processes_prevent_expiry(self, idle_store):
|
||||||
|
"""Sessions with active background processes should never expire."""
|
||||||
|
idle_store._has_active_processes_fn = lambda key: True
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="agent:main:telegram:dm",
|
||||||
|
session_id="sid_4",
|
||||||
|
created_at=datetime.now() - timedelta(hours=5),
|
||||||
|
updated_at=datetime.now() - timedelta(hours=5),
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
assert idle_store._is_session_expired(entry) is False
|
||||||
|
|
||||||
|
def test_daily_mode_expired(self, tmp_path):
|
||||||
|
"""Daily mode should expire sessions from before today's reset hour."""
|
||||||
|
config = GatewayConfig(
|
||||||
|
default_reset_policy=SessionResetPolicy(mode="daily", at_hour=4),
|
||||||
|
)
|
||||||
|
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||||
|
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||||
|
store._db = None
|
||||||
|
store._loaded = True
|
||||||
|
|
||||||
|
entry = SessionEntry(
|
||||||
|
session_key="agent:main:telegram:dm",
|
||||||
|
session_id="sid_5",
|
||||||
|
created_at=datetime.now() - timedelta(days=2),
|
||||||
|
updated_at=datetime.now() - timedelta(days=2),
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
assert store._is_session_expired(entry) is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetOrCreateSessionNoCallback:
|
||||||
|
"""get_or_create_session should NOT call a sync flush callback."""
|
||||||
|
|
||||||
|
def test_auto_reset_cleans_pre_flushed_marker(self, idle_store):
|
||||||
|
"""When a session auto-resets, the pre_flushed marker should be discarded."""
|
||||||
|
source = SessionSource(
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_id="123",
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
# Create initial session
|
||||||
|
entry1 = idle_store.get_or_create_session(source)
|
||||||
|
old_sid = entry1.session_id
|
||||||
|
|
||||||
|
# Simulate the watcher having flushed it
|
||||||
|
idle_store._pre_flushed_sessions.add(old_sid)
|
||||||
|
|
||||||
|
# Simulate the session going idle
|
||||||
|
entry1.updated_at = datetime.now() - timedelta(minutes=120)
|
||||||
|
idle_store._save()
|
||||||
|
|
||||||
|
# Next call should auto-reset
|
||||||
|
entry2 = idle_store.get_or_create_session(source)
|
||||||
|
assert entry2.session_id != old_sid
|
||||||
|
assert entry2.was_auto_reset is True
|
||||||
|
|
||||||
|
# The old session_id should be removed from pre_flushed
|
||||||
|
assert old_sid not in idle_store._pre_flushed_sessions
|
||||||
|
|
||||||
|
def test_no_sync_callback_invoked(self, idle_store):
|
||||||
|
"""No synchronous callback should block during auto-reset."""
|
||||||
|
source = SessionSource(
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_id="123",
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
entry1 = idle_store.get_or_create_session(source)
|
||||||
|
entry1.updated_at = datetime.now() - timedelta(minutes=120)
|
||||||
|
idle_store._save()
|
||||||
|
|
||||||
|
# Verify no _on_auto_reset attribute
|
||||||
|
assert not hasattr(idle_store, '_on_auto_reset')
|
||||||
|
|
||||||
|
# This should NOT block (no sync LLM call)
|
||||||
|
entry2 = idle_store.get_or_create_session(source)
|
||||||
|
assert entry2.was_auto_reset is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreFlushedSessionsTracking:
|
||||||
|
"""The _pre_flushed_sessions set should prevent double-flushing."""
|
||||||
|
|
||||||
|
def test_starts_empty(self, idle_store):
|
||||||
|
assert len(idle_store._pre_flushed_sessions) == 0
|
||||||
|
|
||||||
|
def test_add_and_check(self, idle_store):
|
||||||
|
idle_store._pre_flushed_sessions.add("sid_old")
|
||||||
|
assert "sid_old" in idle_store._pre_flushed_sessions
|
||||||
|
assert "sid_other" not in idle_store._pre_flushed_sessions
|
||||||
|
|
||||||
|
def test_discard_on_reset(self, idle_store):
|
||||||
|
"""discard should remove without raising if not present."""
|
||||||
|
idle_store._pre_flushed_sessions.add("sid_a")
|
||||||
|
idle_store._pre_flushed_sessions.discard("sid_a")
|
||||||
|
assert "sid_a" not in idle_store._pre_flushed_sessions
|
||||||
|
# discard on non-existent should not raise
|
||||||
|
idle_store._pre_flushed_sessions.discard("sid_nonexistent")
|
||||||
Loading…
Add table
Add a link
Reference in a new issue