Merge PR #291: feat: add MCP (Model Context Protocol) client support
Authored by 0xbyt4. Adds MCP client with official SDK, direct tool registration, auto-injection into hermes-* toolsets, and graceful degradation.
This commit is contained in:
commit
468b7fdbad
7 changed files with 1299 additions and 2 deletions
496
tools/mcp_tool.py
Normal file
496
tools/mcp_tool.py
Normal file
|
|
@ -0,0 +1,496 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
MCP (Model Context Protocol) Client Support
|
||||
|
||||
Connects to external MCP servers via stdio transport, discovers their tools,
|
||||
and registers them into the hermes-agent tool registry so the agent can call
|
||||
them like any built-in tool.
|
||||
|
||||
Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key.
|
||||
The ``mcp`` Python package is optional -- if not installed, this module is a
|
||||
no-op and logs a debug message.
|
||||
|
||||
Example config::
|
||||
|
||||
mcp_servers:
|
||||
filesystem:
|
||||
command: "npx"
|
||||
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
||||
env: {}
|
||||
github:
|
||||
command: "npx"
|
||||
args: ["-y", "@modelcontextprotocol/server-github"]
|
||||
env:
|
||||
GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..."
|
||||
|
||||
Architecture:
|
||||
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
|
||||
Each MCP server runs as a long-lived asyncio Task on this loop, keeping
|
||||
its ``async with stdio_client(...)`` context alive. Tool call coroutines
|
||||
are scheduled onto the loop via ``run_coroutine_threadsafe()``.
|
||||
|
||||
On shutdown, each server Task is signalled to exit its ``async with``
|
||||
block, ensuring the anyio cancel-scope cleanup happens in the *same*
|
||||
Task that opened the connection (required by anyio).
|
||||
|
||||
Thread safety:
|
||||
_servers and _mcp_loop/_mcp_thread are accessed from both the MCP
|
||||
background thread and caller threads. All mutations are protected by
|
||||
_lock so the code is safe regardless of GIL presence (e.g. Python 3.13+
|
||||
free-threading).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graceful import -- MCP SDK is an optional dependency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MCP_AVAILABLE = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
_MCP_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.debug("mcp package not installed -- MCP tool support disabled")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server task -- each MCP server lives in one long-lived asyncio Task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MCPServerTask:
|
||||
"""Manages a single MCP server connection in a dedicated asyncio Task.
|
||||
|
||||
The entire connection lifecycle (connect, discover, serve, disconnect)
|
||||
runs inside one asyncio Task so that anyio cancel-scopes created by
|
||||
``stdio_client`` are entered and exited in the same Task context.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"name", "session",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.session: Optional[Any] = None
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._ready = asyncio.Event()
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._tools: list = []
|
||||
self._error: Optional[Exception] = None
|
||||
|
||||
async def run(self, config: dict):
|
||||
"""Long-lived coroutine: connect, discover tools, wait, disconnect."""
|
||||
command = config.get("command")
|
||||
args = config.get("args", [])
|
||||
env = config.get("env")
|
||||
|
||||
if not command:
|
||||
self._error = ValueError(
|
||||
f"MCP server '{self.name}' has no 'command' in config"
|
||||
)
|
||||
self._ready.set()
|
||||
return
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=command,
|
||||
args=args,
|
||||
env=env if env else None,
|
||||
)
|
||||
|
||||
try:
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
|
||||
tools_result = await session.list_tools()
|
||||
self._tools = (
|
||||
tools_result.tools
|
||||
if hasattr(tools_result, "tools")
|
||||
else []
|
||||
)
|
||||
|
||||
# Signal that connection is ready
|
||||
self._ready.set()
|
||||
|
||||
# Block until shutdown is requested -- this keeps the
|
||||
# async-with contexts alive on THIS Task.
|
||||
await self._shutdown_event.wait()
|
||||
except Exception as exc:
|
||||
self._error = exc
|
||||
self._ready.set()
|
||||
finally:
|
||||
self.session = None
|
||||
|
||||
async def start(self, config: dict):
|
||||
"""Create the background Task and wait until ready (or failed)."""
|
||||
self._task = asyncio.ensure_future(self.run(config))
|
||||
await self._ready.wait()
|
||||
if self._error:
|
||||
raise self._error
|
||||
|
||||
async def shutdown(self):
|
||||
"""Signal the Task to exit and wait for clean resource teardown."""
|
||||
self._shutdown_event.set()
|
||||
if self._task and not self._task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._task, timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"MCP server '%s' shutdown timed out, cancelling task",
|
||||
self.name,
|
||||
)
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.session = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_servers: Dict[str, MCPServerTask] = {}
|
||||
|
||||
# Dedicated event loop running in a background daemon thread.
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
|
||||
# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access.
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def _ensure_mcp_loop():
|
||||
"""Start the background event loop thread if not already running."""
|
||||
global _mcp_loop, _mcp_thread
|
||||
with _lock:
|
||||
if _mcp_loop is not None and _mcp_loop.is_running():
|
||||
return
|
||||
_mcp_loop = asyncio.new_event_loop()
|
||||
_mcp_thread = threading.Thread(
|
||||
target=_mcp_loop.run_forever,
|
||||
name="mcp-event-loop",
|
||||
daemon=True,
|
||||
)
|
||||
_mcp_thread.start()
|
||||
|
||||
|
||||
def _run_on_mcp_loop(coro, timeout: float = 30):
|
||||
"""Schedule a coroutine on the MCP event loop and block until done."""
|
||||
with _lock:
|
||||
loop = _mcp_loop
|
||||
if loop is None or not loop.is_running():
|
||||
raise RuntimeError("MCP event loop is not running")
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_mcp_config() -> Dict[str, dict]:
|
||||
"""Read ``mcp_servers`` from the Hermes config file.
|
||||
|
||||
Returns a dict of ``{server_name: {command, args, env}}`` or empty dict.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
return {}
|
||||
return servers
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load MCP config: %s", exc)
|
||||
return {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server connection helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _connect_server(name: str, config: dict) -> MCPServerTask:
|
||||
"""Create an MCPServerTask, start it, and return when ready.
|
||||
|
||||
The server Task keeps the subprocess alive in the background.
|
||||
Call ``server.shutdown()`` (on the same event loop) to tear it down.
|
||||
|
||||
Raises:
|
||||
ValueError: if ``command`` is missing from *config*.
|
||||
Exception: on connection or initialization failure.
|
||||
"""
|
||||
server = MCPServerTask(name)
|
||||
await server.start(config)
|
||||
return server
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Handler / check-fn factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_handler(server_name: str, tool_name: str):
|
||||
"""Return a sync handler that calls an MCP tool via the background loop.
|
||||
|
||||
The handler conforms to the registry's dispatch interface:
|
||||
``handler(args_dict, **kwargs) -> str``
|
||||
"""
|
||||
|
||||
def _handler(args: dict, **kwargs) -> str:
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
})
|
||||
|
||||
async def _call():
|
||||
result = await server.session.call_tool(tool_name, arguments=args)
|
||||
# MCP CallToolResult has .content (list of content blocks) and .isError
|
||||
if result.isError:
|
||||
error_text = ""
|
||||
for block in (result.content or []):
|
||||
if hasattr(block, "text"):
|
||||
error_text += block.text
|
||||
return json.dumps({"error": error_text or "MCP tool returned an error"})
|
||||
|
||||
# Collect text from content blocks
|
||||
parts: List[str] = []
|
||||
for block in (result.content or []):
|
||||
if hasattr(block, "text"):
|
||||
parts.append(block.text)
|
||||
return json.dumps({"result": "\n".join(parts) if parts else ""})
|
||||
|
||||
try:
|
||||
return _run_on_mcp_loop(_call(), timeout=120)
|
||||
except Exception as exc:
|
||||
logger.error("MCP tool %s/%s call failed: %s", server_name, tool_name, exc)
|
||||
return json.dumps({"error": f"MCP call failed: {type(exc).__name__}: {exc}"})
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
def _make_check_fn(server_name: str):
|
||||
"""Return a check function that verifies the MCP connection is alive."""
|
||||
|
||||
def _check() -> bool:
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
return server is not None and server.session is not None
|
||||
|
||||
return _check
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discovery & registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _convert_mcp_schema(server_name: str, mcp_tool) -> dict:
|
||||
"""Convert an MCP tool listing to the Hermes registry schema format.
|
||||
|
||||
Args:
|
||||
server_name: The logical server name for prefixing.
|
||||
mcp_tool: An MCP ``Tool`` object with ``.name``, ``.description``,
|
||||
and ``.inputSchema``.
|
||||
|
||||
Returns:
|
||||
A dict suitable for ``registry.register(schema=...)``.
|
||||
"""
|
||||
# Sanitize: replace hyphens and dots with underscores for LLM API compatibility
|
||||
safe_tool_name = mcp_tool.name.replace("-", "_").replace(".", "_")
|
||||
safe_server_name = server_name.replace("-", "_").replace(".", "_")
|
||||
prefixed_name = f"mcp_{safe_server_name}_{safe_tool_name}"
|
||||
return {
|
||||
"name": prefixed_name,
|
||||
"description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}",
|
||||
"parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _existing_tool_names() -> List[str]:
|
||||
"""Return tool names for all currently connected servers."""
|
||||
names: List[str] = []
|
||||
for sname, server in _servers.items():
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(sname, mcp_tool)
|
||||
names.append(schema["name"])
|
||||
return names
|
||||
|
||||
|
||||
async def _discover_and_register_server(name: str, config: dict) -> List[str]:
|
||||
"""Connect to a single MCP server, discover tools, and register them.
|
||||
|
||||
Returns list of registered tool names.
|
||||
"""
|
||||
from tools.registry import registry
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
server = await _connect_server(name, config)
|
||||
with _lock:
|
||||
_servers[name] = server
|
||||
|
||||
registered_names: List[str] = []
|
||||
toolset_name = f"mcp-{name}"
|
||||
|
||||
for mcp_tool in server._tools:
|
||||
schema = _convert_mcp_schema(name, mcp_tool)
|
||||
tool_name_prefixed = schema["name"]
|
||||
|
||||
registry.register(
|
||||
name=tool_name_prefixed,
|
||||
toolset=toolset_name,
|
||||
schema=schema,
|
||||
handler=_make_tool_handler(name, mcp_tool.name),
|
||||
check_fn=_make_check_fn(name),
|
||||
is_async=False,
|
||||
description=schema["description"],
|
||||
)
|
||||
registered_names.append(tool_name_prefixed)
|
||||
|
||||
# Create a custom toolset so these tools are discoverable
|
||||
if registered_names:
|
||||
create_custom_toolset(
|
||||
name=toolset_name,
|
||||
description=f"MCP tools from {name} server",
|
||||
tools=registered_names,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"MCP server '%s': registered %d tool(s): %s",
|
||||
name, len(registered_names), ", ".join(registered_names),
|
||||
)
|
||||
return registered_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def discover_mcp_tools() -> List[str]:
|
||||
"""Entry point: load config, connect to MCP servers, register tools.
|
||||
|
||||
Called from ``model_tools._discover_tools()``. Safe to call even when
|
||||
the ``mcp`` package is not installed (returns empty list).
|
||||
|
||||
Idempotent for already-connected servers. If some servers failed on a
|
||||
previous call, only the missing ones are retried.
|
||||
|
||||
Returns:
|
||||
List of all registered MCP tool names.
|
||||
"""
|
||||
if not _MCP_AVAILABLE:
|
||||
logger.debug("MCP SDK not available -- skipping MCP tool discovery")
|
||||
return []
|
||||
|
||||
servers = _load_mcp_config()
|
||||
if not servers:
|
||||
logger.debug("No MCP servers configured")
|
||||
return []
|
||||
|
||||
# Only attempt servers that aren't already connected
|
||||
with _lock:
|
||||
new_servers = {k: v for k, v in servers.items() if k not in _servers}
|
||||
|
||||
if not new_servers:
|
||||
return _existing_tool_names()
|
||||
|
||||
# Start the background event loop for MCP connections
|
||||
_ensure_mcp_loop()
|
||||
|
||||
all_tools: List[str] = []
|
||||
|
||||
async def _discover_all():
|
||||
for name, cfg in new_servers.items():
|
||||
try:
|
||||
registered = await _discover_and_register_server(name, cfg)
|
||||
all_tools.extend(registered)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to connect to MCP server '%s': %s", name, exc)
|
||||
|
||||
_run_on_mcp_loop(_discover_all(), timeout=60)
|
||||
|
||||
if all_tools:
|
||||
# Dynamically inject into all hermes-* platform toolsets
|
||||
from toolsets import TOOLSETS
|
||||
for ts_name, ts in TOOLSETS.items():
|
||||
if ts_name.startswith("hermes-"):
|
||||
for tool_name in all_tools:
|
||||
if tool_name not in ts["tools"]:
|
||||
ts["tools"].append(tool_name)
|
||||
|
||||
# Return ALL registered tools (existing + newly discovered)
|
||||
return _existing_tool_names()
|
||||
|
||||
|
||||
def shutdown_mcp_servers():
|
||||
"""Close all MCP server connections and stop the background loop.
|
||||
|
||||
Each server Task is signalled to exit its ``async with`` block so that
|
||||
the anyio cancel-scope cleanup happens in the same Task that opened it.
|
||||
All servers are shut down in parallel via ``asyncio.gather``.
|
||||
"""
|
||||
with _lock:
|
||||
if not _servers:
|
||||
# No servers -- just stop the loop. _stop_mcp_loop() also
|
||||
# acquires _lock, so we must release it first.
|
||||
pass
|
||||
else:
|
||||
servers_snapshot = list(_servers.values())
|
||||
|
||||
# Fast path: nothing to shut down.
|
||||
if not _servers:
|
||||
_stop_mcp_loop()
|
||||
return
|
||||
|
||||
async def _shutdown():
|
||||
results = await asyncio.gather(
|
||||
*(server.shutdown() for server in servers_snapshot),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for server, result in zip(servers_snapshot, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.debug(
|
||||
"Error closing MCP server '%s': %s", server.name, result,
|
||||
)
|
||||
with _lock:
|
||||
_servers.clear()
|
||||
|
||||
with _lock:
|
||||
loop = _mcp_loop
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(_shutdown(), loop)
|
||||
future.result(timeout=15)
|
||||
except Exception as exc:
|
||||
logger.debug("Error during MCP shutdown: %s", exc)
|
||||
|
||||
_stop_mcp_loop()
|
||||
|
||||
|
||||
def _stop_mcp_loop():
|
||||
"""Stop the background event loop and join its thread."""
|
||||
global _mcp_loop, _mcp_thread
|
||||
with _lock:
|
||||
loop = _mcp_loop
|
||||
thread = _mcp_thread
|
||||
_mcp_loop = None
|
||||
_mcp_thread = None
|
||||
if loop is not None:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
if thread is not None:
|
||||
thread.join(timeout=5)
|
||||
loop.close()
|
||||
Loading…
Add table
Add a link
Reference in a new issue