diff --git a/gateway/config.py b/gateway/config.py index 5d3dfa9f..d325abcd 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -304,6 +304,8 @@ def load_gateway_config() -> GatewayConfig: if isinstance(frc, list): frc = ",".join(str(v) for v in frc) os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc) + if "auto_thread" in discord_cfg and not os.getenv("DISCORD_AUTO_THREAD"): + os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower() except Exception: pass diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index c7ae2ada..257756ad 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -14,6 +14,8 @@ from typing import Dict, List, Optional, Any logger = logging.getLogger(__name__) +VALID_THREAD_AUTO_ARCHIVE_MINUTES = {60, 1440, 4320, 10080} + try: import discord from discord import Message as DiscordMessage, Intents @@ -251,6 +253,7 @@ class DiscordAdapter(BasePlatformAdapter): audio_path: str, caption: Optional[str] = None, reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send audio as a Discord file attachment.""" if not self._client: @@ -289,6 +292,7 @@ class DiscordAdapter(BasePlatformAdapter): image_path: str, caption: Optional[str] = None, reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send a local image file natively as a Discord file attachment.""" if not self._client: @@ -326,6 +330,7 @@ class DiscordAdapter(BasePlatformAdapter): image_url: str, caption: Optional[str] = None, reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send an image natively as a Discord file attachment.""" if not self._client: @@ -711,6 +716,21 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: logger.debug("Discord followup failed: %s", e) + @tree.command(name="thread", description="Create a new thread and start a Hermes session in it") + @discord.app_commands.describe( + name="Thread name", + message="Optional first message to send to Hermes in the thread", + auto_archive_duration="Auto-archive in minutes (60, 1440, 4320, 10080)", + ) + async def slash_thread( + interaction: discord.Interaction, + name: str, + message: str = "", + auto_archive_duration: int = 1440, + ): + await interaction.response.defer(ephemeral=True) + await self._handle_thread_create_slash(interaction, name, message, auto_archive_duration) + def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent: """Build a MessageEvent from a Discord slash command interaction.""" is_dm = isinstance(interaction.channel, discord.DMChannel) @@ -741,6 +761,188 @@ class DiscordAdapter(BasePlatformAdapter): raw_message=interaction, ) + # ------------------------------------------------------------------ + # Thread creation helpers + # ------------------------------------------------------------------ + + async def _handle_thread_create_slash( + self, + interaction: discord.Interaction, + name: str, + message: str = "", + auto_archive_duration: int = 1440, + ) -> None: + """Create a Discord thread from a slash command and start a session in it.""" + result = await self._create_thread( + interaction, + name=name, + message=message, + auto_archive_duration=auto_archive_duration, + ) + + if not result.get("success"): + error = result.get("error", "unknown error") + await interaction.followup.send(f"Failed to create thread: {error}", ephemeral=True) + return + + thread_id = result.get("thread_id") + thread_name = result.get("thread_name") or name + + # Tell the user where the thread is + link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**" + await interaction.followup.send(f"Created thread {link}", ephemeral=True) + + # If a message was provided, kick off a new Hermes session in the thread + starter = (message or "").strip() + if starter and thread_id: + await self._dispatch_thread_session(interaction, thread_id, thread_name, starter) + + async def _dispatch_thread_session( + self, + interaction: discord.Interaction, + thread_id: str, + thread_name: str, + text: str, + ) -> None: + """Build a MessageEvent pointing at a thread and send it through handle_message.""" + guild_name = "" + if hasattr(interaction, "guild") and interaction.guild: + guild_name = interaction.guild.name + + chat_name = f"{guild_name} / {thread_name}" if guild_name else thread_name + + source = self.build_source( + chat_id=thread_id, + chat_name=chat_name, + chat_type="thread", + user_id=str(interaction.user.id), + user_name=interaction.user.display_name, + thread_id=thread_id, + ) + + event = MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=source, + raw_message=interaction, + ) + await self.handle_message(event) + + def _thread_parent_channel(self, channel: Any) -> Any: + """Return the parent text channel when invoked from a thread.""" + return getattr(channel, "parent", None) or channel + + async def _resolve_interaction_channel(self, interaction: discord.Interaction) -> Optional[Any]: + """Return the interaction channel, fetching it if the payload is partial.""" + channel = getattr(interaction, "channel", None) + if channel is not None: + return channel + if not self._client: + return None + channel_id = getattr(interaction, "channel_id", None) + if channel_id is None: + return None + channel = self._client.get_channel(int(channel_id)) + if channel is not None: + return channel + try: + return await self._client.fetch_channel(int(channel_id)) + except Exception: + return None + + async def _create_thread( + self, + interaction: discord.Interaction, + *, + name: str, + message: str = "", + auto_archive_duration: int = 1440, + ) -> Dict[str, Any]: + """Create a thread in the current Discord channel. + + Tries ``parent_channel.create_thread()`` first. If Discord rejects + that (e.g. permission issues), falls back to sending a seed message + and creating the thread from it. + """ + name = (name or "").strip() + if not name: + return {"error": "Thread name is required."} + + if auto_archive_duration not in VALID_THREAD_AUTO_ARCHIVE_MINUTES: + allowed = ", ".join(str(v) for v in sorted(VALID_THREAD_AUTO_ARCHIVE_MINUTES)) + return {"error": f"auto_archive_duration must be one of: {allowed}."} + + channel = await self._resolve_interaction_channel(interaction) + if channel is None: + return {"error": "Could not resolve the current Discord channel."} + if isinstance(channel, discord.DMChannel): + return {"error": "Discord threads can only be created inside server text channels, not DMs."} + + parent_channel = self._thread_parent_channel(channel) + if parent_channel is None: + return {"error": "Could not determine a parent text channel for the new thread."} + + display_name = getattr(getattr(interaction, "user", None), "display_name", None) or "unknown user" + reason = f"Requested by {display_name} via /thread" + starter_message = (message or "").strip() + + try: + thread = await parent_channel.create_thread( + name=name, + auto_archive_duration=auto_archive_duration, + reason=reason, + ) + if starter_message: + await thread.send(starter_message) + return { + "success": True, + "thread_id": str(thread.id), + "thread_name": getattr(thread, "name", None) or name, + } + except Exception as direct_error: + try: + seed_content = starter_message or f"\U0001f9f5 Thread created by Hermes: **{name}**" + seed_msg = await parent_channel.send(seed_content) + thread = await seed_msg.create_thread( + name=name, + auto_archive_duration=auto_archive_duration, + reason=reason, + ) + return { + "success": True, + "thread_id": str(thread.id), + "thread_name": getattr(thread, "name", None) or name, + } + except Exception as fallback_error: + return { + "error": ( + "Discord rejected direct thread creation and the fallback also failed. " + f"Direct error: {direct_error}. Fallback error: {fallback_error}" + ) + } + + # ------------------------------------------------------------------ + # Auto-thread helpers + # ------------------------------------------------------------------ + + async def _auto_create_thread(self, message: 'DiscordMessage') -> Optional[Any]: + """Create a thread from a user message for auto-threading. + + Returns the created thread object, or ``None`` on failure. + """ + # Build a short thread name from the message + content = (message.content or "").strip() + thread_name = content[:80] if content else "Hermes" + if len(content) > 80: + thread_name = thread_name[:77] + "..." + + try: + thread = await message.create_thread(name=thread_name, auto_archive_duration=1440) + return thread + except Exception as e: + logger.warning("[%s] Auto-thread creation failed: %s", self.name, e) + return None + async def send_exec_approval( self, chat_id: str, command: str, approval_id: str ) -> SendResult: @@ -852,6 +1054,19 @@ class DiscordAdapter(BasePlatformAdapter): message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip() message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip() + # Auto-thread: when enabled, automatically create a thread for every + # new message in a text channel so each conversation is isolated. + # Messages already inside threads or DMs are unaffected. + auto_threaded_channel = None + if not is_thread and not isinstance(message.channel, discord.DMChannel): + auto_thread = os.getenv("DISCORD_AUTO_THREAD", "").lower() in ("true", "1", "yes") + if auto_thread: + thread = await self._auto_create_thread(message) + if thread: + is_thread = True + thread_id = str(thread.id) + auto_threaded_channel = thread + # Determine message type msg_type = MessageType.TEXT if message.content.startswith("/"): @@ -870,13 +1085,16 @@ class DiscordAdapter(BasePlatformAdapter): msg_type = MessageType.DOCUMENT break + # When auto-threading kicked in, route responses to the new thread + effective_channel = auto_threaded_channel or message.channel + # Determine chat type if isinstance(message.channel, discord.DMChannel): chat_type = "dm" chat_name = message.author.name elif is_thread: chat_type = "thread" - chat_name = self._format_thread_chat_name(message.channel) + chat_name = self._format_thread_chat_name(effective_channel) else: chat_type = "group" chat_name = getattr(message.channel, "name", str(message.channel.id)) @@ -888,7 +1106,7 @@ class DiscordAdapter(BasePlatformAdapter): # Build source source = self.build_source( - chat_id=str(message.channel.id), + chat_id=str(effective_channel.id), chat_name=chat_name, chat_type=chat_type, user_id=str(message.author.id), diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py index fd9eacab..ff15326d 100644 --- a/tests/gateway/test_discord_free_response.py +++ b/tests/gateway/test_discord_free_response.py @@ -27,6 +27,9 @@ def _ensure_discord_mock(): discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4) discord_mod.Interaction = object discord_mod.Embed = MagicMock + discord_mod.app_commands = SimpleNamespace( + describe=lambda **kwargs: (lambda fn: fn), + ) ext_mod = MagicMock() commands_mod = MagicMock() diff --git a/tests/gateway/test_discord_media_metadata.py b/tests/gateway/test_discord_media_metadata.py new file mode 100644 index 00000000..a98ac4fc --- /dev/null +++ b/tests/gateway/test_discord_media_metadata.py @@ -0,0 +1,9 @@ +import inspect + +from gateway.platforms.discord import DiscordAdapter + + +def test_discord_media_methods_accept_metadata_kwarg(): + for method_name in ("send_voice", "send_image_file", "send_image"): + signature = inspect.signature(getattr(DiscordAdapter, method_name)) + assert "metadata" in signature.parameters, method_name diff --git a/tests/gateway/test_discord_slash_commands.py b/tests/gateway/test_discord_slash_commands.py new file mode 100644 index 00000000..0b8cb04c --- /dev/null +++ b/tests/gateway/test_discord_slash_commands.py @@ -0,0 +1,434 @@ +"""Tests for native Discord slash command fast-paths (thread creation & auto-thread).""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch +import sys + +import pytest + +from gateway.config import PlatformConfig + + +def _ensure_discord_mock(): + if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): + return + + discord_mod = MagicMock() + discord_mod.Intents.default.return_value = MagicMock() + discord_mod.DMChannel = type("DMChannel", (), {}) + discord_mod.Thread = type("Thread", (), {}) + discord_mod.ForumChannel = type("ForumChannel", (), {}) + discord_mod.Interaction = object + discord_mod.app_commands = SimpleNamespace( + describe=lambda **kwargs: (lambda fn: fn), + ) + + ext_mod = MagicMock() + commands_mod = MagicMock() + commands_mod.Bot = MagicMock + ext_mod.commands = commands_mod + + sys.modules.setdefault("discord", discord_mod) + sys.modules.setdefault("discord.ext", ext_mod) + sys.modules.setdefault("discord.ext.commands", commands_mod) + + +_ensure_discord_mock() + +from gateway.platforms.discord import DiscordAdapter # noqa: E402 + + +class FakeTree: + def __init__(self): + self.commands = {} + + def command(self, *, name, description): + def decorator(fn): + self.commands[name] = fn + return fn + + return decorator + + +@pytest.fixture +def adapter(): + config = PlatformConfig(enabled=True, token="***") + adapter = DiscordAdapter(config) + adapter._client = SimpleNamespace( + tree=FakeTree(), + get_channel=lambda _id: None, + fetch_channel=AsyncMock(), + user=SimpleNamespace(id=99999, name="HermesBot"), + ) + return adapter + + +# ------------------------------------------------------------------ +# /thread slash command registration +# ------------------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_registers_native_thread_slash_command(adapter): + adapter._handle_thread_create_slash = AsyncMock() + adapter._register_slash_commands() + + command = adapter._client.tree.commands["thread"] + interaction = SimpleNamespace( + response=SimpleNamespace(defer=AsyncMock()), + ) + + await command(interaction, name="Planning", message="", auto_archive_duration=1440) + + interaction.response.defer.assert_awaited_once_with(ephemeral=True) + adapter._handle_thread_create_slash.assert_awaited_once_with(interaction, "Planning", "", 1440) + + +# ------------------------------------------------------------------ +# _handle_thread_create_slash — success, session dispatch, failure +# ------------------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_handle_thread_create_slash_reports_success(adapter): + created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock()) + parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread), send=AsyncMock()) + interaction_channel = SimpleNamespace(parent=parent_channel) + interaction = SimpleNamespace( + channel=interaction_channel, + channel_id=123, + user=SimpleNamespace(display_name="Jezza", id=42), + guild=SimpleNamespace(name="TestGuild"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + await adapter._handle_thread_create_slash(interaction, "Planning", "Kickoff", 1440) + + parent_channel.create_thread.assert_awaited_once_with( + name="Planning", + auto_archive_duration=1440, + reason="Requested by Jezza via /thread", + ) + created_thread.send.assert_awaited_once_with("Kickoff") + # Thread link shown to user + interaction.followup.send.assert_awaited() + args, kwargs = interaction.followup.send.await_args + assert "<#555>" in args[0] + assert kwargs["ephemeral"] is True + + +@pytest.mark.asyncio +async def test_handle_thread_create_slash_dispatches_session_when_message_provided(adapter): + """When a message is given, _dispatch_thread_session should be called.""" + created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock()) + parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread)) + interaction = SimpleNamespace( + channel=SimpleNamespace(parent=parent_channel), + channel_id=123, + user=SimpleNamespace(display_name="Jezza", id=42), + guild=SimpleNamespace(name="TestGuild"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + adapter._dispatch_thread_session = AsyncMock() + + await adapter._handle_thread_create_slash(interaction, "Planning", "Hello Hermes", 1440) + + adapter._dispatch_thread_session.assert_awaited_once_with( + interaction, "555", "Planning", "Hello Hermes", + ) + + +@pytest.mark.asyncio +async def test_handle_thread_create_slash_no_dispatch_without_message(adapter): + """Without a message, no session dispatch should occur.""" + created_thread = SimpleNamespace(id=555, name="Planning", send=AsyncMock()) + parent_channel = SimpleNamespace(create_thread=AsyncMock(return_value=created_thread)) + interaction = SimpleNamespace( + channel=SimpleNamespace(parent=parent_channel), + channel_id=123, + user=SimpleNamespace(display_name="Jezza", id=42), + guild=SimpleNamespace(name="TestGuild"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + adapter._dispatch_thread_session = AsyncMock() + + await adapter._handle_thread_create_slash(interaction, "Planning", "", 1440) + + adapter._dispatch_thread_session.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_handle_thread_create_slash_falls_back_to_seed_message(adapter): + created_thread = SimpleNamespace(id=555, name="Planning") + seed_message = SimpleNamespace(id=777, create_thread=AsyncMock(return_value=created_thread)) + channel = SimpleNamespace( + create_thread=AsyncMock(side_effect=RuntimeError("direct failed")), + send=AsyncMock(return_value=seed_message), + ) + interaction = SimpleNamespace( + channel=channel, + channel_id=123, + user=SimpleNamespace(display_name="Jezza", id=42), + guild=SimpleNamespace(name="TestGuild"), + followup=SimpleNamespace(send=AsyncMock()), + ) + + await adapter._handle_thread_create_slash(interaction, "Planning", "Kickoff", 1440) + + channel.send.assert_awaited_once_with("Kickoff") + seed_message.create_thread.assert_awaited_once_with( + name="Planning", + auto_archive_duration=1440, + reason="Requested by Jezza via /thread", + ) + interaction.followup.send.assert_awaited() + + +@pytest.mark.asyncio +async def test_handle_thread_create_slash_reports_failure(adapter): + channel = SimpleNamespace( + create_thread=AsyncMock(side_effect=RuntimeError("direct failed")), + send=AsyncMock(side_effect=RuntimeError("nope")), + ) + interaction = SimpleNamespace( + channel=channel, + channel_id=123, + user=SimpleNamespace(display_name="Jezza", id=42), + followup=SimpleNamespace(send=AsyncMock()), + ) + + await adapter._handle_thread_create_slash(interaction, "Planning", "", 1440) + + interaction.followup.send.assert_awaited_once() + args, kwargs = interaction.followup.send.await_args + assert "Failed to create thread:" in args[0] + assert "nope" in args[0] + assert kwargs["ephemeral"] is True + + +# ------------------------------------------------------------------ +# _dispatch_thread_session — builds correct event and routes it +# ------------------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_dispatch_thread_session_builds_thread_event(adapter): + """Dispatched event should have chat_type=thread and chat_id=thread_id.""" + interaction = SimpleNamespace( + user=SimpleNamespace(display_name="Jezza", id=42), + guild=SimpleNamespace(name="TestGuild"), + ) + + captured_events = [] + + async def capture_handle(event): + captured_events.append(event) + + adapter.handle_message = capture_handle + + await adapter._dispatch_thread_session(interaction, "555", "Planning", "Hello!") + + assert len(captured_events) == 1 + event = captured_events[0] + assert event.text == "Hello!" + assert event.source.chat_id == "555" + assert event.source.chat_type == "thread" + assert event.source.thread_id == "555" + assert "TestGuild" in event.source.chat_name + + +# ------------------------------------------------------------------ +# Auto-thread: _auto_create_thread +# ------------------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_auto_create_thread_uses_message_content_as_name(adapter): + thread = SimpleNamespace(id=999, name="Hello world") + message = SimpleNamespace( + content="Hello world, how are you?", + create_thread=AsyncMock(return_value=thread), + ) + + result = await adapter._auto_create_thread(message) + + assert result is thread + message.create_thread.assert_awaited_once() + call_kwargs = message.create_thread.await_args[1] + assert call_kwargs["name"] == "Hello world, how are you?" + assert call_kwargs["auto_archive_duration"] == 1440 + + +@pytest.mark.asyncio +async def test_auto_create_thread_truncates_long_names(adapter): + long_text = "a" * 200 + thread = SimpleNamespace(id=999, name="truncated") + message = SimpleNamespace( + content=long_text, + create_thread=AsyncMock(return_value=thread), + ) + + result = await adapter._auto_create_thread(message) + + assert result is thread + call_kwargs = message.create_thread.await_args[1] + assert len(call_kwargs["name"]) <= 80 + assert call_kwargs["name"].endswith("...") + + +@pytest.mark.asyncio +async def test_auto_create_thread_returns_none_on_failure(adapter): + message = SimpleNamespace( + content="Hello", + create_thread=AsyncMock(side_effect=RuntimeError("no perms")), + ) + + result = await adapter._auto_create_thread(message) + assert result is None + + +# ------------------------------------------------------------------ +# Auto-thread integration in _handle_message +# ------------------------------------------------------------------ + + +import discord as _discord_mod # noqa: E402 — mock or real, used below + + +class _FakeTextChannel: + """A channel that is NOT a discord.Thread or discord.DMChannel.""" + + def __init__(self, channel_id=100, name="general", guild_name="TestGuild"): + self.id = channel_id + self.name = name + self.guild = SimpleNamespace(name=guild_name, id=1) + self.topic = None + + +class _FakeThreadChannel(_discord_mod.Thread): + """isinstance(ch, discord.Thread) → True.""" + + def __init__(self, channel_id=200, name="existing-thread", guild_name="TestGuild", parent_id=100): + # Don't call super().__init__ — mock Thread is just an empty type + self.id = channel_id + self.name = name + self.guild = SimpleNamespace(name=guild_name, id=1) + self.topic = None + self.parent = SimpleNamespace(id=parent_id, name="general", guild=SimpleNamespace(name=guild_name, id=1)) + + +def _fake_message(channel, *, content="Hello", author_id=42, display_name="Jezza"): + return SimpleNamespace( + author=SimpleNamespace(id=author_id, display_name=display_name, bot=False), + content=content, + channel=channel, + attachments=[], + mentions=[], + reference=None, + created_at=None, + id=12345, + ) + + +@pytest.mark.asyncio +async def test_auto_thread_creates_thread_and_redirects(adapter, monkeypatch): + """When DISCORD_AUTO_THREAD=true, a new thread is created and the event routes there.""" + monkeypatch.setenv("DISCORD_AUTO_THREAD", "true") + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false") + + thread = SimpleNamespace(id=999, name="Hello") + adapter._auto_create_thread = AsyncMock(return_value=thread) + + captured_events = [] + + async def capture_handle(event): + captured_events.append(event) + + adapter.handle_message = capture_handle + + msg = _fake_message(_FakeTextChannel(), content="Hello world") + + await adapter._handle_message(msg) + + adapter._auto_create_thread.assert_awaited_once_with(msg) + assert len(captured_events) == 1 + event = captured_events[0] + assert event.source.chat_id == "999" # redirected to thread + assert event.source.chat_type == "thread" + assert event.source.thread_id == "999" + + +@pytest.mark.asyncio +async def test_auto_thread_disabled_by_default(adapter, monkeypatch): + """Without DISCORD_AUTO_THREAD, messages stay in the channel.""" + monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False) + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false") + + adapter._auto_create_thread = AsyncMock() + + captured_events = [] + + async def capture_handle(event): + captured_events.append(event) + + adapter.handle_message = capture_handle + + msg = _fake_message(_FakeTextChannel()) + + await adapter._handle_message(msg) + + adapter._auto_create_thread.assert_not_awaited() + assert len(captured_events) == 1 + assert captured_events[0].source.chat_id == "100" # stays in channel + + +@pytest.mark.asyncio +async def test_auto_thread_skips_threads_and_dms(adapter, monkeypatch): + """Auto-thread should not create threads inside existing threads.""" + monkeypatch.setenv("DISCORD_AUTO_THREAD", "true") + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false") + + adapter._auto_create_thread = AsyncMock() + + captured_events = [] + + async def capture_handle(event): + captured_events.append(event) + + adapter.handle_message = capture_handle + + msg = _fake_message(_FakeThreadChannel()) + + await adapter._handle_message(msg) + + adapter._auto_create_thread.assert_not_awaited() # should NOT auto-thread + + +# ------------------------------------------------------------------ +# Config bridge +# ------------------------------------------------------------------ + + +def test_discord_auto_thread_config_bridge(monkeypatch, tmp_path): + """discord.auto_thread in config.yaml should be bridged to DISCORD_AUTO_THREAD env var.""" + import yaml + from pathlib import Path + + # Write a config.yaml the loader will find + hermes_dir = tmp_path / ".hermes" + hermes_dir.mkdir() + config_path = hermes_dir / "config.yaml" + config_path.write_text(yaml.dump({ + "discord": {"auto_thread": True}, + })) + + monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + from gateway.config import load_gateway_config + load_gateway_config() + + import os + assert os.getenv("DISCORD_AUTO_THREAD") == "true"