fix(discord): persist thread participation across gateway restarts
_bot_participated_threads was an in-memory set — lost on every restart. After restart, the bot forgot which threads it was active in, requiring fresh @mentions and potentially creating duplicate threads instead of continuing existing conversations. Changes: - Persist thread IDs to ~/.hermes/discord_threads.json - Load on adapter init, save on every new thread participation - _track_thread() replaces direct .add() calls for atomic persist - Cap at 500 tracked threads to prevent unbounded growth - /thread slash command also tracks participation - 7 new tests covering persistence, restart survival, corruption recovery, cap enforcement
This commit is contained in:
parent
0351e4fa90
commit
c8582fc4a2
2 changed files with 139 additions and 4 deletions
|
|
@ -10,6 +10,7 @@ Uses discord.py library for:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
|
|
@ -18,6 +19,7 @@ import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Any
|
from typing import Callable, Dict, List, Optional, Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -434,8 +436,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
||||||
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
|
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
|
||||||
# Track threads where the bot has participated so follow-up messages
|
# Track threads where the bot has participated so follow-up messages
|
||||||
# in those threads don't require @mention.
|
# in those threads don't require @mention. Persisted to disk so the
|
||||||
self._bot_participated_threads: set = set()
|
# set survives gateway restarts.
|
||||||
|
self._bot_participated_threads: set = self._load_participated_threads()
|
||||||
|
# Cap to prevent unbounded growth (Discord threads get archived).
|
||||||
|
self._MAX_TRACKED_THREADS = 500
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
async def connect(self) -> bool:
|
||||||
"""Connect to Discord and start receiving events."""
|
"""Connect to Discord and start receiving events."""
|
||||||
|
|
@ -1573,6 +1578,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**"
|
link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**"
|
||||||
await interaction.followup.send(f"Created thread {link}", ephemeral=True)
|
await interaction.followup.send(f"Created thread {link}", ephemeral=True)
|
||||||
|
|
||||||
|
# Track thread participation so follow-ups don't require @mention
|
||||||
|
if thread_id:
|
||||||
|
self._track_thread(thread_id)
|
||||||
|
|
||||||
# If a message was provided, kick off a new Hermes session in the thread
|
# If a message was provided, kick off a new Hermes session in the thread
|
||||||
starter = (message or "").strip()
|
starter = (message or "").strip()
|
||||||
if starter and thread_id:
|
if starter and thread_id:
|
||||||
|
|
@ -1798,6 +1807,49 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
return f"{parent_name} / {thread_name}"
|
return f"{parent_name} / {thread_name}"
|
||||||
return thread_name
|
return thread_name
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Thread participation persistence
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _thread_state_path() -> Path:
|
||||||
|
"""Path to the persisted thread participation set."""
|
||||||
|
from hermes_cli.config import get_hermes_home
|
||||||
|
return get_hermes_home() / "discord_threads.json"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_participated_threads(cls) -> set:
|
||||||
|
"""Load persisted thread IDs from disk."""
|
||||||
|
path = cls._thread_state_path()
|
||||||
|
try:
|
||||||
|
if path.exists():
|
||||||
|
data = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
if isinstance(data, list):
|
||||||
|
return set(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Could not load discord thread state: %s", e)
|
||||||
|
return set()
|
||||||
|
|
||||||
|
def _save_participated_threads(self) -> None:
|
||||||
|
"""Persist the current thread set to disk (best-effort)."""
|
||||||
|
path = self._thread_state_path()
|
||||||
|
try:
|
||||||
|
# Trim to most recent entries if over cap
|
||||||
|
thread_list = list(self._bot_participated_threads)
|
||||||
|
if len(thread_list) > self._MAX_TRACKED_THREADS:
|
||||||
|
thread_list = thread_list[-self._MAX_TRACKED_THREADS:]
|
||||||
|
self._bot_participated_threads = set(thread_list)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Could not save discord thread state: %s", e)
|
||||||
|
|
||||||
|
def _track_thread(self, thread_id: str) -> None:
|
||||||
|
"""Add a thread to the participation set and persist."""
|
||||||
|
if thread_id not in self._bot_participated_threads:
|
||||||
|
self._bot_participated_threads.add(thread_id)
|
||||||
|
self._save_participated_threads()
|
||||||
|
|
||||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||||
"""Handle incoming Discord messages."""
|
"""Handle incoming Discord messages."""
|
||||||
# In server channels (not DMs), require the bot to be @mentioned
|
# In server channels (not DMs), require the bot to be @mentioned
|
||||||
|
|
@ -1850,7 +1902,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
is_thread = True
|
is_thread = True
|
||||||
thread_id = str(thread.id)
|
thread_id = str(thread.id)
|
||||||
auto_threaded_channel = thread
|
auto_threaded_channel = thread
|
||||||
self._bot_participated_threads.add(thread_id)
|
self._track_thread(thread_id)
|
||||||
|
|
||||||
# Determine message type
|
# Determine message type
|
||||||
msg_type = MessageType.TEXT
|
msg_type = MessageType.TEXT
|
||||||
|
|
@ -1954,7 +2006,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
# Track thread participation so the bot won't require @mention for
|
# Track thread participation so the bot won't require @mention for
|
||||||
# follow-up messages in threads it has already engaged in.
|
# follow-up messages in threads it has already engaged in.
|
||||||
if thread_id:
|
if thread_id:
|
||||||
self._bot_participated_threads.add(thread_id)
|
self._track_thread(thread_id)
|
||||||
|
|
||||||
await self.handle_message(event)
|
await self.handle_message(event)
|
||||||
|
|
||||||
|
|
|
||||||
83
tests/gateway/test_discord_thread_persistence.py
Normal file
83
tests/gateway/test_discord_thread_persistence.py
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""Tests for Discord thread participation persistence.
|
||||||
|
|
||||||
|
Verifies that _bot_participated_threads survives adapter restarts by
|
||||||
|
being persisted to ~/.hermes/discord_threads.json.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiscordThreadPersistence:
|
||||||
|
"""Thread IDs are saved to disk and reloaded on init."""
|
||||||
|
|
||||||
|
def _make_adapter(self, tmp_path):
|
||||||
|
"""Build a minimal DiscordAdapter with HERMES_HOME pointed at tmp_path."""
|
||||||
|
from gateway.config import PlatformConfig
|
||||||
|
from gateway.platforms.discord import DiscordAdapter
|
||||||
|
|
||||||
|
config = PlatformConfig(enabled=True, token="test-token")
|
||||||
|
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||||
|
return DiscordAdapter(config=config)
|
||||||
|
|
||||||
|
def test_starts_empty_when_no_state_file(self, tmp_path):
|
||||||
|
adapter = self._make_adapter(tmp_path)
|
||||||
|
assert adapter._bot_participated_threads == set()
|
||||||
|
|
||||||
|
def test_track_thread_persists_to_disk(self, tmp_path):
|
||||||
|
adapter = self._make_adapter(tmp_path)
|
||||||
|
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||||
|
adapter._track_thread("111")
|
||||||
|
adapter._track_thread("222")
|
||||||
|
|
||||||
|
state_file = tmp_path / "discord_threads.json"
|
||||||
|
assert state_file.exists()
|
||||||
|
saved = json.loads(state_file.read_text())
|
||||||
|
assert set(saved) == {"111", "222"}
|
||||||
|
|
||||||
|
def test_threads_survive_restart(self, tmp_path):
|
||||||
|
"""Threads tracked by one adapter instance are visible to the next."""
|
||||||
|
adapter1 = self._make_adapter(tmp_path)
|
||||||
|
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||||
|
adapter1._track_thread("aaa")
|
||||||
|
adapter1._track_thread("bbb")
|
||||||
|
|
||||||
|
adapter2 = self._make_adapter(tmp_path)
|
||||||
|
assert "aaa" in adapter2._bot_participated_threads
|
||||||
|
assert "bbb" in adapter2._bot_participated_threads
|
||||||
|
|
||||||
|
def test_duplicate_track_does_not_double_save(self, tmp_path):
|
||||||
|
adapter = self._make_adapter(tmp_path)
|
||||||
|
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||||
|
adapter._track_thread("111")
|
||||||
|
adapter._track_thread("111") # no-op
|
||||||
|
|
||||||
|
saved = json.loads((tmp_path / "discord_threads.json").read_text())
|
||||||
|
assert saved.count("111") == 1
|
||||||
|
|
||||||
|
def test_caps_at_max_tracked_threads(self, tmp_path):
|
||||||
|
adapter = self._make_adapter(tmp_path)
|
||||||
|
adapter._MAX_TRACKED_THREADS = 5
|
||||||
|
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||||
|
for i in range(10):
|
||||||
|
adapter._track_thread(str(i))
|
||||||
|
|
||||||
|
assert len(adapter._bot_participated_threads) == 5
|
||||||
|
|
||||||
|
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
|
||||||
|
state_file = tmp_path / "discord_threads.json"
|
||||||
|
state_file.write_text("not valid json{{{")
|
||||||
|
adapter = self._make_adapter(tmp_path)
|
||||||
|
assert adapter._bot_participated_threads == set()
|
||||||
|
|
||||||
|
def test_missing_hermes_home_does_not_crash(self, tmp_path):
|
||||||
|
"""Load/save tolerate missing directories."""
|
||||||
|
fake_home = tmp_path / "nonexistent" / "deep"
|
||||||
|
with patch.dict(os.environ, {"HERMES_HOME": str(fake_home)}):
|
||||||
|
from gateway.platforms.discord import DiscordAdapter
|
||||||
|
# _load should return empty set, not crash
|
||||||
|
threads = DiscordAdapter._load_participated_threads()
|
||||||
|
assert threads == set()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue