From 3c252ae44b524ed20861e681b14c4d66a6fb4bdf Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:03:14 +0300 Subject: [PATCH 1/6] feat: add MCP (Model Context Protocol) client support Connect to external MCP servers via stdio transport, discover their tools at startup, and register them into the hermes-agent tool registry. - New tools/mcp_tool.py: config loading, server connection via background event loop, tool handler factories, discovery, and graceful shutdown - model_tools.py: trigger MCP discovery after built-in tool imports - cli.py: call shutdown_mcp_servers in _run_cleanup - pyproject.toml: add mcp>=1.2.0 as optional dependency - 27 unit tests covering config, schema conversion, handlers, registration, SDK interaction, toolset injection, graceful fallback, and shutdown Config format (in ~/.hermes/config.yaml): mcp_servers: filesystem: command: "npx" args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] --- cli.py | 5 + model_tools.py | 7 + pyproject.toml | 2 + tests/tools/test_mcp_tool.py | 588 +++++++++++++++++++++++++++++++++++ tools/mcp_tool.py | 380 ++++++++++++++++++++++ uv.lock | 82 ++++- 6 files changed, 1063 insertions(+), 1 deletion(-) create mode 100644 tests/tools/test_mcp_tool.py create mode 100644 tools/mcp_tool.py diff --git a/cli.py b/cli.py index faa6586d..a2519460 100755 --- a/cli.py +++ b/cli.py @@ -386,6 +386,11 @@ def _run_cleanup(): _cleanup_all_browsers() except Exception: pass + try: + from tools.mcp_tool import shutdown_mcp_servers + shutdown_mcp_servers() + except Exception: + pass # ============================================================================ # ASCII Art & Branding diff --git a/model_tools.py b/model_tools.py index 036bb34b..8da3d67e 100644 --- a/model_tools.py +++ b/model_tools.py @@ -105,6 +105,13 @@ def _discover_tools(): _discover_tools() +# MCP tool discovery (external MCP servers from config) +try: + from tools.mcp_tool import discover_mcp_tools + discover_mcp_tools() +except Exception as e: + logger.debug("MCP tool discovery failed: %s", e) + # ============================================================================= # Backward-compat constants (built once after discovery) diff --git a/pyproject.toml b/pyproject.toml index 152b4730..2f241b3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ cli = ["simple-term-menu"] tts-premium = ["elevenlabs"] pty = ["ptyprocess>=0.7.0"] honcho = ["honcho-ai>=2.0.1"] +mcp = ["mcp>=1.2.0"] all = [ "hermes-agent[modal]", "hermes-agent[messaging]", @@ -57,6 +58,7 @@ all = [ "hermes-agent[slack]", "hermes-agent[pty]", "hermes-agent[honcho]", + "hermes-agent[mcp]", ] [project.scripts] diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py new file mode 100644 index 00000000..caaffd48 --- /dev/null +++ b/tests/tools/test_mcp_tool.py @@ -0,0 +1,588 @@ +"""Tests for the MCP (Model Context Protocol) client support. + +All tests use mocks -- no real MCP servers or subprocesses are started. +""" + +import asyncio +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_mcp_tool(name="read_file", description="Read a file", input_schema=None): + """Create a fake MCP Tool object matching the SDK interface.""" + tool = SimpleNamespace() + tool.name = name + tool.description = description + tool.inputSchema = input_schema or { + "type": "object", + "properties": { + "path": {"type": "string", "description": "File path"}, + }, + "required": ["path"], + } + return tool + + +def _make_call_result(text="file contents here", is_error=False): + """Create a fake MCP CallToolResult.""" + block = SimpleNamespace(text=text) + return SimpleNamespace(content=[block], isError=is_error) + + +# --------------------------------------------------------------------------- +# Config loading +# --------------------------------------------------------------------------- + +class TestLoadMCPConfig: + def test_no_config_returns_empty(self): + """No mcp_servers key in config -> empty dict.""" + with patch("tools.mcp_tool.load_config", create=True) as mock_lc: + # Patch the actual import inside the function + with patch("hermes_cli.config.load_config", return_value={"model": "test"}): + from tools.mcp_tool import _load_mcp_config + result = _load_mcp_config() + assert result == {} + + def test_valid_config_parsed(self): + """Valid mcp_servers config is returned as-is.""" + servers = { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + "env": {}, + } + } + with patch("hermes_cli.config.load_config", return_value={"mcp_servers": servers}): + from tools.mcp_tool import _load_mcp_config + result = _load_mcp_config() + assert "filesystem" in result + assert result["filesystem"]["command"] == "npx" + + def test_mcp_servers_not_dict_returns_empty(self): + """mcp_servers set to non-dict value -> empty dict.""" + with patch("hermes_cli.config.load_config", return_value={"mcp_servers": "invalid"}): + from tools.mcp_tool import _load_mcp_config + result = _load_mcp_config() + assert result == {} + + +# --------------------------------------------------------------------------- +# Schema conversion +# --------------------------------------------------------------------------- + +class TestSchemaConversion: + def test_converts_mcp_tool_to_hermes_schema(self): + from tools.mcp_tool import _convert_mcp_schema + + mcp_tool = _make_mcp_tool(name="read_file", description="Read a file") + schema = _convert_mcp_schema("filesystem", mcp_tool) + + assert schema["name"] == "mcp_filesystem_read_file" + assert schema["description"] == "Read a file" + assert "properties" in schema["parameters"] + + def test_empty_input_schema_gets_default(self): + from tools.mcp_tool import _convert_mcp_schema + + mcp_tool = _make_mcp_tool(name="ping", description="Ping", input_schema=None) + mcp_tool.inputSchema = None + schema = _convert_mcp_schema("test", mcp_tool) + + assert schema["parameters"]["type"] == "object" + assert schema["parameters"]["properties"] == {} + + def test_tool_name_prefix_format(self): + from tools.mcp_tool import _convert_mcp_schema + + mcp_tool = _make_mcp_tool(name="list_dir") + schema = _convert_mcp_schema("my_server", mcp_tool) + + assert schema["name"] == "mcp_my_server_list_dir" + + def test_hyphens_sanitized_to_underscores(self): + """Hyphens in tool/server names are replaced with underscores for LLM compat.""" + from tools.mcp_tool import _convert_mcp_schema + + mcp_tool = _make_mcp_tool(name="get-sum") + schema = _convert_mcp_schema("my-server", mcp_tool) + + assert schema["name"] == "mcp_my_server_get_sum" + assert "-" not in schema["name"] + + +# --------------------------------------------------------------------------- +# Check function +# --------------------------------------------------------------------------- + +class TestCheckFunction: + def test_disconnected_returns_false(self): + from tools.mcp_tool import _make_check_fn, _connections + + # Ensure no connection exists + _connections.pop("test_server", None) + check = _make_check_fn("test_server") + assert check() is False + + def test_connected_returns_true(self): + from tools.mcp_tool import _make_check_fn, _connections, MCPConnection + + conn = MCPConnection( + server_name="test_server", + session=MagicMock(), + stack=MagicMock(), + ) + _connections["test_server"] = conn + try: + check = _make_check_fn("test_server") + assert check() is True + finally: + _connections.pop("test_server", None) + + def test_session_none_returns_false(self): + from tools.mcp_tool import _make_check_fn, _connections, MCPConnection + + conn = MCPConnection( + server_name="test_server", + session=None, + stack=MagicMock(), + ) + _connections["test_server"] = conn + try: + check = _make_check_fn("test_server") + assert check() is False + finally: + _connections.pop("test_server", None) + + +# --------------------------------------------------------------------------- +# Tool handler (async) +# --------------------------------------------------------------------------- + +class TestToolHandler: + """Tool handlers are sync functions that schedule work on the MCP loop.""" + + def _patch_mcp_loop(self, coro_side_effect=None): + """Return a patch for _run_on_mcp_loop that runs the coroutine directly.""" + def fake_run(coro, timeout=30): + return asyncio.get_event_loop().run_until_complete(coro) + if coro_side_effect: + return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect) + return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run) + + def test_successful_call(self): + from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection + + mock_session = MagicMock() + mock_session.call_tool = AsyncMock( + return_value=_make_call_result("hello world", is_error=False) + ) + conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock()) + _connections["test_srv"] = conn + + try: + handler = _make_tool_handler("test_srv", "greet") + with self._patch_mcp_loop(): + result = json.loads(handler({"name": "world"})) + assert result["result"] == "hello world" + mock_session.call_tool.assert_called_once_with("greet", arguments={"name": "world"}) + finally: + _connections.pop("test_srv", None) + + def test_mcp_error_result(self): + from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection + + mock_session = MagicMock() + mock_session.call_tool = AsyncMock( + return_value=_make_call_result("something went wrong", is_error=True) + ) + conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock()) + _connections["test_srv"] = conn + + try: + handler = _make_tool_handler("test_srv", "fail_tool") + with self._patch_mcp_loop(): + result = json.loads(handler({})) + assert "error" in result + assert "something went wrong" in result["error"] + finally: + _connections.pop("test_srv", None) + + def test_disconnected_server(self): + from tools.mcp_tool import _make_tool_handler, _connections + + _connections.pop("ghost", None) + handler = _make_tool_handler("ghost", "any_tool") + # Disconnected check happens before _run_on_mcp_loop, no patch needed + result = json.loads(handler({})) + assert "error" in result + assert "not connected" in result["error"] + + def test_exception_during_call(self): + from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection + + mock_session = MagicMock() + mock_session.call_tool = AsyncMock(side_effect=RuntimeError("connection lost")) + conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock()) + _connections["test_srv"] = conn + + try: + handler = _make_tool_handler("test_srv", "broken_tool") + with self._patch_mcp_loop(): + result = json.loads(handler({})) + assert "error" in result + assert "connection lost" in result["error"] + finally: + _connections.pop("test_srv", None) + + +# --------------------------------------------------------------------------- +# Tool registration (discovery + register) +# --------------------------------------------------------------------------- + +class TestDiscoverAndRegister: + def test_tools_registered_in_registry(self): + """_discover_and_register_server registers tools with correct names.""" + from tools.registry import ToolRegistry, registry as real_registry + from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection + + mock_registry = ToolRegistry() + mock_tools = [ + _make_mcp_tool("read_file", "Read a file"), + _make_mcp_tool("write_file", "Write a file"), + ] + + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=mock_tools) + ) + + async def fake_connect(name, config): + return MCPConnection(name, session=mock_session, stack=MagicMock()) + + with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.registry.registry", mock_registry): + registered = asyncio.run( + _discover_and_register_server("fs", {"command": "npx", "args": []}) + ) + + assert "mcp_fs_read_file" in registered + assert "mcp_fs_write_file" in registered + assert "mcp_fs_read_file" in mock_registry.get_all_tool_names() + assert "mcp_fs_write_file" in mock_registry.get_all_tool_names() + + _connections.pop("fs", None) + + def test_toolset_created(self): + """A custom toolset is created for the MCP server.""" + from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection + + mock_tools = [_make_mcp_tool("ping", "Ping")] + + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=mock_tools) + ) + + async def fake_connect(name, config): + return MCPConnection(name, session=mock_session, stack=MagicMock()) + + mock_create = MagicMock() + with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("toolsets.create_custom_toolset", mock_create): + asyncio.run( + _discover_and_register_server("myserver", {"command": "test"}) + ) + + mock_create.assert_called_once() + call_kwargs = mock_create.call_args + assert call_kwargs[1]["name"] == "mcp-myserver" or call_kwargs[0][0] == "mcp-myserver" + + _connections.pop("myserver", None) + + def test_schema_format_correct(self): + """Registered schemas have the correct format.""" + from tools.registry import ToolRegistry, registry as real_registry + from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection + + mock_registry = ToolRegistry() + mock_tools = [_make_mcp_tool("do_thing", "Do something")] + + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=mock_tools) + ) + + async def fake_connect(name, config): + return MCPConnection(name, session=mock_session, stack=MagicMock()) + + with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.registry.registry", mock_registry): + asyncio.run( + _discover_and_register_server("srv", {"command": "test"}) + ) + + entry = mock_registry._tools.get("mcp_srv_do_thing") + assert entry is not None + assert entry.schema["name"] == "mcp_srv_do_thing" + assert "parameters" in entry.schema + assert entry.is_async is False + assert entry.toolset == "mcp-srv" + + _connections.pop("srv", None) + + +# --------------------------------------------------------------------------- +# _connect_server (SDK interaction) +# --------------------------------------------------------------------------- + +class TestConnectServer: + def test_calls_sdk_with_correct_params(self): + """_connect_server creates StdioServerParameters and calls stdio_client.""" + from tools.mcp_tool import _connect_server, MCPConnection + + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + + mock_read = MagicMock() + mock_write = MagicMock() + + with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ + patch("tools.mcp_tool.stdio_client") as mock_stdio, \ + patch("tools.mcp_tool.ClientSession") as mock_cs, \ + patch("tools.mcp_tool.AsyncExitStack") as mock_stack_cls: + + mock_stack = MagicMock() + mock_stack.enter_async_context = AsyncMock( + side_effect=[(mock_read, mock_write), mock_session] + ) + mock_stack_cls.return_value = mock_stack + + conn = asyncio.run(_connect_server("test_srv", { + "command": "npx", + "args": ["-y", "some-server"], + "env": {"MY_KEY": "secret"}, + })) + + # StdioServerParameters called with correct values + mock_params.assert_called_once_with( + command="npx", + args=["-y", "some-server"], + env={"MY_KEY": "secret"}, + ) + # ClientSession created with the streams + mock_cs.assert_called_once_with(mock_read, mock_write) + # initialize() was called + mock_session.initialize.assert_called_once() + # Returned connection is valid + assert conn.server_name == "test_srv" + assert conn.session is mock_session + + def test_no_command_raises(self): + """Missing 'command' in config raises ValueError.""" + from tools.mcp_tool import _connect_server + + with pytest.raises(ValueError, match="no 'command'"): + asyncio.run(_connect_server("bad", {"args": []})) + + def test_empty_env_passed_as_none(self): + """Empty env dict is passed as None to StdioServerParameters.""" + from tools.mcp_tool import _connect_server + + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + + with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ + patch("tools.mcp_tool.stdio_client"), \ + patch("tools.mcp_tool.ClientSession", return_value=mock_session), \ + patch("tools.mcp_tool.AsyncExitStack") as mock_stack_cls: + + mock_stack = MagicMock() + mock_stack.enter_async_context = AsyncMock( + side_effect=[ + (MagicMock(), MagicMock()), + mock_session, + ] + ) + mock_stack_cls.return_value = mock_stack + + asyncio.run(_connect_server("srv", { + "command": "node", + "env": {}, + })) + + # Empty dict -> None + assert mock_params.call_args[1]["env"] is None or \ + mock_params.call_args.kwargs.get("env") is None + + +# --------------------------------------------------------------------------- +# discover_mcp_tools toolset injection +# --------------------------------------------------------------------------- + +class TestToolsetInjection: + def test_mcp_tools_added_to_platform_toolsets(self): + """Discovered MCP tools are injected into hermes-cli and platform toolsets.""" + from tools.mcp_tool import _connections, MCPConnection + + mock_tools = [_make_mcp_tool("list_files", "List files")] + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=mock_tools) + ) + + async def fake_connect(name, config): + return MCPConnection(name, session=mock_session, stack=MagicMock()) + + fake_toolsets = { + "hermes-cli": {"tools": ["terminal", "web_search"], "description": "CLI", "includes": []}, + "hermes-telegram": {"tools": ["terminal"], "description": "Telegram", "includes": []}, + } + fake_config = { + "fs": {"command": "npx", "args": []}, + } + + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ + patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.mcp_tool.TOOLSETS", fake_toolsets, create=True), \ + patch("toolsets.TOOLSETS", fake_toolsets): + from tools.mcp_tool import discover_mcp_tools + result = discover_mcp_tools() + + assert "mcp_fs_list_files" in result + assert "mcp_fs_list_files" in fake_toolsets["hermes-cli"]["tools"] + assert "mcp_fs_list_files" in fake_toolsets["hermes-telegram"]["tools"] + # Original tools preserved + assert "terminal" in fake_toolsets["hermes-cli"]["tools"] + + _connections.pop("fs", None) + + def test_server_connection_failure_skipped(self): + """If one server fails to connect, others still proceed.""" + from tools.mcp_tool import _connections, MCPConnection + + mock_tools = [_make_mcp_tool("ping", "Ping")] + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=mock_tools) + ) + + call_count = 0 + + async def flaky_connect(name, config): + nonlocal call_count + call_count += 1 + if name == "broken": + raise ConnectionError("cannot reach server") + return MCPConnection(name, session=mock_session, stack=MagicMock()) + + fake_config = { + "broken": {"command": "bad"}, + "good": {"command": "npx", "args": []}, + } + fake_toolsets = { + "hermes-cli": {"tools": [], "description": "CLI", "includes": []}, + } + + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ + patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \ + patch("toolsets.TOOLSETS", fake_toolsets): + from tools.mcp_tool import discover_mcp_tools + result = discover_mcp_tools() + + # Only good server's tool registered + assert "mcp_good_ping" in result + assert "mcp_broken_ping" not in result + assert call_count == 2 # Both were attempted + + _connections.pop("good", None) + + +# --------------------------------------------------------------------------- +# Graceful fallback +# --------------------------------------------------------------------------- + +class TestGracefulFallback: + def test_mcp_unavailable_returns_empty(self): + """When _MCP_AVAILABLE is False, discover_mcp_tools is a no-op.""" + with patch("tools.mcp_tool._MCP_AVAILABLE", False): + from tools.mcp_tool import discover_mcp_tools + result = discover_mcp_tools() + assert result == [] + + def test_no_servers_returns_empty(self): + """No MCP servers configured -> empty list.""" + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._load_mcp_config", return_value={}): + from tools.mcp_tool import discover_mcp_tools + result = discover_mcp_tools() + assert result == [] + + +# --------------------------------------------------------------------------- +# Shutdown +# --------------------------------------------------------------------------- + +class TestShutdown: + def test_no_connections_safe(self): + """shutdown_mcp_servers with no connections does nothing.""" + from tools.mcp_tool import shutdown_mcp_servers, _connections + + _connections.clear() + shutdown_mcp_servers() # Should not raise + + def test_shutdown_clears_connections(self): + """shutdown_mcp_servers closes stacks and clears the dict.""" + import tools.mcp_tool as mcp_mod + from tools.mcp_tool import shutdown_mcp_servers, _connections, MCPConnection + + _connections.clear() + mock_stack = MagicMock() + mock_stack.aclose = AsyncMock() + conn = MCPConnection("test", session=MagicMock(), stack=mock_stack) + _connections["test"] = conn + + # Start a real background loop so shutdown can schedule on it + mcp_mod._ensure_mcp_loop() + try: + shutdown_mcp_servers() + finally: + # _stop_mcp_loop is called by shutdown, but ensure cleanup + mcp_mod._mcp_loop = None + mcp_mod._mcp_thread = None + + assert len(_connections) == 0 + mock_stack.aclose.assert_called_once() + + def test_shutdown_handles_errors(self): + """shutdown_mcp_servers handles errors during close gracefully.""" + import tools.mcp_tool as mcp_mod + from tools.mcp_tool import shutdown_mcp_servers, _connections, MCPConnection + + _connections.clear() + mock_stack = MagicMock() + mock_stack.aclose = AsyncMock(side_effect=RuntimeError("close failed")) + conn = MCPConnection("broken", session=MagicMock(), stack=mock_stack) + _connections["broken"] = conn + + mcp_mod._ensure_mcp_loop() + try: + shutdown_mcp_servers() # Should not raise + finally: + mcp_mod._mcp_loop = None + mcp_mod._mcp_thread = None + + assert len(_connections) == 0 diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py new file mode 100644 index 00000000..eecbaa29 --- /dev/null +++ b/tools/mcp_tool.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 +""" +MCP (Model Context Protocol) Client Support + +Connects to external MCP servers via stdio transport, discovers their tools, +and registers them into the hermes-agent tool registry so the agent can call +them like any built-in tool. + +Configuration is read from ~/.hermes/config.yaml under the ``mcp_servers`` key. +The ``mcp`` Python package is optional -- if not installed, this module is a +no-op and logs a debug message. + +Example config:: + + mcp_servers: + filesystem: + command: "npx" + args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + env: {} + github: + command: "npx" + args: ["-y", "@modelcontextprotocol/server-github"] + env: + GITHUB_PERSONAL_ACCESS_TOKEN: "ghp_..." + +Architecture: + A dedicated background event loop (_mcp_loop) runs in a daemon thread. + All MCP connections live on this loop. Tool handlers schedule coroutines + onto it via run_coroutine_threadsafe(), so they work from any thread. +""" + +import asyncio +import json +import logging +import threading +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Graceful import -- MCP SDK is an optional dependency +# --------------------------------------------------------------------------- + +_MCP_AVAILABLE = False +try: + from mcp import ClientSession, StdioServerParameters + from mcp.client.stdio import stdio_client + from contextlib import AsyncExitStack + _MCP_AVAILABLE = True +except ImportError: + logger.debug("mcp package not installed -- MCP tool support disabled") + + +# --------------------------------------------------------------------------- +# Connection tracking +# --------------------------------------------------------------------------- + +class MCPConnection: + """Holds a live MCP server connection and its async resource stack.""" + + __slots__ = ("server_name", "session", "stack") + + def __init__(self, server_name: str, session: Any, stack: Any): + self.server_name = server_name + self.session: Optional[Any] = session + self.stack: Optional[Any] = stack + + +_connections: Dict[str, MCPConnection] = {} + +# Dedicated event loop running in a background daemon thread. +# All MCP async operations (connect, call_tool, shutdown) run here. +_mcp_loop: Optional[asyncio.AbstractEventLoop] = None +_mcp_thread: Optional[threading.Thread] = None + + +def _ensure_mcp_loop(): + """Start the background event loop thread if not already running.""" + global _mcp_loop, _mcp_thread + if _mcp_loop is not None and _mcp_loop.is_running(): + return + _mcp_loop = asyncio.new_event_loop() + _mcp_thread = threading.Thread( + target=_mcp_loop.run_forever, + name="mcp-event-loop", + daemon=True, + ) + _mcp_thread.start() + + +def _run_on_mcp_loop(coro, timeout: float = 30): + """Schedule a coroutine on the MCP event loop and block until done.""" + if _mcp_loop is None or not _mcp_loop.is_running(): + raise RuntimeError("MCP event loop is not running") + future = asyncio.run_coroutine_threadsafe(coro, _mcp_loop) + return future.result(timeout=timeout) + + +# --------------------------------------------------------------------------- +# Config loading +# --------------------------------------------------------------------------- + +def _load_mcp_config() -> Dict[str, dict]: + """Read ``mcp_servers`` from the Hermes config file. + + Returns a dict of ``{server_name: {command, args, env}}`` or empty dict. + """ + try: + from hermes_cli.config import load_config + config = load_config() + servers = config.get("mcp_servers") + if not servers or not isinstance(servers, dict): + return {} + return servers + except Exception as exc: + logger.debug("Failed to load MCP config: %s", exc) + return {} + + +# --------------------------------------------------------------------------- +# Server connection +# --------------------------------------------------------------------------- + +async def _connect_server(name: str, config: dict) -> MCPConnection: + """Start an MCP server subprocess and initialize a ClientSession. + + Args: + name: Logical server name (e.g. "filesystem"). + config: Dict with ``command``, ``args``, and optional ``env``. + + Returns: + An ``MCPConnection`` with a live session. + + Raises: + Exception on connection or initialization failure. + """ + command = config.get("command") + args = config.get("args", []) + env = config.get("env") + + if not command: + raise ValueError(f"MCP server '{name}' has no 'command' in config") + + server_params = StdioServerParameters( + command=command, + args=args, + env=env if env else None, + ) + + stack = AsyncExitStack() + stdio_transport = await stack.enter_async_context(stdio_client(server_params)) + read_stream, write_stream = stdio_transport + session = await stack.enter_async_context(ClientSession(read_stream, write_stream)) + await session.initialize() + + return MCPConnection(server_name=name, session=session, stack=stack) + + +# --------------------------------------------------------------------------- +# Handler / check-fn factories +# --------------------------------------------------------------------------- + +def _make_tool_handler(server_name: str, tool_name: str): + """Return a sync handler that calls an MCP tool via the background loop. + + The handler conforms to the registry's dispatch interface: + ``handler(args_dict, **kwargs) -> str`` + """ + + def _handler(args: dict, **kwargs) -> str: + conn = _connections.get(server_name) + if not conn or not conn.session: + return json.dumps({ + "error": f"MCP server '{server_name}' is not connected" + }) + + async def _call(): + result = await conn.session.call_tool(tool_name, arguments=args) + # MCP CallToolResult has .content (list of content blocks) and .isError + if result.isError: + error_text = "" + for block in (result.content or []): + if hasattr(block, "text"): + error_text += block.text + return json.dumps({"error": error_text or "MCP tool returned an error"}) + + # Collect text from content blocks + parts: List[str] = [] + for block in (result.content or []): + if hasattr(block, "text"): + parts.append(block.text) + return json.dumps({"result": "\n".join(parts) if parts else ""}) + + try: + return _run_on_mcp_loop(_call(), timeout=120) + except Exception as exc: + logger.error("MCP tool %s/%s call failed: %s", server_name, tool_name, exc) + return json.dumps({"error": f"MCP call failed: {type(exc).__name__}: {exc}"}) + + return _handler + + +def _make_check_fn(server_name: str): + """Return a check function that verifies the MCP connection is alive.""" + + def _check() -> bool: + conn = _connections.get(server_name) + return conn is not None and conn.session is not None + + return _check + + +# --------------------------------------------------------------------------- +# Discovery & registration +# --------------------------------------------------------------------------- + +def _convert_mcp_schema(server_name: str, mcp_tool) -> dict: + """Convert an MCP tool listing to the Hermes registry schema format. + + Args: + server_name: The logical server name for prefixing. + mcp_tool: An MCP ``Tool`` object with ``.name``, ``.description``, + and ``.inputSchema``. + + Returns: + A dict suitable for ``registry.register(schema=...)``. + """ + # Sanitize: replace hyphens and dots with underscores for LLM API compatibility + safe_tool_name = mcp_tool.name.replace("-", "_").replace(".", "_") + safe_server_name = server_name.replace("-", "_").replace(".", "_") + prefixed_name = f"mcp_{safe_server_name}_{safe_tool_name}" + return { + "name": prefixed_name, + "description": mcp_tool.description or f"MCP tool {mcp_tool.name} from {server_name}", + "parameters": mcp_tool.inputSchema if mcp_tool.inputSchema else { + "type": "object", + "properties": {}, + }, + } + + +async def _discover_and_register_server(name: str, config: dict) -> List[str]: + """Connect to a single MCP server, discover tools, and register them. + + Returns list of registered tool names. + """ + from tools.registry import registry + from toolsets import create_custom_toolset + + conn = await _connect_server(name, config) + _connections[name] = conn + + # Discover tools + tools_result = await conn.session.list_tools() + tools = tools_result.tools if hasattr(tools_result, "tools") else [] + + registered_names: List[str] = [] + toolset_name = f"mcp-{name}" + + for mcp_tool in tools: + schema = _convert_mcp_schema(name, mcp_tool) + tool_name_prefixed = schema["name"] + + registry.register( + name=tool_name_prefixed, + toolset=toolset_name, + schema=schema, + handler=_make_tool_handler(name, mcp_tool.name), + check_fn=_make_check_fn(name), + is_async=False, + description=schema["description"], + ) + registered_names.append(tool_name_prefixed) + + # Create a custom toolset so these tools are discoverable + if registered_names: + create_custom_toolset( + name=toolset_name, + description=f"MCP tools from {name} server", + tools=registered_names, + ) + + logger.info( + "MCP server '%s': registered %d tool(s): %s", + name, len(registered_names), ", ".join(registered_names), + ) + return registered_names + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def discover_mcp_tools() -> List[str]: + """Entry point: load config, connect to MCP servers, register tools. + + Called from ``model_tools._discover_tools()``. Safe to call even when + the ``mcp`` package is not installed (returns empty list). + + Returns: + List of all registered MCP tool names. + """ + if not _MCP_AVAILABLE: + logger.debug("MCP SDK not available -- skipping MCP tool discovery") + return [] + + servers = _load_mcp_config() + if not servers: + logger.debug("No MCP servers configured") + return [] + + # Start the background event loop for MCP connections + _ensure_mcp_loop() + + all_tools: List[str] = [] + + async def _discover_all(): + for name, cfg in servers.items(): + try: + registered = await _discover_and_register_server(name, cfg) + all_tools.extend(registered) + except Exception as exc: + logger.warning("Failed to connect to MCP server '%s': %s", name, exc) + + _run_on_mcp_loop(_discover_all(), timeout=60) + + if all_tools: + # Add MCP tools to hermes-cli and other platform toolsets + from toolsets import TOOLSETS + for ts_name in ("hermes-cli", "hermes-telegram", "hermes-discord", + "hermes-whatsapp", "hermes-slack"): + ts = TOOLSETS.get(ts_name) + if ts: + for tool_name in all_tools: + if tool_name not in ts["tools"]: + ts["tools"].append(tool_name) + + return all_tools + + +def shutdown_mcp_servers(): + """Close all MCP server connections and stop the background loop.""" + global _mcp_loop, _mcp_thread + + if not _connections: + _stop_mcp_loop() + return + + async def _shutdown(): + for name, conn in list(_connections.items()): + try: + if conn.stack: + await conn.stack.aclose() + except Exception as exc: + logger.debug("Error closing MCP server '%s': %s", name, exc) + finally: + conn.session = None + conn.stack = None + _connections.clear() + + if _mcp_loop is not None and _mcp_loop.is_running(): + try: + future = asyncio.run_coroutine_threadsafe(_shutdown(), _mcp_loop) + future.result(timeout=10) + except Exception as exc: + logger.debug("Error during MCP shutdown: %s", exc) + + _stop_mcp_loop() + + +def _stop_mcp_loop(): + """Stop the background event loop and join its thread.""" + global _mcp_loop, _mcp_thread + if _mcp_loop is not None: + _mcp_loop.call_soon_threadsafe(_mcp_loop.stop) + if _mcp_thread is not None: + _mcp_thread.join(timeout=5) + _mcp_thread = None + _mcp_loop.close() + _mcp_loop = None diff --git a/uv.lock b/uv.lock index 54863389..a768b72c 100644 --- a/uv.lock +++ b/uv.lock @@ -1015,6 +1015,7 @@ all = [ { name = "discord-py" }, { name = "elevenlabs" }, { name = "honcho-ai" }, + { name = "mcp" }, { name = "ptyprocess" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -1037,6 +1038,9 @@ dev = [ honcho = [ { name = "honcho-ai" }, ] +mcp = [ + { name = "mcp" }, +] messaging = [ { name = "aiohttp" }, { name = "discord-py" }, @@ -1072,6 +1076,7 @@ requires-dist = [ { name = "hermes-agent", extras = ["cron"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["dev"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["honcho"], marker = "extra == 'all'" }, + { name = "hermes-agent", extras = ["mcp"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["messaging"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["modal"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["pty"], marker = "extra == 'all'" }, @@ -1081,6 +1086,7 @@ requires-dist = [ { name = "httpx" }, { name = "jinja2" }, { name = "litellm", specifier = ">=1.75.5" }, + { name = "mcp", marker = "extra == 'mcp'", specifier = ">=1.2.0" }, { name = "openai" }, { name = "platformdirs" }, { name = "prompt-toolkit" }, @@ -1103,7 +1109,7 @@ requires-dist = [ { name = "tenacity" }, { name = "typer" }, ] -provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "all"] +provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "mcp", "all"] [[package]] name = "hf-xet" @@ -1522,6 +1528,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" }, ] +[[package]] +name = "mcp" +version = "1.26.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/6d/62e76bbb8144d6ed86e202b5edd8a4cb631e7c8130f3f4893c3f90262b10/mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66", size = 608005, upload-time = "2026-01-24T19:40:32.468Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -2114,6 +2145,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.13.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/6d/fffca34caecc4a3f97bda81b2098da5e8ab7efc9a66e819074a11955d87e/pydantic_settings-2.13.1.tar.gz", hash = "sha256:b4c11847b15237fb0171e1462bf540e294affb9b86db4d9aa5c01730bdbe4025", size = 223826, upload-time = "2026-02-19T13:45:08.055Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -2221,6 +2266,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, ] +[[package]] +name = "pywin32" +version = "311" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/40/44efbb0dfbd33aca6a6483191dae0716070ed99e2ecb0c53683f400a0b4f/pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3", size = 8760432, upload-time = "2025-07-14T20:13:05.9Z" }, + { url = "https://files.pythonhosted.org/packages/5e/bf/360243b1e953bd254a82f12653974be395ba880e7ec23e3731d9f73921cc/pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b", size = 9590103, upload-time = "2025-07-14T20:13:07.698Z" }, + { url = "https://files.pythonhosted.org/packages/57/38/d290720e6f138086fb3d5ffe0b6caa019a791dd57866940c82e4eeaf2012/pywin32-311-cp310-cp310-win_arm64.whl", hash = "sha256:0502d1facf1fed4839a9a51ccbcc63d952cf318f78ffc00a7e78528ac27d7a2b", size = 8778557, upload-time = "2025-07-14T20:13:11.11Z" }, + { url = "https://files.pythonhosted.org/packages/7c/af/449a6a91e5d6db51420875c54f6aff7c97a86a3b13a0b4f1a5c13b988de3/pywin32-311-cp311-cp311-win32.whl", hash = "sha256:184eb5e436dea364dcd3d2316d577d625c0351bf237c4e9a5fabbcfa5a58b151", size = 8697031, upload-time = "2025-07-14T20:13:13.266Z" }, + { url = "https://files.pythonhosted.org/packages/51/8f/9bb81dd5bb77d22243d33c8397f09377056d5c687aa6d4042bea7fbf8364/pywin32-311-cp311-cp311-win_amd64.whl", hash = "sha256:3ce80b34b22b17ccbd937a6e78e7225d80c52f5ab9940fe0506a1a16f3dab503", size = 9508308, upload-time = "2025-07-14T20:13:15.147Z" }, + { url = "https://files.pythonhosted.org/packages/44/7b/9c2ab54f74a138c491aba1b1cd0795ba61f144c711daea84a88b63dc0f6c/pywin32-311-cp311-cp311-win_arm64.whl", hash = "sha256:a733f1388e1a842abb67ffa8e7aad0e70ac519e09b0f6a784e65a136ec7cefd2", size = 8703930, upload-time = "2025-07-14T20:13:16.945Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543, upload-time = "2025-07-14T20:13:20.765Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040, upload-time = "2025-07-14T20:13:22.543Z" }, + { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" }, + { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" }, + { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" }, + { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" }, + { url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" }, + { url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" }, + { url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3" @@ -2639,6 +2706,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "sse-starlette" +version = "3.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/9f/c3695c2d2d4ef70072c3a06992850498b01c6bc9be531950813716b426fa/sse_starlette-3.3.2.tar.gz", hash = "sha256:678fca55a1945c734d8472a6cad186a55ab02840b4f6786f5ee8770970579dcd", size = 32326, upload-time = "2026-02-28T11:24:34.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/28/8cb142d3fe80c4a2d8af54ca0b003f47ce0ba920974e7990fa6e016402d1/sse_starlette-3.3.2-py3-none-any.whl", hash = "sha256:5c3ea3dad425c601236726af2f27689b74494643f57017cafcb6f8c9acfbb862", size = 14270, upload-time = "2026-02-28T11:24:32.984Z" }, +] + [[package]] name = "starlette" version = "0.52.1" From 0eb0bec74cac9e5022087e40deba27ff466d4f6b Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:06:17 +0300 Subject: [PATCH 2/6] feat(gateway): add MCP server shutdown on gateway exit Ensures MCP subprocess connections are closed when the messaging gateway shuts down, preventing orphan processes. --- gateway/run.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/gateway/run.py b/gateway/run.py index 8154b76f..2a40149e 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2202,7 +2202,14 @@ async def start_gateway(config: Optional[GatewayConfig] = None) -> bool: # Stop cron ticker cleanly cron_stop.set() cron_thread.join(timeout=5) - + + # Close MCP server connections + try: + from tools.mcp_tool import shutdown_mcp_servers + shutdown_mcp_servers() + except Exception: + pass + return True From aa2ecaef29fd13eae1df704857039cbba6c05849 Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:22:00 +0300 Subject: [PATCH 3/6] fix: resolve orphan subprocess leak on MCP server shutdown Refactor MCP connections from AsyncExitStack to task-per-server architecture. Each server now runs as a long-lived asyncio Task with `async with stdio_client(...)`, ensuring anyio cancel-scope cleanup happens in the same Task that opened the connection. --- tests/tools/test_mcp_tool.py | 358 +++++++++++++++++++---------------- tools/mcp_tool.py | 194 ++++++++++++------- 2 files changed, 319 insertions(+), 233 deletions(-) diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index caaffd48..f12a6c93 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -36,6 +36,15 @@ def _make_call_result(text="file contents here", is_error=False): return SimpleNamespace(content=[block], isError=is_error) +def _make_mock_server(name, session=None, tools=None): + """Create an MCPServerTask with mock attributes for testing.""" + from tools.mcp_tool import MCPServerTask + server = MCPServerTask(name) + server.session = session + server._tools = tools or [] + return server + + # --------------------------------------------------------------------------- # Config loading # --------------------------------------------------------------------------- @@ -43,12 +52,10 @@ def _make_call_result(text="file contents here", is_error=False): class TestLoadMCPConfig: def test_no_config_returns_empty(self): """No mcp_servers key in config -> empty dict.""" - with patch("tools.mcp_tool.load_config", create=True) as mock_lc: - # Patch the actual import inside the function - with patch("hermes_cli.config.load_config", return_value={"model": "test"}): - from tools.mcp_tool import _load_mcp_config - result = _load_mcp_config() - assert result == {} + with patch("hermes_cli.config.load_config", return_value={"model": "test"}): + from tools.mcp_tool import _load_mcp_config + result = _load_mcp_config() + assert result == {} def test_valid_config_parsed(self): """Valid mcp_servers config is returned as-is.""" @@ -123,46 +130,37 @@ class TestSchemaConversion: class TestCheckFunction: def test_disconnected_returns_false(self): - from tools.mcp_tool import _make_check_fn, _connections + from tools.mcp_tool import _make_check_fn, _servers - # Ensure no connection exists - _connections.pop("test_server", None) + _servers.pop("test_server", None) check = _make_check_fn("test_server") assert check() is False def test_connected_returns_true(self): - from tools.mcp_tool import _make_check_fn, _connections, MCPConnection + from tools.mcp_tool import _make_check_fn, _servers - conn = MCPConnection( - server_name="test_server", - session=MagicMock(), - stack=MagicMock(), - ) - _connections["test_server"] = conn + server = _make_mock_server("test_server", session=MagicMock()) + _servers["test_server"] = server try: check = _make_check_fn("test_server") assert check() is True finally: - _connections.pop("test_server", None) + _servers.pop("test_server", None) def test_session_none_returns_false(self): - from tools.mcp_tool import _make_check_fn, _connections, MCPConnection + from tools.mcp_tool import _make_check_fn, _servers - conn = MCPConnection( - server_name="test_server", - session=None, - stack=MagicMock(), - ) - _connections["test_server"] = conn + server = _make_mock_server("test_server", session=None) + _servers["test_server"] = server try: check = _make_check_fn("test_server") assert check() is False finally: - _connections.pop("test_server", None) + _servers.pop("test_server", None) # --------------------------------------------------------------------------- -# Tool handler (async) +# Tool handler # --------------------------------------------------------------------------- class TestToolHandler: @@ -171,20 +169,24 @@ class TestToolHandler: def _patch_mcp_loop(self, coro_side_effect=None): """Return a patch for _run_on_mcp_loop that runs the coroutine directly.""" def fake_run(coro, timeout=30): - return asyncio.get_event_loop().run_until_complete(coro) + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() if coro_side_effect: return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect) return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run) def test_successful_call(self): - from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection + from tools.mcp_tool import _make_tool_handler, _servers mock_session = MagicMock() mock_session.call_tool = AsyncMock( return_value=_make_call_result("hello world", is_error=False) ) - conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock()) - _connections["test_srv"] = conn + server = _make_mock_server("test_srv", session=mock_session) + _servers["test_srv"] = server try: handler = _make_tool_handler("test_srv", "greet") @@ -193,17 +195,17 @@ class TestToolHandler: assert result["result"] == "hello world" mock_session.call_tool.assert_called_once_with("greet", arguments={"name": "world"}) finally: - _connections.pop("test_srv", None) + _servers.pop("test_srv", None) def test_mcp_error_result(self): - from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection + from tools.mcp_tool import _make_tool_handler, _servers mock_session = MagicMock() mock_session.call_tool = AsyncMock( return_value=_make_call_result("something went wrong", is_error=True) ) - conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock()) - _connections["test_srv"] = conn + server = _make_mock_server("test_srv", session=mock_session) + _servers["test_srv"] = server try: handler = _make_tool_handler("test_srv", "fail_tool") @@ -212,25 +214,24 @@ class TestToolHandler: assert "error" in result assert "something went wrong" in result["error"] finally: - _connections.pop("test_srv", None) + _servers.pop("test_srv", None) def test_disconnected_server(self): - from tools.mcp_tool import _make_tool_handler, _connections + from tools.mcp_tool import _make_tool_handler, _servers - _connections.pop("ghost", None) + _servers.pop("ghost", None) handler = _make_tool_handler("ghost", "any_tool") - # Disconnected check happens before _run_on_mcp_loop, no patch needed result = json.loads(handler({})) assert "error" in result assert "not connected" in result["error"] def test_exception_during_call(self): - from tools.mcp_tool import _make_tool_handler, _connections, MCPConnection + from tools.mcp_tool import _make_tool_handler, _servers mock_session = MagicMock() mock_session.call_tool = AsyncMock(side_effect=RuntimeError("connection lost")) - conn = MCPConnection("test_srv", session=mock_session, stack=MagicMock()) - _connections["test_srv"] = conn + server = _make_mock_server("test_srv", session=mock_session) + _servers["test_srv"] = server try: handler = _make_tool_handler("test_srv", "broken_tool") @@ -239,7 +240,7 @@ class TestToolHandler: assert "error" in result assert "connection lost" in result["error"] finally: - _connections.pop("test_srv", None) + _servers.pop("test_srv", None) # --------------------------------------------------------------------------- @@ -249,23 +250,21 @@ class TestToolHandler: class TestDiscoverAndRegister: def test_tools_registered_in_registry(self): """_discover_and_register_server registers tools with correct names.""" - from tools.registry import ToolRegistry, registry as real_registry - from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection + from tools.registry import ToolRegistry + from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask mock_registry = ToolRegistry() mock_tools = [ _make_mcp_tool("read_file", "Read a file"), _make_mcp_tool("write_file", "Write a file"), ] - mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) async def fake_connect(name, config): - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ patch("tools.registry.registry", mock_registry): @@ -278,22 +277,20 @@ class TestDiscoverAndRegister: assert "mcp_fs_read_file" in mock_registry.get_all_tool_names() assert "mcp_fs_write_file" in mock_registry.get_all_tool_names() - _connections.pop("fs", None) + _servers.pop("fs", None) def test_toolset_created(self): """A custom toolset is created for the MCP server.""" - from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection + from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask mock_tools = [_make_mcp_tool("ping", "Ping")] - mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) async def fake_connect(name, config): - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server mock_create = MagicMock() with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ @@ -306,24 +303,22 @@ class TestDiscoverAndRegister: call_kwargs = mock_create.call_args assert call_kwargs[1]["name"] == "mcp-myserver" or call_kwargs[0][0] == "mcp-myserver" - _connections.pop("myserver", None) + _servers.pop("myserver", None) def test_schema_format_correct(self): """Registered schemas have the correct format.""" - from tools.registry import ToolRegistry, registry as real_registry - from tools.mcp_tool import _discover_and_register_server, _connections, MCPConnection + from tools.registry import ToolRegistry + from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask mock_registry = ToolRegistry() mock_tools = [_make_mcp_tool("do_thing", "Do something")] - mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) async def fake_connect(name, config): - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ patch("tools.registry.registry", mock_registry): @@ -338,91 +333,125 @@ class TestDiscoverAndRegister: assert entry.is_async is False assert entry.toolset == "mcp-srv" - _connections.pop("srv", None) + _servers.pop("srv", None) # --------------------------------------------------------------------------- -# _connect_server (SDK interaction) +# MCPServerTask (run / start / shutdown) # --------------------------------------------------------------------------- -class TestConnectServer: - def test_calls_sdk_with_correct_params(self): - """_connect_server creates StdioServerParameters and calls stdio_client.""" - from tools.mcp_tool import _connect_server, MCPConnection +class TestMCPServerTask: + """Test the MCPServerTask lifecycle with mocked MCP SDK.""" + def _mock_stdio_and_session(self, session): + """Return patches for stdio_client and ClientSession as async CMs.""" + mock_read, mock_write = MagicMock(), MagicMock() + + mock_stdio_cm = MagicMock() + mock_stdio_cm.__aenter__ = AsyncMock(return_value=(mock_read, mock_write)) + mock_stdio_cm.__aexit__ = AsyncMock(return_value=False) + + mock_cs_cm = MagicMock() + mock_cs_cm.__aenter__ = AsyncMock(return_value=session) + mock_cs_cm.__aexit__ = AsyncMock(return_value=False) + + return ( + patch("tools.mcp_tool.stdio_client", return_value=mock_stdio_cm), + patch("tools.mcp_tool.ClientSession", return_value=mock_cs_cm), + mock_read, mock_write, + ) + + def test_start_connects_and_discovers_tools(self): + """start() creates a Task that connects, discovers tools, and waits.""" + from tools.mcp_tool import MCPServerTask + + mock_tools = [_make_mcp_tool("echo")] mock_session = MagicMock() mock_session.initialize = AsyncMock() - - mock_read = MagicMock() - mock_write = MagicMock() - - with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ - patch("tools.mcp_tool.stdio_client") as mock_stdio, \ - patch("tools.mcp_tool.ClientSession") as mock_cs, \ - patch("tools.mcp_tool.AsyncExitStack") as mock_stack_cls: - - mock_stack = MagicMock() - mock_stack.enter_async_context = AsyncMock( - side_effect=[(mock_read, mock_write), mock_session] - ) - mock_stack_cls.return_value = mock_stack - - conn = asyncio.run(_connect_server("test_srv", { - "command": "npx", - "args": ["-y", "some-server"], - "env": {"MY_KEY": "secret"}, - })) - - # StdioServerParameters called with correct values - mock_params.assert_called_once_with( - command="npx", - args=["-y", "some-server"], - env={"MY_KEY": "secret"}, + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=mock_tools) ) - # ClientSession created with the streams - mock_cs.assert_called_once_with(mock_read, mock_write) - # initialize() was called - mock_session.initialize.assert_called_once() - # Returned connection is valid - assert conn.server_name == "test_srv" - assert conn.session is mock_session + + p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session) + + async def _test(): + with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs: + server = MCPServerTask("test_srv") + await server.start({"command": "npx", "args": ["-y", "test"]}) + + assert server.session is mock_session + assert len(server._tools) == 1 + assert server._tools[0].name == "echo" + mock_session.initialize.assert_called_once() + + await server.shutdown() + assert server.session is None + + asyncio.run(_test()) def test_no_command_raises(self): """Missing 'command' in config raises ValueError.""" - from tools.mcp_tool import _connect_server + from tools.mcp_tool import MCPServerTask - with pytest.raises(ValueError, match="no 'command'"): - asyncio.run(_connect_server("bad", {"args": []})) + async def _test(): + server = MCPServerTask("bad") + with pytest.raises(ValueError, match="no 'command'"): + await server.start({"args": []}) + + asyncio.run(_test()) def test_empty_env_passed_as_none(self): """Empty env dict is passed as None to StdioServerParameters.""" - from tools.mcp_tool import _connect_server + from tools.mcp_tool import MCPServerTask mock_session = MagicMock() mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=[]) + ) - with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ - patch("tools.mcp_tool.stdio_client"), \ - patch("tools.mcp_tool.ClientSession", return_value=mock_session), \ - patch("tools.mcp_tool.AsyncExitStack") as mock_stack_cls: + p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session) - mock_stack = MagicMock() - mock_stack.enter_async_context = AsyncMock( - side_effect=[ - (MagicMock(), MagicMock()), - mock_session, - ] - ) - mock_stack_cls.return_value = mock_stack + async def _test(): + with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ + p_stdio, p_cs: + server = MCPServerTask("srv") + await server.start({"command": "node", "env": {}}) - asyncio.run(_connect_server("srv", { - "command": "node", - "env": {}, - })) + # Empty dict -> None + call_kwargs = mock_params.call_args + assert call_kwargs.kwargs.get("env") is None - # Empty dict -> None - assert mock_params.call_args[1]["env"] is None or \ - mock_params.call_args.kwargs.get("env") is None + await server.shutdown() + + asyncio.run(_test()) + + def test_shutdown_signals_task_exit(self): + """shutdown() signals the event and waits for task completion.""" + from tools.mcp_tool import MCPServerTask + + mock_session = MagicMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock( + return_value=SimpleNamespace(tools=[]) + ) + + p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session) + + async def _test(): + with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs: + server = MCPServerTask("srv") + await server.start({"command": "npx"}) + + assert server.session is not None + assert not server._task.done() + + await server.shutdown() + + assert server.session is None + assert server._task.done() + + asyncio.run(_test()) # --------------------------------------------------------------------------- @@ -432,17 +461,16 @@ class TestConnectServer: class TestToolsetInjection: def test_mcp_tools_added_to_platform_toolsets(self): """Discovered MCP tools are injected into hermes-cli and platform toolsets.""" - from tools.mcp_tool import _connections, MCPConnection + from tools.mcp_tool import _servers, MCPServerTask mock_tools = [_make_mcp_tool("list_files", "List files")] mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) async def fake_connect(name, config): - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server fake_toolsets = { "hermes-cli": {"tools": ["terminal", "web_search"], "description": "CLI", "includes": []}, @@ -455,7 +483,6 @@ class TestToolsetInjection: with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ - patch("tools.mcp_tool.TOOLSETS", fake_toolsets, create=True), \ patch("toolsets.TOOLSETS", fake_toolsets): from tools.mcp_tool import discover_mcp_tools result = discover_mcp_tools() @@ -466,18 +493,14 @@ class TestToolsetInjection: # Original tools preserved assert "terminal" in fake_toolsets["hermes-cli"]["tools"] - _connections.pop("fs", None) + _servers.pop("fs", None) def test_server_connection_failure_skipped(self): """If one server fails to connect, others still proceed.""" - from tools.mcp_tool import _connections, MCPConnection + from tools.mcp_tool import _servers, MCPServerTask mock_tools = [_make_mcp_tool("ping", "Ping")] mock_session = MagicMock() - mock_session.initialize = AsyncMock() - mock_session.list_tools = AsyncMock( - return_value=SimpleNamespace(tools=mock_tools) - ) call_count = 0 @@ -486,7 +509,10 @@ class TestToolsetInjection: call_count += 1 if name == "broken": raise ConnectionError("cannot reach server") - return MCPConnection(name, session=mock_session, stack=MagicMock()) + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server fake_config = { "broken": {"command": "bad"}, @@ -508,7 +534,7 @@ class TestToolsetInjection: assert "mcp_broken_ping" not in result assert call_count == 2 # Both were attempted - _connections.pop("good", None) + _servers.pop("good", None) # --------------------------------------------------------------------------- @@ -533,50 +559,46 @@ class TestGracefulFallback: # --------------------------------------------------------------------------- -# Shutdown +# Shutdown (public API) # --------------------------------------------------------------------------- class TestShutdown: - def test_no_connections_safe(self): - """shutdown_mcp_servers with no connections does nothing.""" - from tools.mcp_tool import shutdown_mcp_servers, _connections + def test_no_servers_safe(self): + """shutdown_mcp_servers with no servers does nothing.""" + from tools.mcp_tool import shutdown_mcp_servers, _servers - _connections.clear() + _servers.clear() shutdown_mcp_servers() # Should not raise - def test_shutdown_clears_connections(self): - """shutdown_mcp_servers closes stacks and clears the dict.""" + def test_shutdown_clears_servers(self): + """shutdown_mcp_servers calls shutdown() on each server and clears dict.""" import tools.mcp_tool as mcp_mod - from tools.mcp_tool import shutdown_mcp_servers, _connections, MCPConnection + from tools.mcp_tool import shutdown_mcp_servers, _servers - _connections.clear() - mock_stack = MagicMock() - mock_stack.aclose = AsyncMock() - conn = MCPConnection("test", session=MagicMock(), stack=mock_stack) - _connections["test"] = conn + _servers.clear() + mock_server = MagicMock() + mock_server.shutdown = AsyncMock() + _servers["test"] = mock_server - # Start a real background loop so shutdown can schedule on it mcp_mod._ensure_mcp_loop() try: shutdown_mcp_servers() finally: - # _stop_mcp_loop is called by shutdown, but ensure cleanup mcp_mod._mcp_loop = None mcp_mod._mcp_thread = None - assert len(_connections) == 0 - mock_stack.aclose.assert_called_once() + assert len(_servers) == 0 + mock_server.shutdown.assert_called_once() def test_shutdown_handles_errors(self): """shutdown_mcp_servers handles errors during close gracefully.""" import tools.mcp_tool as mcp_mod - from tools.mcp_tool import shutdown_mcp_servers, _connections, MCPConnection + from tools.mcp_tool import shutdown_mcp_servers, _servers - _connections.clear() - mock_stack = MagicMock() - mock_stack.aclose = AsyncMock(side_effect=RuntimeError("close failed")) - conn = MCPConnection("broken", session=MagicMock(), stack=mock_stack) - _connections["broken"] = conn + _servers.clear() + mock_server = MagicMock() + mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed")) + _servers["broken"] = mock_server mcp_mod._ensure_mcp_loop() try: @@ -585,4 +607,4 @@ class TestShutdown: mcp_mod._mcp_loop = None mcp_mod._mcp_thread = None - assert len(_connections) == 0 + assert len(_servers) == 0 diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index eecbaa29..5225d63f 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -25,8 +25,13 @@ Example config:: Architecture: A dedicated background event loop (_mcp_loop) runs in a daemon thread. - All MCP connections live on this loop. Tool handlers schedule coroutines - onto it via run_coroutine_threadsafe(), so they work from any thread. + Each MCP server runs as a long-lived asyncio Task on this loop, keeping + its ``async with stdio_client(...)`` context alive. Tool call coroutines + are scheduled onto the loop via ``run_coroutine_threadsafe()``. + + On shutdown, each server Task is signalled to exit its ``async with`` + block, ensuring the anyio cancel-scope cleanup happens in the *same* + Task that opened the connection (required by anyio). """ import asyncio @@ -45,31 +50,114 @@ _MCP_AVAILABLE = False try: from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client - from contextlib import AsyncExitStack _MCP_AVAILABLE = True except ImportError: logger.debug("mcp package not installed -- MCP tool support disabled") # --------------------------------------------------------------------------- -# Connection tracking +# Server task -- each MCP server lives in one long-lived asyncio Task # --------------------------------------------------------------------------- -class MCPConnection: - """Holds a live MCP server connection and its async resource stack.""" +class MCPServerTask: + """Manages a single MCP server connection in a dedicated asyncio Task. - __slots__ = ("server_name", "session", "stack") + The entire connection lifecycle (connect, discover, serve, disconnect) + runs inside one asyncio Task so that anyio cancel-scopes created by + ``stdio_client`` are entered and exited in the same Task context. + """ - def __init__(self, server_name: str, session: Any, stack: Any): - self.server_name = server_name - self.session: Optional[Any] = session - self.stack: Optional[Any] = stack + __slots__ = ( + "name", "session", + "_task", "_ready", "_shutdown_event", "_tools", "_error", + ) + + def __init__(self, name: str): + self.name = name + self.session: Optional[Any] = None + self._task: Optional[asyncio.Task] = None + self._ready = asyncio.Event() + self._shutdown_event = asyncio.Event() + self._tools: list = [] + self._error: Optional[Exception] = None + + async def run(self, config: dict): + """Long-lived coroutine: connect, discover tools, wait, disconnect.""" + command = config.get("command") + args = config.get("args", []) + env = config.get("env") + + if not command: + self._error = ValueError( + f"MCP server '{self.name}' has no 'command' in config" + ) + self._ready.set() + return + + server_params = StdioServerParameters( + command=command, + args=args, + env=env if env else None, + ) + + try: + async with stdio_client(server_params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + self.session = session + + tools_result = await session.list_tools() + self._tools = ( + tools_result.tools + if hasattr(tools_result, "tools") + else [] + ) + + # Signal that connection is ready + self._ready.set() + + # Block until shutdown is requested -- this keeps the + # async-with contexts alive on THIS Task. + await self._shutdown_event.wait() + except Exception as exc: + self._error = exc + self._ready.set() + finally: + self.session = None + + async def start(self, config: dict): + """Create the background Task and wait until ready (or failed).""" + self._task = asyncio.ensure_future(self.run(config)) + await self._ready.wait() + if self._error: + raise self._error + + async def shutdown(self): + """Signal the Task to exit and wait for clean resource teardown.""" + self._shutdown_event.set() + if self._task and not self._task.done(): + try: + await asyncio.wait_for(self._task, timeout=10) + except asyncio.TimeoutError: + logger.warning( + "MCP server '%s' shutdown timed out, cancelling task", + self.name, + ) + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self.session = None -_connections: Dict[str, MCPConnection] = {} +# --------------------------------------------------------------------------- +# Module-level state +# --------------------------------------------------------------------------- + +_servers: Dict[str, MCPServerTask] = {} # Dedicated event loop running in a background daemon thread. -# All MCP async operations (connect, call_tool, shutdown) run here. _mcp_loop: Optional[asyncio.AbstractEventLoop] = None _mcp_thread: Optional[threading.Thread] = None @@ -118,42 +206,22 @@ def _load_mcp_config() -> Dict[str, dict]: # --------------------------------------------------------------------------- -# Server connection +# Server connection helper # --------------------------------------------------------------------------- -async def _connect_server(name: str, config: dict) -> MCPConnection: - """Start an MCP server subprocess and initialize a ClientSession. +async def _connect_server(name: str, config: dict) -> MCPServerTask: + """Create an MCPServerTask, start it, and return when ready. - Args: - name: Logical server name (e.g. "filesystem"). - config: Dict with ``command``, ``args``, and optional ``env``. - - Returns: - An ``MCPConnection`` with a live session. + The server Task keeps the subprocess alive in the background. + Call ``server.shutdown()`` (on the same event loop) to tear it down. Raises: - Exception on connection or initialization failure. + ValueError: if ``command`` is missing from *config*. + Exception: on connection or initialization failure. """ - command = config.get("command") - args = config.get("args", []) - env = config.get("env") - - if not command: - raise ValueError(f"MCP server '{name}' has no 'command' in config") - - server_params = StdioServerParameters( - command=command, - args=args, - env=env if env else None, - ) - - stack = AsyncExitStack() - stdio_transport = await stack.enter_async_context(stdio_client(server_params)) - read_stream, write_stream = stdio_transport - session = await stack.enter_async_context(ClientSession(read_stream, write_stream)) - await session.initialize() - - return MCPConnection(server_name=name, session=session, stack=stack) + server = MCPServerTask(name) + await server.start(config) + return server # --------------------------------------------------------------------------- @@ -168,14 +236,14 @@ def _make_tool_handler(server_name: str, tool_name: str): """ def _handler(args: dict, **kwargs) -> str: - conn = _connections.get(server_name) - if not conn or not conn.session: + server = _servers.get(server_name) + if not server or not server.session: return json.dumps({ "error": f"MCP server '{server_name}' is not connected" }) async def _call(): - result = await conn.session.call_tool(tool_name, arguments=args) + result = await server.session.call_tool(tool_name, arguments=args) # MCP CallToolResult has .content (list of content blocks) and .isError if result.isError: error_text = "" @@ -204,8 +272,8 @@ def _make_check_fn(server_name: str): """Return a check function that verifies the MCP connection is alive.""" def _check() -> bool: - conn = _connections.get(server_name) - return conn is not None and conn.session is not None + server = _servers.get(server_name) + return server is not None and server.session is not None return _check @@ -247,17 +315,13 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: from tools.registry import registry from toolsets import create_custom_toolset - conn = await _connect_server(name, config) - _connections[name] = conn - - # Discover tools - tools_result = await conn.session.list_tools() - tools = tools_result.tools if hasattr(tools_result, "tools") else [] + server = await _connect_server(name, config) + _servers[name] = server registered_names: List[str] = [] toolset_name = f"mcp-{name}" - for mcp_tool in tools: + for mcp_tool in server._tools: schema = _convert_mcp_schema(name, mcp_tool) tool_name_prefixed = schema["name"] @@ -339,29 +403,29 @@ def discover_mcp_tools() -> List[str]: def shutdown_mcp_servers(): - """Close all MCP server connections and stop the background loop.""" + """Close all MCP server connections and stop the background loop. + + Each server Task is signalled to exit its ``async with`` block so that + the anyio cancel-scope cleanup happens in the same Task that opened it. + """ global _mcp_loop, _mcp_thread - if not _connections: + if not _servers: _stop_mcp_loop() return async def _shutdown(): - for name, conn in list(_connections.items()): + for name, server in list(_servers.items()): try: - if conn.stack: - await conn.stack.aclose() + await server.shutdown() except Exception as exc: logger.debug("Error closing MCP server '%s': %s", name, exc) - finally: - conn.session = None - conn.stack = None - _connections.clear() + _servers.clear() if _mcp_loop is not None and _mcp_loop.is_running(): try: future = asyncio.run_coroutine_threadsafe(_shutdown(), _mcp_loop) - future.result(timeout=10) + future.result(timeout=15) except Exception as exc: logger.debug("Error during MCP shutdown: %s", exc) From 593c549bc466f6e0b8c517320393c16731597c74 Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:34:21 +0300 Subject: [PATCH 4/6] fix: make discover_mcp_tools idempotent to prevent duplicate connections When discover_mcp_tools() is called multiple times (e.g. direct call then model_tools import), return existing tool names instead of opening new connections that would orphan the previous ones. --- tools/mcp_tool.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 5225d63f..5cdce4a3 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -361,6 +361,9 @@ def discover_mcp_tools() -> List[str]: Called from ``model_tools._discover_tools()``. Safe to call even when the ``mcp`` package is not installed (returns empty list). + Idempotent: if servers are already connected, returns the existing + tool names without creating duplicate connections. + Returns: List of all registered MCP tool names. """ @@ -368,6 +371,15 @@ def discover_mcp_tools() -> List[str]: logger.debug("MCP SDK not available -- skipping MCP tool discovery") return [] + # Already connected -- return existing tool names (idempotent) + if _servers: + existing: List[str] = [] + for name, server in _servers.items(): + for mcp_tool in server._tools: + schema = _convert_mcp_schema(name, mcp_tool) + existing.append(schema["name"]) + return existing + servers = _load_mcp_config() if not servers: logger.debug("No MCP servers configured") From 151e8d896ca2296eeb836097bdb8049e70ef40f0 Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:38:01 +0300 Subject: [PATCH 5/6] fix(tests): isolate discover_mcp_tools tests from global _servers state Patch _servers to empty dict in tests that call discover_mcp_tools() with mocked config, preventing interference from real MCP connections that may exist when running within the full test suite. --- tests/tools/test_mcp_tool.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index f12a6c93..2e52272b 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -461,11 +461,14 @@ class TestMCPServerTask: class TestToolsetInjection: def test_mcp_tools_added_to_platform_toolsets(self): """Discovered MCP tools are injected into hermes-cli and platform toolsets.""" - from tools.mcp_tool import _servers, MCPServerTask + from tools.mcp_tool import MCPServerTask mock_tools = [_make_mcp_tool("list_files", "List files")] mock_session = MagicMock() + # Fresh _servers dict to bypass idempotency guard + fresh_servers = {} + async def fake_connect(name, config): server = MCPServerTask(name) server.session = mock_session @@ -481,6 +484,7 @@ class TestToolsetInjection: } with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._servers", fresh_servers), \ patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ patch("toolsets.TOOLSETS", fake_toolsets): @@ -493,15 +497,15 @@ class TestToolsetInjection: # Original tools preserved assert "terminal" in fake_toolsets["hermes-cli"]["tools"] - _servers.pop("fs", None) - def test_server_connection_failure_skipped(self): """If one server fails to connect, others still proceed.""" - from tools.mcp_tool import _servers, MCPServerTask + from tools.mcp_tool import MCPServerTask mock_tools = [_make_mcp_tool("ping", "Ping")] mock_session = MagicMock() + # Fresh _servers dict to bypass idempotency guard + fresh_servers = {} call_count = 0 async def flaky_connect(name, config): @@ -523,6 +527,7 @@ class TestToolsetInjection: } with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._servers", fresh_servers), \ patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \ patch("toolsets.TOOLSETS", fake_toolsets): @@ -534,8 +539,6 @@ class TestToolsetInjection: assert "mcp_broken_ping" not in result assert call_count == 2 # Both were attempted - _servers.pop("good", None) - # --------------------------------------------------------------------------- # Graceful fallback @@ -552,6 +555,7 @@ class TestGracefulFallback: def test_no_servers_returns_empty(self): """No MCP servers configured -> empty list.""" with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._servers", {}), \ patch("tools.mcp_tool._load_mcp_config", return_value={}): from tools.mcp_tool import discover_mcp_tools result = discover_mcp_tools() From 11a2ecb936d6bc97f67ce2574630767091a504ec Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Mon, 2 Mar 2026 22:08:32 +0300 Subject: [PATCH 6/6] fix: resolve thread safety issues and shutdown deadlock in MCP client - Add threading.Lock protecting all shared state (_servers, _mcp_loop, _mcp_thread) - Fix deadlock in shutdown_mcp_servers: _stop_mcp_loop was called inside a _lock block but also acquires _lock (non-reentrant) - Fix race condition in _ensure_mcp_loop with concurrent callers - Change idempotency to per-server (retry failed servers, skip connected) - Dynamic toolset injection via startswith("hermes-") instead of hardcoded list - Parallel shutdown via asyncio.gather instead of sequential loop - Add tests for partial failure retry, parallel shutdown, dynamic injection --- tests/tools/test_mcp_tool.py | 108 +++++++++++++++++++++++++--- tools/mcp_tool.py | 134 +++++++++++++++++++++++------------ 2 files changed, 184 insertions(+), 58 deletions(-) diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 2e52272b..065baf4a 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -459,14 +459,13 @@ class TestMCPServerTask: # --------------------------------------------------------------------------- class TestToolsetInjection: - def test_mcp_tools_added_to_platform_toolsets(self): - """Discovered MCP tools are injected into hermes-cli and platform toolsets.""" + def test_mcp_tools_added_to_all_hermes_toolsets(self): + """Discovered MCP tools are dynamically injected into all hermes-* toolsets.""" from tools.mcp_tool import MCPServerTask mock_tools = [_make_mcp_tool("list_files", "List files")] mock_session = MagicMock() - # Fresh _servers dict to bypass idempotency guard fresh_servers = {} async def fake_connect(name, config): @@ -476,12 +475,12 @@ class TestToolsetInjection: return server fake_toolsets = { - "hermes-cli": {"tools": ["terminal", "web_search"], "description": "CLI", "includes": []}, - "hermes-telegram": {"tools": ["terminal"], "description": "Telegram", "includes": []}, - } - fake_config = { - "fs": {"command": "npx", "args": []}, + "hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []}, + "hermes-telegram": {"tools": ["terminal"], "description": "TG", "includes": []}, + "hermes-gateway": {"tools": [], "description": "GW", "includes": []}, + "non-hermes": {"tools": [], "description": "other", "includes": []}, } + fake_config = {"fs": {"command": "npx", "args": []}} with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ patch("tools.mcp_tool._servers", fresh_servers), \ @@ -492,8 +491,12 @@ class TestToolsetInjection: result = discover_mcp_tools() assert "mcp_fs_list_files" in result + # All hermes-* toolsets get injection assert "mcp_fs_list_files" in fake_toolsets["hermes-cli"]["tools"] assert "mcp_fs_list_files" in fake_toolsets["hermes-telegram"]["tools"] + assert "mcp_fs_list_files" in fake_toolsets["hermes-gateway"]["tools"] + # Non-hermes toolset should NOT get injection + assert "mcp_fs_list_files" not in fake_toolsets["non-hermes"]["tools"] # Original tools preserved assert "terminal" in fake_toolsets["hermes-cli"]["tools"] @@ -504,7 +507,6 @@ class TestToolsetInjection: mock_tools = [_make_mcp_tool("ping", "Ping")] mock_session = MagicMock() - # Fresh _servers dict to bypass idempotency guard fresh_servers = {} call_count = 0 @@ -534,10 +536,62 @@ class TestToolsetInjection: from tools.mcp_tool import discover_mcp_tools result = discover_mcp_tools() - # Only good server's tool registered assert "mcp_good_ping" in result assert "mcp_broken_ping" not in result - assert call_count == 2 # Both were attempted + assert call_count == 2 + + def test_partial_failure_retry_on_second_call(self): + """Failed servers are retried on subsequent discover_mcp_tools() calls.""" + from tools.mcp_tool import MCPServerTask + + mock_tools = [_make_mcp_tool("ping", "Ping")] + mock_session = MagicMock() + + # Use a real dict so idempotency logic works correctly + fresh_servers = {} + call_count = 0 + broken_fixed = False + + async def flaky_connect(name, config): + nonlocal call_count + call_count += 1 + if name == "broken" and not broken_fixed: + raise ConnectionError("cannot reach server") + server = MCPServerTask(name) + server.session = mock_session + server._tools = mock_tools + return server + + fake_config = { + "broken": {"command": "bad"}, + "good": {"command": "npx", "args": []}, + } + fake_toolsets = { + "hermes-cli": {"tools": [], "description": "CLI", "includes": []}, + } + + with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ + patch("tools.mcp_tool._servers", fresh_servers), \ + patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ + patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \ + patch("toolsets.TOOLSETS", fake_toolsets): + from tools.mcp_tool import discover_mcp_tools + + # First call: good connects, broken fails + result1 = discover_mcp_tools() + assert "mcp_good_ping" in result1 + assert "mcp_broken_ping" not in result1 + first_attempts = call_count + + # "Fix" the broken server + broken_fixed = True + call_count = 0 + + # Second call: should retry broken, skip good + result2 = discover_mcp_tools() + assert "mcp_good_ping" in result2 + assert "mcp_broken_ping" in result2 + assert call_count == 1 # Only broken retried # --------------------------------------------------------------------------- @@ -581,6 +635,7 @@ class TestShutdown: _servers.clear() mock_server = MagicMock() + mock_server.name = "test" mock_server.shutdown = AsyncMock() _servers["test"] = mock_server @@ -601,6 +656,7 @@ class TestShutdown: _servers.clear() mock_server = MagicMock() + mock_server.name = "broken" mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed")) _servers["broken"] = mock_server @@ -612,3 +668,33 @@ class TestShutdown: mcp_mod._mcp_thread = None assert len(_servers) == 0 + + def test_shutdown_is_parallel(self): + """Multiple servers are shut down in parallel via asyncio.gather.""" + import tools.mcp_tool as mcp_mod + from tools.mcp_tool import shutdown_mcp_servers, _servers + import time + + _servers.clear() + + # 3 servers each taking 1s to shut down + for i in range(3): + mock_server = MagicMock() + mock_server.name = f"srv_{i}" + async def slow_shutdown(): + await asyncio.sleep(1) + mock_server.shutdown = slow_shutdown + _servers[f"srv_{i}"] = mock_server + + mcp_mod._ensure_mcp_loop() + try: + start = time.monotonic() + shutdown_mcp_servers() + elapsed = time.monotonic() - start + finally: + mcp_mod._mcp_loop = None + mcp_mod._mcp_thread = None + + assert len(_servers) == 0 + # Parallel: ~1s, not ~3s. Allow some margin. + assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)" diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 5cdce4a3..4ab55215 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -32,6 +32,12 @@ Architecture: On shutdown, each server Task is signalled to exit its ``async with`` block, ensuring the anyio cancel-scope cleanup happens in the *same* Task that opened the connection (required by anyio). + +Thread safety: + _servers and _mcp_loop/_mcp_thread are accessed from both the MCP + background thread and caller threads. All mutations are protected by + _lock so the code is safe regardless of GIL presence (e.g. Python 3.13+ + free-threading). """ import asyncio @@ -161,26 +167,32 @@ _servers: Dict[str, MCPServerTask] = {} _mcp_loop: Optional[asyncio.AbstractEventLoop] = None _mcp_thread: Optional[threading.Thread] = None +# Protects _mcp_loop, _mcp_thread, and _servers from concurrent access. +_lock = threading.Lock() + def _ensure_mcp_loop(): """Start the background event loop thread if not already running.""" global _mcp_loop, _mcp_thread - if _mcp_loop is not None and _mcp_loop.is_running(): - return - _mcp_loop = asyncio.new_event_loop() - _mcp_thread = threading.Thread( - target=_mcp_loop.run_forever, - name="mcp-event-loop", - daemon=True, - ) - _mcp_thread.start() + with _lock: + if _mcp_loop is not None and _mcp_loop.is_running(): + return + _mcp_loop = asyncio.new_event_loop() + _mcp_thread = threading.Thread( + target=_mcp_loop.run_forever, + name="mcp-event-loop", + daemon=True, + ) + _mcp_thread.start() def _run_on_mcp_loop(coro, timeout: float = 30): """Schedule a coroutine on the MCP event loop and block until done.""" - if _mcp_loop is None or not _mcp_loop.is_running(): + with _lock: + loop = _mcp_loop + if loop is None or not loop.is_running(): raise RuntimeError("MCP event loop is not running") - future = asyncio.run_coroutine_threadsafe(coro, _mcp_loop) + future = asyncio.run_coroutine_threadsafe(coro, loop) return future.result(timeout=timeout) @@ -236,7 +248,8 @@ def _make_tool_handler(server_name: str, tool_name: str): """ def _handler(args: dict, **kwargs) -> str: - server = _servers.get(server_name) + with _lock: + server = _servers.get(server_name) if not server or not server.session: return json.dumps({ "error": f"MCP server '{server_name}' is not connected" @@ -272,7 +285,8 @@ def _make_check_fn(server_name: str): """Return a check function that verifies the MCP connection is alive.""" def _check() -> bool: - server = _servers.get(server_name) + with _lock: + server = _servers.get(server_name) return server is not None and server.session is not None return _check @@ -307,6 +321,16 @@ def _convert_mcp_schema(server_name: str, mcp_tool) -> dict: } +def _existing_tool_names() -> List[str]: + """Return tool names for all currently connected servers.""" + names: List[str] = [] + for sname, server in _servers.items(): + for mcp_tool in server._tools: + schema = _convert_mcp_schema(sname, mcp_tool) + names.append(schema["name"]) + return names + + async def _discover_and_register_server(name: str, config: dict) -> List[str]: """Connect to a single MCP server, discover tools, and register them. @@ -316,7 +340,8 @@ async def _discover_and_register_server(name: str, config: dict) -> List[str]: from toolsets import create_custom_toolset server = await _connect_server(name, config) - _servers[name] = server + with _lock: + _servers[name] = server registered_names: List[str] = [] toolset_name = f"mcp-{name}" @@ -361,8 +386,8 @@ def discover_mcp_tools() -> List[str]: Called from ``model_tools._discover_tools()``. Safe to call even when the ``mcp`` package is not installed (returns empty list). - Idempotent: if servers are already connected, returns the existing - tool names without creating duplicate connections. + Idempotent for already-connected servers. If some servers failed on a + previous call, only the missing ones are retried. Returns: List of all registered MCP tool names. @@ -371,27 +396,25 @@ def discover_mcp_tools() -> List[str]: logger.debug("MCP SDK not available -- skipping MCP tool discovery") return [] - # Already connected -- return existing tool names (idempotent) - if _servers: - existing: List[str] = [] - for name, server in _servers.items(): - for mcp_tool in server._tools: - schema = _convert_mcp_schema(name, mcp_tool) - existing.append(schema["name"]) - return existing - servers = _load_mcp_config() if not servers: logger.debug("No MCP servers configured") return [] + # Only attempt servers that aren't already connected + with _lock: + new_servers = {k: v for k, v in servers.items() if k not in _servers} + + if not new_servers: + return _existing_tool_names() + # Start the background event loop for MCP connections _ensure_mcp_loop() all_tools: List[str] = [] async def _discover_all(): - for name, cfg in servers.items(): + for name, cfg in new_servers.items(): try: registered = await _discover_and_register_server(name, cfg) all_tools.extend(registered) @@ -401,17 +424,16 @@ def discover_mcp_tools() -> List[str]: _run_on_mcp_loop(_discover_all(), timeout=60) if all_tools: - # Add MCP tools to hermes-cli and other platform toolsets + # Dynamically inject into all hermes-* platform toolsets from toolsets import TOOLSETS - for ts_name in ("hermes-cli", "hermes-telegram", "hermes-discord", - "hermes-whatsapp", "hermes-slack"): - ts = TOOLSETS.get(ts_name) - if ts: + for ts_name, ts in TOOLSETS.items(): + if ts_name.startswith("hermes-"): for tool_name in all_tools: if tool_name not in ts["tools"]: ts["tools"].append(tool_name) - return all_tools + # Return ALL registered tools (existing + newly discovered) + return _existing_tool_names() def shutdown_mcp_servers(): @@ -419,24 +441,39 @@ def shutdown_mcp_servers(): Each server Task is signalled to exit its ``async with`` block so that the anyio cancel-scope cleanup happens in the same Task that opened it. + All servers are shut down in parallel via ``asyncio.gather``. """ - global _mcp_loop, _mcp_thread + with _lock: + if not _servers: + # No servers -- just stop the loop. _stop_mcp_loop() also + # acquires _lock, so we must release it first. + pass + else: + servers_snapshot = list(_servers.values()) + # Fast path: nothing to shut down. if not _servers: _stop_mcp_loop() return async def _shutdown(): - for name, server in list(_servers.items()): - try: - await server.shutdown() - except Exception as exc: - logger.debug("Error closing MCP server '%s': %s", name, exc) - _servers.clear() + results = await asyncio.gather( + *(server.shutdown() for server in servers_snapshot), + return_exceptions=True, + ) + for server, result in zip(servers_snapshot, results): + if isinstance(result, Exception): + logger.debug( + "Error closing MCP server '%s': %s", server.name, result, + ) + with _lock: + _servers.clear() - if _mcp_loop is not None and _mcp_loop.is_running(): + with _lock: + loop = _mcp_loop + if loop is not None and loop.is_running(): try: - future = asyncio.run_coroutine_threadsafe(_shutdown(), _mcp_loop) + future = asyncio.run_coroutine_threadsafe(_shutdown(), loop) future.result(timeout=15) except Exception as exc: logger.debug("Error during MCP shutdown: %s", exc) @@ -447,10 +484,13 @@ def shutdown_mcp_servers(): def _stop_mcp_loop(): """Stop the background event loop and join its thread.""" global _mcp_loop, _mcp_thread - if _mcp_loop is not None: - _mcp_loop.call_soon_threadsafe(_mcp_loop.stop) - if _mcp_thread is not None: - _mcp_thread.join(timeout=5) - _mcp_thread = None - _mcp_loop.close() + with _lock: + loop = _mcp_loop + thread = _mcp_thread _mcp_loop = None + _mcp_thread = None + if loop is not None: + loop.call_soon_threadsafe(loop.stop) + if thread is not None: + thread.join(timeout=5) + loop.close()