feat(discord): add /thread command, auto_thread config, and media metadata fix (#1178)
- Add /thread slash command that creates a Discord thread and starts a new Hermes session in it. The starter message (if provided) becomes the first user input in the new session. - Add discord.auto_thread config option (DISCORD_AUTO_THREAD env var): when enabled, every message in a text channel automatically creates a thread, allowing parallel isolated sessions. - Fix Discord media method signatures to accept metadata kwarg (send_voice, send_image_file, send_image) — prevents TypeError when the base adapter passes platform metadata. - Fix test mock isolation: add app_commands and ForumChannel to discord mocks so tests pass in full-suite runs. Based on PRs #866 and #1109 by insecurejezza, modified per review: removed /channel command (unsafe), added auto_thread feature, made /thread dispatch new sessions. Co-authored-by: insecurejezza <insecurejezza@users.noreply.github.com>
This commit is contained in:
parent
d425901bae
commit
b8b45bfb77
5 changed files with 668 additions and 2 deletions
|
|
@ -14,6 +14,8 @@ from typing import Dict, List, Optional, Any
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VALID_THREAD_AUTO_ARCHIVE_MINUTES = {60, 1440, 4320, 10080}
|
||||
|
||||
try:
|
||||
import discord
|
||||
from discord import Message as DiscordMessage, Intents
|
||||
|
|
@ -251,6 +253,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a Discord file attachment."""
|
||||
if not self._client:
|
||||
|
|
@ -289,6 +292,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
|
|
@ -326,6 +330,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
|
|
@ -711,6 +716,21 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="thread", description="Create a new thread and start a Hermes session in it")
|
||||
@discord.app_commands.describe(
|
||||
name="Thread name",
|
||||
message="Optional first message to send to Hermes in the thread",
|
||||
auto_archive_duration="Auto-archive in minutes (60, 1440, 4320, 10080)",
|
||||
)
|
||||
async def slash_thread(
|
||||
interaction: discord.Interaction,
|
||||
name: str,
|
||||
message: str = "",
|
||||
auto_archive_duration: int = 1440,
|
||||
):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
await self._handle_thread_create_slash(interaction, name, message, auto_archive_duration)
|
||||
|
||||
def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Discord slash command interaction."""
|
||||
is_dm = isinstance(interaction.channel, discord.DMChannel)
|
||||
|
|
@ -741,6 +761,188 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
raw_message=interaction,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Thread creation helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_thread_create_slash(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
name: str,
|
||||
message: str = "",
|
||||
auto_archive_duration: int = 1440,
|
||||
) -> None:
|
||||
"""Create a Discord thread from a slash command and start a session in it."""
|
||||
result = await self._create_thread(
|
||||
interaction,
|
||||
name=name,
|
||||
message=message,
|
||||
auto_archive_duration=auto_archive_duration,
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
error = result.get("error", "unknown error")
|
||||
await interaction.followup.send(f"Failed to create thread: {error}", ephemeral=True)
|
||||
return
|
||||
|
||||
thread_id = result.get("thread_id")
|
||||
thread_name = result.get("thread_name") or name
|
||||
|
||||
# Tell the user where the thread is
|
||||
link = f"<#{thread_id}>" if thread_id else f"**{thread_name}**"
|
||||
await interaction.followup.send(f"Created thread {link}", ephemeral=True)
|
||||
|
||||
# If a message was provided, kick off a new Hermes session in the thread
|
||||
starter = (message or "").strip()
|
||||
if starter and thread_id:
|
||||
await self._dispatch_thread_session(interaction, thread_id, thread_name, starter)
|
||||
|
||||
async def _dispatch_thread_session(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
thread_id: str,
|
||||
thread_name: str,
|
||||
text: str,
|
||||
) -> None:
|
||||
"""Build a MessageEvent pointing at a thread and send it through handle_message."""
|
||||
guild_name = ""
|
||||
if hasattr(interaction, "guild") and interaction.guild:
|
||||
guild_name = interaction.guild.name
|
||||
|
||||
chat_name = f"{guild_name} / {thread_name}" if guild_name else thread_name
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=thread_id,
|
||||
chat_name=chat_name,
|
||||
chat_type="thread",
|
||||
user_id=str(interaction.user.id),
|
||||
user_name=interaction.user.display_name,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message=interaction,
|
||||
)
|
||||
await self.handle_message(event)
|
||||
|
||||
def _thread_parent_channel(self, channel: Any) -> Any:
|
||||
"""Return the parent text channel when invoked from a thread."""
|
||||
return getattr(channel, "parent", None) or channel
|
||||
|
||||
async def _resolve_interaction_channel(self, interaction: discord.Interaction) -> Optional[Any]:
|
||||
"""Return the interaction channel, fetching it if the payload is partial."""
|
||||
channel = getattr(interaction, "channel", None)
|
||||
if channel is not None:
|
||||
return channel
|
||||
if not self._client:
|
||||
return None
|
||||
channel_id = getattr(interaction, "channel_id", None)
|
||||
if channel_id is None:
|
||||
return None
|
||||
channel = self._client.get_channel(int(channel_id))
|
||||
if channel is not None:
|
||||
return channel
|
||||
try:
|
||||
return await self._client.fetch_channel(int(channel_id))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _create_thread(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
*,
|
||||
name: str,
|
||||
message: str = "",
|
||||
auto_archive_duration: int = 1440,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a thread in the current Discord channel.
|
||||
|
||||
Tries ``parent_channel.create_thread()`` first. If Discord rejects
|
||||
that (e.g. permission issues), falls back to sending a seed message
|
||||
and creating the thread from it.
|
||||
"""
|
||||
name = (name or "").strip()
|
||||
if not name:
|
||||
return {"error": "Thread name is required."}
|
||||
|
||||
if auto_archive_duration not in VALID_THREAD_AUTO_ARCHIVE_MINUTES:
|
||||
allowed = ", ".join(str(v) for v in sorted(VALID_THREAD_AUTO_ARCHIVE_MINUTES))
|
||||
return {"error": f"auto_archive_duration must be one of: {allowed}."}
|
||||
|
||||
channel = await self._resolve_interaction_channel(interaction)
|
||||
if channel is None:
|
||||
return {"error": "Could not resolve the current Discord channel."}
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
return {"error": "Discord threads can only be created inside server text channels, not DMs."}
|
||||
|
||||
parent_channel = self._thread_parent_channel(channel)
|
||||
if parent_channel is None:
|
||||
return {"error": "Could not determine a parent text channel for the new thread."}
|
||||
|
||||
display_name = getattr(getattr(interaction, "user", None), "display_name", None) or "unknown user"
|
||||
reason = f"Requested by {display_name} via /thread"
|
||||
starter_message = (message or "").strip()
|
||||
|
||||
try:
|
||||
thread = await parent_channel.create_thread(
|
||||
name=name,
|
||||
auto_archive_duration=auto_archive_duration,
|
||||
reason=reason,
|
||||
)
|
||||
if starter_message:
|
||||
await thread.send(starter_message)
|
||||
return {
|
||||
"success": True,
|
||||
"thread_id": str(thread.id),
|
||||
"thread_name": getattr(thread, "name", None) or name,
|
||||
}
|
||||
except Exception as direct_error:
|
||||
try:
|
||||
seed_content = starter_message or f"\U0001f9f5 Thread created by Hermes: **{name}**"
|
||||
seed_msg = await parent_channel.send(seed_content)
|
||||
thread = await seed_msg.create_thread(
|
||||
name=name,
|
||||
auto_archive_duration=auto_archive_duration,
|
||||
reason=reason,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"thread_id": str(thread.id),
|
||||
"thread_name": getattr(thread, "name", None) or name,
|
||||
}
|
||||
except Exception as fallback_error:
|
||||
return {
|
||||
"error": (
|
||||
"Discord rejected direct thread creation and the fallback also failed. "
|
||||
f"Direct error: {direct_error}. Fallback error: {fallback_error}"
|
||||
)
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auto-thread helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _auto_create_thread(self, message: 'DiscordMessage') -> Optional[Any]:
|
||||
"""Create a thread from a user message for auto-threading.
|
||||
|
||||
Returns the created thread object, or ``None`` on failure.
|
||||
"""
|
||||
# Build a short thread name from the message
|
||||
content = (message.content or "").strip()
|
||||
thread_name = content[:80] if content else "Hermes"
|
||||
if len(content) > 80:
|
||||
thread_name = thread_name[:77] + "..."
|
||||
|
||||
try:
|
||||
thread = await message.create_thread(name=thread_name, auto_archive_duration=1440)
|
||||
return thread
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Auto-thread creation failed: %s", self.name, e)
|
||||
return None
|
||||
|
||||
async def send_exec_approval(
|
||||
self, chat_id: str, command: str, approval_id: str
|
||||
) -> SendResult:
|
||||
|
|
@ -852,6 +1054,19 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||||
|
||||
# Auto-thread: when enabled, automatically create a thread for every
|
||||
# new message in a text channel so each conversation is isolated.
|
||||
# Messages already inside threads or DMs are unaffected.
|
||||
auto_threaded_channel = None
|
||||
if not is_thread and not isinstance(message.channel, discord.DMChannel):
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "").lower() in ("true", "1", "yes")
|
||||
if auto_thread:
|
||||
thread = await self._auto_create_thread(message)
|
||||
if thread:
|
||||
is_thread = True
|
||||
thread_id = str(thread.id)
|
||||
auto_threaded_channel = thread
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if message.content.startswith("/"):
|
||||
|
|
@ -870,13 +1085,16 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
# When auto-threading kicked in, route responses to the new thread
|
||||
effective_channel = auto_threaded_channel or message.channel
|
||||
|
||||
# Determine chat type
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
chat_name = message.author.name
|
||||
elif is_thread:
|
||||
chat_type = "thread"
|
||||
chat_name = self._format_thread_chat_name(message.channel)
|
||||
chat_name = self._format_thread_chat_name(effective_channel)
|
||||
else:
|
||||
chat_type = "group"
|
||||
chat_name = getattr(message.channel, "name", str(message.channel.id))
|
||||
|
|
@ -888,7 +1106,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(message.channel.id),
|
||||
chat_id=str(effective_channel.id),
|
||||
chat_name=chat_name,
|
||||
chat_type=chat_type,
|
||||
user_id=str(message.author.id),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue