merge: resolve conflict with main (add mcp + homeassistant extras)
This commit is contained in:
commit
aefc330b8f
81 changed files with 8138 additions and 776 deletions
|
|
@ -736,9 +736,13 @@ class BasePlatformAdapter(ABC):
|
|||
chat_type: str = "dm",
|
||||
user_id: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
thread_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None,
|
||||
chat_topic: Optional[str] = None,
|
||||
) -> SessionSource:
|
||||
"""Helper to build a SessionSource for this platform."""
|
||||
# Normalize empty topic to None
|
||||
if chat_topic is not None and not chat_topic.strip():
|
||||
chat_topic = None
|
||||
return SessionSource(
|
||||
platform=self.platform,
|
||||
chat_id=str(chat_id),
|
||||
|
|
@ -747,6 +751,7 @@ class BasePlatformAdapter(ABC):
|
|||
user_id=str(user_id) if user_id else None,
|
||||
user_name=user_name,
|
||||
thread_id=str(thread_id) if thread_id else None,
|
||||
chat_topic=chat_topic.strip() if chat_topic else None,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -542,6 +542,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
chat_name = interaction.channel.name
|
||||
if hasattr(interaction.channel, "guild") and interaction.channel.guild:
|
||||
chat_name = f"{interaction.channel.guild.name} / #{chat_name}"
|
||||
|
||||
# Get channel topic (if available)
|
||||
chat_topic = getattr(interaction.channel, "topic", None)
|
||||
|
||||
source = self.build_source(
|
||||
chat_id=str(interaction.channel_id),
|
||||
|
|
@ -549,6 +552,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
chat_type=chat_type,
|
||||
user_id=str(interaction.user.id),
|
||||
user_name=interaction.user.display_name,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
msg_type = MessageType.COMMAND if text.startswith("/") else MessageType.TEXT
|
||||
|
|
@ -661,6 +665,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if isinstance(message.channel, discord.Thread):
|
||||
thread_id = str(message.channel.id)
|
||||
|
||||
# Get channel topic (if available - TextChannels have topics, DMs/threads don't)
|
||||
chat_topic = getattr(message.channel, "topic", None)
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(message.channel.id),
|
||||
|
|
@ -669,6 +676,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
user_id=str(message.author.id),
|
||||
user_name=message.author.display_name,
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
# Build media URLs -- download image attachments to local cache so the
|
||||
|
|
|
|||
|
|
@ -29,7 +29,17 @@ except ImportError:
|
|||
Bot = Any
|
||||
Message = Any
|
||||
Application = Any
|
||||
ContextTypes = Any
|
||||
CommandHandler = Any
|
||||
TelegramMessageHandler = Any
|
||||
filters = None
|
||||
ParseMode = None
|
||||
ChatType = None
|
||||
|
||||
# Mock ContextTypes so type annotations using ContextTypes.DEFAULT_TYPE
|
||||
# don't crash during class definition when the library isn't installed.
|
||||
class _MockContextTypes:
|
||||
DEFAULT_TYPE = Any
|
||||
ContextTypes = _MockContextTypes
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
|
|
|
|||
|
|
@ -19,7 +19,10 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
|
|
@ -157,16 +160,18 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
pass
|
||||
|
||||
# Start the bridge process in its own process group
|
||||
whatsapp_mode = os.getenv("WHATSAPP_MODE", "self-chat")
|
||||
self._bridge_process = subprocess.Popen(
|
||||
[
|
||||
"node",
|
||||
str(bridge_path),
|
||||
"--port", str(self._bridge_port),
|
||||
"--session", str(self._session_path),
|
||||
"--mode", whatsapp_mode,
|
||||
],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
preexec_fn=os.setsid,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
# Wait for bridge to be ready via HTTP health check
|
||||
|
|
@ -211,13 +216,19 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
# Kill the entire process group so child node processes die too
|
||||
import signal
|
||||
try:
|
||||
os.killpg(os.getpgid(self._bridge_process.pid), signal.SIGTERM)
|
||||
if _IS_WINDOWS:
|
||||
self._bridge_process.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(self._bridge_process.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
self._bridge_process.terminate()
|
||||
await asyncio.sleep(1)
|
||||
if self._bridge_process.poll() is None:
|
||||
try:
|
||||
os.killpg(os.getpgid(self._bridge_process.pid), signal.SIGKILL)
|
||||
if _IS_WINDOWS:
|
||||
self._bridge_process.kill()
|
||||
else:
|
||||
os.killpg(os.getpgid(self._bridge_process.pid), signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
self._bridge_process.kill()
|
||||
except Exception as e:
|
||||
|
|
|
|||
143
gateway/run.py
143
gateway/run.py
|
|
@ -164,6 +164,7 @@ class GatewayRunner:
|
|||
self._prefill_messages = self._load_prefill_messages()
|
||||
self._ephemeral_system_prompt = self._load_ephemeral_system_prompt()
|
||||
self._reasoning_config = self._load_reasoning_config()
|
||||
self._provider_routing = self._load_provider_routing()
|
||||
|
||||
# Wire process registry into session store for reset protection
|
||||
from tools.process_registry import process_registry
|
||||
|
|
@ -346,6 +347,20 @@ class GatewayRunner:
|
|||
logger.warning("Unknown reasoning_effort '%s', using default (xhigh)", effort)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _load_provider_routing() -> dict:
|
||||
"""Load OpenRouter provider routing preferences from config.yaml."""
|
||||
try:
|
||||
import yaml as _y
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path) as _f:
|
||||
cfg = _y.safe_load(_f) or {}
|
||||
return cfg.get("provider_routing", {}) or {}
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
Start the gateway and all configured platform adapters.
|
||||
|
|
@ -643,7 +658,7 @@ class GatewayRunner:
|
|||
# Emit command:* hook for any recognized slash command
|
||||
_known_commands = {"new", "reset", "help", "status", "stop", "model",
|
||||
"personality", "retry", "undo", "sethome", "set-home",
|
||||
"compress", "usage"}
|
||||
"compress", "usage", "reload-mcp"}
|
||||
if command and command in _known_commands:
|
||||
await self.hooks.emit(f"command:{command}", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
|
|
@ -684,6 +699,9 @@ class GatewayRunner:
|
|||
|
||||
if command == "usage":
|
||||
return await self._handle_usage_command(event)
|
||||
|
||||
if command == "reload-mcp":
|
||||
return await self._handle_reload_mcp_command(event)
|
||||
|
||||
# Skill slash commands: /skill-name loads the skill and sends to agent
|
||||
if command:
|
||||
|
|
@ -982,13 +1000,12 @@ class GatewayRunner:
|
|||
source = event.source
|
||||
|
||||
# Get existing session key
|
||||
session_key = f"agent:main:{source.platform.value}:" + \
|
||||
(f"dm" if source.chat_type == "dm" else f"{source.chat_type}:{source.chat_id}")
|
||||
session_key = self.session_store._generate_session_key(source)
|
||||
|
||||
# Memory flush before reset: load the old transcript and let a
|
||||
# temporary agent save memories before the session is wiped.
|
||||
try:
|
||||
old_entry = self.session_store._sessions.get(session_key)
|
||||
old_entry = self.session_store._entries.get(session_key)
|
||||
if old_entry:
|
||||
old_history = self.session_store.load_transcript(old_entry.session_id)
|
||||
if old_history:
|
||||
|
|
@ -1085,6 +1102,7 @@ class GatewayRunner:
|
|||
"`/sethome` — Set this chat as the home channel",
|
||||
"`/compress` — Compress conversation context",
|
||||
"`/usage` — Show token usage for this session",
|
||||
"`/reload-mcp` — Reload MCP servers from config",
|
||||
"`/help` — Show this message",
|
||||
]
|
||||
try:
|
||||
|
|
@ -1220,9 +1238,9 @@ class GatewayRunner:
|
|||
if not last_user_msg:
|
||||
return "No previous message to retry."
|
||||
|
||||
# Truncate history to before the last user message
|
||||
# Truncate history to before the last user message and persist
|
||||
truncated = history[:last_user_idx]
|
||||
session_entry.conversation_history = truncated
|
||||
self.session_store.rewrite_transcript(session_entry.session_id, truncated)
|
||||
|
||||
# Re-send by creating a fake text event with the old message
|
||||
retry_event = MessageEvent(
|
||||
|
|
@ -1254,7 +1272,7 @@ class GatewayRunner:
|
|||
|
||||
removed_msg = history[last_user_idx].get("content", "")
|
||||
removed_count = len(history) - last_user_idx
|
||||
session_entry.conversation_history = history[:last_user_idx]
|
||||
self.session_store.rewrite_transcript(session_entry.session_id, history[:last_user_idx])
|
||||
|
||||
preview = removed_msg[:40] + "..." if len(removed_msg) > 40 else removed_msg
|
||||
return f"↩️ Undid {removed_count} message(s).\nRemoved: \"{preview}\""
|
||||
|
|
@ -1328,7 +1346,7 @@ class GatewayRunner:
|
|||
lambda: tmp_agent._compress_context(msgs, "", approx_tokens=approx_tokens),
|
||||
)
|
||||
|
||||
session_entry.conversation_history = compressed
|
||||
self.session_store.rewrite_transcript(session_entry.session_id, compressed)
|
||||
new_count = len(compressed)
|
||||
new_tokens = estimate_messages_tokens_rough(compressed)
|
||||
|
||||
|
|
@ -1378,6 +1396,76 @@ class GatewayRunner:
|
|||
)
|
||||
return "No usage data available for this session."
|
||||
|
||||
async def _handle_reload_mcp_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /reload-mcp command -- disconnect and reconnect all MCP servers."""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
from tools.mcp_tool import shutdown_mcp_servers, discover_mcp_tools, _load_mcp_config, _servers, _lock
|
||||
|
||||
# Capture old server names before shutdown
|
||||
with _lock:
|
||||
old_servers = set(_servers.keys())
|
||||
|
||||
# Read new config before shutting down, so we know what will be added/removed
|
||||
new_config = _load_mcp_config()
|
||||
new_server_names = set(new_config.keys())
|
||||
|
||||
# Shutdown existing connections
|
||||
await loop.run_in_executor(None, shutdown_mcp_servers)
|
||||
|
||||
# Reconnect by discovering tools (reads config.yaml fresh)
|
||||
new_tools = await loop.run_in_executor(None, discover_mcp_tools)
|
||||
|
||||
# Compute what changed
|
||||
with _lock:
|
||||
connected_servers = set(_servers.keys())
|
||||
|
||||
added = connected_servers - old_servers
|
||||
removed = old_servers - connected_servers
|
||||
reconnected = connected_servers & old_servers
|
||||
|
||||
lines = ["🔄 **MCP Servers Reloaded**\n"]
|
||||
if reconnected:
|
||||
lines.append(f"♻️ Reconnected: {', '.join(sorted(reconnected))}")
|
||||
if added:
|
||||
lines.append(f"➕ Added: {', '.join(sorted(added))}")
|
||||
if removed:
|
||||
lines.append(f"➖ Removed: {', '.join(sorted(removed))}")
|
||||
if not connected_servers:
|
||||
lines.append("No MCP servers connected.")
|
||||
else:
|
||||
lines.append(f"\n🔧 {len(new_tools)} tool(s) available from {len(connected_servers)} server(s)")
|
||||
|
||||
# Inject a message at the END of the session history so the
|
||||
# model knows tools changed on its next turn. Appended after
|
||||
# all existing messages to preserve prompt-cache for the prefix.
|
||||
change_parts = []
|
||||
if added:
|
||||
change_parts.append(f"Added servers: {', '.join(sorted(added))}")
|
||||
if removed:
|
||||
change_parts.append(f"Removed servers: {', '.join(sorted(removed))}")
|
||||
if reconnected:
|
||||
change_parts.append(f"Reconnected servers: {', '.join(sorted(reconnected))}")
|
||||
tool_summary = f"{len(new_tools)} MCP tool(s) now available" if new_tools else "No MCP tools available"
|
||||
change_detail = ". ".join(change_parts) + ". " if change_parts else ""
|
||||
reload_msg = {
|
||||
"role": "user",
|
||||
"content": f"[SYSTEM: MCP servers have been reloaded. {change_detail}{tool_summary}. The tool list for this conversation has been updated accordingly.]",
|
||||
}
|
||||
try:
|
||||
session_entry = self.session_store.get_or_create_session(event.source)
|
||||
self.session_store.append_to_transcript(
|
||||
session_entry.session_id, reload_msg
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort; don't fail the reload over a transcript write
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("MCP reload failed: %s", e)
|
||||
return f"❌ MCP reload failed: {e}"
|
||||
|
||||
def _set_session_env(self, context: SessionContext) -> None:
|
||||
"""Set environment variables for the current session."""
|
||||
os.environ["HERMES_SESSION_PLATFORM"] = context.source.platform.value
|
||||
|
|
@ -1671,7 +1759,7 @@ class GatewayRunner:
|
|||
progress_queue = queue.Queue() if tool_progress_enabled else None
|
||||
last_tool = [None] # Mutable container for tracking in closure
|
||||
|
||||
def progress_callback(tool_name: str, preview: str = None):
|
||||
def progress_callback(tool_name: str, preview: str = None, args: dict = None):
|
||||
"""Callback invoked by agent when a tool is called."""
|
||||
if not progress_queue:
|
||||
return
|
||||
|
|
@ -1691,6 +1779,7 @@ class GatewayRunner:
|
|||
"write_file": "✍️",
|
||||
"patch": "🔧",
|
||||
"search": "🔎",
|
||||
"search_files": "🔎",
|
||||
"list_directory": "📂",
|
||||
"image_generate": "🎨",
|
||||
"text_to_speech": "🔊",
|
||||
|
|
@ -1716,14 +1805,28 @@ class GatewayRunner:
|
|||
"schedule_cronjob": "⏰",
|
||||
"list_cronjobs": "⏰",
|
||||
"remove_cronjob": "⏰",
|
||||
"execute_code": "🐍",
|
||||
"delegate_task": "🔀",
|
||||
"clarify": "❓",
|
||||
"skill_manage": "📝",
|
||||
}
|
||||
emoji = tool_emojis.get(tool_name, "⚙️")
|
||||
|
||||
# Verbose mode: show detailed arguments
|
||||
if progress_mode == "verbose" and args:
|
||||
import json as _json
|
||||
args_str = _json.dumps(args, ensure_ascii=False, default=str)
|
||||
if len(args_str) > 200:
|
||||
args_str = args_str[:197] + "..."
|
||||
msg = f"{emoji} {tool_name}({list(args.keys())})\n{args_str}"
|
||||
progress_queue.put(msg)
|
||||
return
|
||||
|
||||
if preview:
|
||||
# Truncate preview to keep messages clean
|
||||
if len(preview) > 40:
|
||||
preview = preview[:37] + "..."
|
||||
msg = f"{emoji} {tool_name}... \"{preview}\""
|
||||
if len(preview) > 80:
|
||||
preview = preview[:77] + "..."
|
||||
msg = f"{emoji} {tool_name}: \"{preview}\""
|
||||
else:
|
||||
msg = f"{emoji} {tool_name}..."
|
||||
|
||||
|
|
@ -1837,6 +1940,7 @@ class GatewayRunner:
|
|||
"tools": [],
|
||||
}
|
||||
|
||||
pr = self._provider_routing
|
||||
agent = AIAgent(
|
||||
model=model,
|
||||
**runtime_kwargs,
|
||||
|
|
@ -1847,6 +1951,12 @@ class GatewayRunner:
|
|||
ephemeral_system_prompt=combined_ephemeral or None,
|
||||
prefill_messages=self._prefill_messages or None,
|
||||
reasoning_config=self._reasoning_config,
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
provider_require_parameters=pr.get("require_parameters", False),
|
||||
provider_data_collection=pr.get("data_collection"),
|
||||
session_id=session_id,
|
||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
|
||||
|
|
@ -2194,7 +2304,14 @@ async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
|
|||
# Stop cron ticker cleanly
|
||||
cron_stop.set()
|
||||
cron_thread.join(timeout=5)
|
||||
|
||||
|
||||
# Close MCP server connections
|
||||
try:
|
||||
from tools.mcp_tool import shutdown_mcp_servers
|
||||
shutdown_mcp_servers()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class SessionSource:
|
|||
user_id: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
thread_id: Optional[str] = None # For forum topics, Discord threads, etc.
|
||||
chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack)
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
|
|
@ -75,6 +76,7 @@ class SessionSource:
|
|||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"thread_id": self.thread_id,
|
||||
"chat_topic": self.chat_topic,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -87,6 +89,7 @@ class SessionSource:
|
|||
user_id=data.get("user_id"),
|
||||
user_name=data.get("user_name"),
|
||||
thread_id=data.get("thread_id"),
|
||||
chat_topic=data.get("chat_topic"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -154,6 +157,10 @@ def build_session_context_prompt(context: SessionContext) -> str:
|
|||
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
|
||||
else:
|
||||
lines.append(f"**Source:** {platform_name} ({context.source.description})")
|
||||
|
||||
# Channel topic (if available - provides context about the channel's purpose)
|
||||
if context.source.chat_topic:
|
||||
lines.append(f"**Channel Topic:** {context.source.chat_topic}")
|
||||
|
||||
# User identity (especially useful for WhatsApp where multiple people DM)
|
||||
if context.source.user_name:
|
||||
|
|
@ -567,6 +574,34 @@ class SessionStore:
|
|||
with open(transcript_path, "a") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
|
||||
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Replace the entire transcript for a session with new messages.
|
||||
|
||||
Used by /retry, /undo, and /compress to persist modified conversation history.
|
||||
Rewrites both SQLite and legacy JSONL storage.
|
||||
"""
|
||||
# SQLite: clear old messages and re-insert
|
||||
if self._db:
|
||||
try:
|
||||
self._db.clear_messages(session_id)
|
||||
for msg in messages:
|
||||
self._db.append_message(
|
||||
session_id=session_id,
|
||||
role=msg.get("role", "unknown"),
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
|
||||
# JSONL: overwrite the file
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "w") as f:
|
||||
for msg in messages:
|
||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||
|
||||
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Load all messages from a session's transcript."""
|
||||
# Try SQLite first
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue