fix(discord): persist thread participation across gateway restarts

_bot_participated_threads was an in-memory set — lost on every restart.
After restart, the bot forgot which threads it was active in, requiring
fresh @mentions and potentially creating duplicate threads instead of
continuing existing conversations.

Changes:
- Persist thread IDs to ~/.hermes/discord_threads.json
- Load on adapter init, save on every new thread participation
- _track_thread() replaces direct .add() calls for atomic persist
- Cap at 500 tracked threads to prevent unbounded growth
- /thread slash command also tracks participation
- 7 new tests covering persistence, restart survival, corruption
  recovery, cap enforcement
This commit is contained in:
teknium1 2026-03-17 02:26:34 -07:00
parent 0351e4fa90
commit c8582fc4a2
2 changed files with 139 additions and 4 deletions

View file

@ -10,6 +10,7 @@ Uses discord.py library for:
""" """
import asyncio import asyncio
import json
import logging import logging
import os import os
import struct import struct
@ -18,6 +19,7 @@ import tempfile
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Any from typing import Callable, Dict, List, Optional, Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -434,8 +436,11 @@ class DiscordAdapter(BasePlatformAdapter):
self._voice_input_callback: Optional[Callable] = None # set by run.py self._voice_input_callback: Optional[Callable] = None # set by run.py
self._on_voice_disconnect: Optional[Callable] = None # set by run.py self._on_voice_disconnect: Optional[Callable] = None # set by run.py
# Track threads where the bot has participated so follow-up messages # Track threads where the bot has participated so follow-up messages
# in those threads don't require @mention. # in those threads don't require @mention. Persisted to disk so the
self._bot_participated_threads: set = set() # set survives gateway restarts.
self._bot_participated_threads: set = self._load_participated_threads()
# Cap to prevent unbounded growth (Discord threads get archived).
self._MAX_TRACKED_THREADS = 500
async def connect(self) -> bool: async def connect(self) -> bool:
"""Connect to Discord and start receiving events.""" """Connect to Discord and start receiving events."""
@ -1573,6 +1578,10 @@ class DiscordAdapter(BasePlatformAdapter):
link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**" link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**"
await interaction.followup.send(f"Created thread {link}", ephemeral=True) await interaction.followup.send(f"Created thread {link}", ephemeral=True)
# Track thread participation so follow-ups don't require @mention
if thread_id:
self._track_thread(thread_id)
# If a message was provided, kick off a new Hermes session in the thread # If a message was provided, kick off a new Hermes session in the thread
starter = (message or "").strip() starter = (message or "").strip()
if starter and thread_id: if starter and thread_id:
@ -1798,6 +1807,49 @@ class DiscordAdapter(BasePlatformAdapter):
return f"{parent_name} / {thread_name}" return f"{parent_name} / {thread_name}"
return thread_name return thread_name
# ------------------------------------------------------------------
# Thread participation persistence
# ------------------------------------------------------------------
@staticmethod
def _thread_state_path() -> Path:
"""Path to the persisted thread participation set."""
from hermes_cli.config import get_hermes_home
return get_hermes_home() / "discord_threads.json"
@classmethod
def _load_participated_threads(cls) -> set:
"""Load persisted thread IDs from disk."""
path = cls._thread_state_path()
try:
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, list):
return set(data)
except Exception as e:
logger.debug("Could not load discord thread state: %s", e)
return set()
def _save_participated_threads(self) -> None:
"""Persist the current thread set to disk (best-effort)."""
path = self._thread_state_path()
try:
# Trim to most recent entries if over cap
thread_list = list(self._bot_participated_threads)
if len(thread_list) > self._MAX_TRACKED_THREADS:
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
self._bot_participated_threads = set(thread_list)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(thread_list), encoding="utf-8")
except Exception as e:
logger.debug("Could not save discord thread state: %s", e)
def _track_thread(self, thread_id: str) -> None:
"""Add a thread to the participation set and persist."""
if thread_id not in self._bot_participated_threads:
self._bot_participated_threads.add(thread_id)
self._save_participated_threads()
async def _handle_message(self, message: DiscordMessage) -> None: async def _handle_message(self, message: DiscordMessage) -> None:
"""Handle incoming Discord messages.""" """Handle incoming Discord messages."""
# In server channels (not DMs), require the bot to be @mentioned # In server channels (not DMs), require the bot to be @mentioned
@ -1850,7 +1902,7 @@ class DiscordAdapter(BasePlatformAdapter):
is_thread = True is_thread = True
thread_id = str(thread.id) thread_id = str(thread.id)
auto_threaded_channel = thread auto_threaded_channel = thread
self._bot_participated_threads.add(thread_id) self._track_thread(thread_id)
# Determine message type # Determine message type
msg_type = MessageType.TEXT msg_type = MessageType.TEXT
@ -1954,7 +2006,7 @@ class DiscordAdapter(BasePlatformAdapter):
# Track thread participation so the bot won't require @mention for # Track thread participation so the bot won't require @mention for
# follow-up messages in threads it has already engaged in. # follow-up messages in threads it has already engaged in.
if thread_id: if thread_id:
self._bot_participated_threads.add(thread_id) self._track_thread(thread_id)
await self.handle_message(event) await self.handle_message(event)

View file

@ -0,0 +1,83 @@
"""Tests for Discord thread participation persistence.
Verifies that _bot_participated_threads survives adapter restarts by
being persisted to ~/.hermes/discord_threads.json.
"""
import json
import os
from unittest.mock import patch
import pytest
class TestDiscordThreadPersistence:
"""Thread IDs are saved to disk and reloaded on init."""
def _make_adapter(self, tmp_path):
"""Build a minimal DiscordAdapter with HERMES_HOME pointed at tmp_path."""
from gateway.config import PlatformConfig
from gateway.platforms.discord import DiscordAdapter
config = PlatformConfig(enabled=True, token="test-token")
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
return DiscordAdapter(config=config)
def test_starts_empty_when_no_state_file(self, tmp_path):
adapter = self._make_adapter(tmp_path)
assert adapter._bot_participated_threads == set()
def test_track_thread_persists_to_disk(self, tmp_path):
adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._track_thread("111")
adapter._track_thread("222")
state_file = tmp_path / "discord_threads.json"
assert state_file.exists()
saved = json.loads(state_file.read_text())
assert set(saved) == {"111", "222"}
def test_threads_survive_restart(self, tmp_path):
"""Threads tracked by one adapter instance are visible to the next."""
adapter1 = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter1._track_thread("aaa")
adapter1._track_thread("bbb")
adapter2 = self._make_adapter(tmp_path)
assert "aaa" in adapter2._bot_participated_threads
assert "bbb" in adapter2._bot_participated_threads
def test_duplicate_track_does_not_double_save(self, tmp_path):
adapter = self._make_adapter(tmp_path)
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._track_thread("111")
adapter._track_thread("111") # no-op
saved = json.loads((tmp_path / "discord_threads.json").read_text())
assert saved.count("111") == 1
def test_caps_at_max_tracked_threads(self, tmp_path):
adapter = self._make_adapter(tmp_path)
adapter._MAX_TRACKED_THREADS = 5
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
for i in range(10):
adapter._track_thread(str(i))
assert len(adapter._bot_participated_threads) == 5
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
state_file = tmp_path / "discord_threads.json"
state_file.write_text("not valid json{{{")
adapter = self._make_adapter(tmp_path)
assert adapter._bot_participated_threads == set()
def test_missing_hermes_home_does_not_crash(self, tmp_path):
"""Load/save tolerate missing directories."""
fake_home = tmp_path / "nonexistent" / "deep"
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
from gateway.platforms.discord import DiscordAdapter
# _load should return empty set, not crash
threads = DiscordAdapter._load_participated_threads()
assert threads == set()