fix(mcp): resolve npx stdio connection failures (#1291)

Salvaged from PR #977 onto current main.
Preserves the MCP stdio command resolution and improved error diagnostics,
with deterministic regression tests for the npx/node PATH cases.

Co-authored-by: kshitij <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
Teknium 2026-03-14 05:44:00 -07:00 committed by GitHub
parent 1a857123b3
commit b646440ca0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 203 additions and 2 deletions

View file

@ -75,6 +75,7 @@ import logging
import math
import os
import re
import shutil
import threading
import time
from typing import Any, Dict, List, Optional
@ -176,6 +177,116 @@ def _sanitize_error(text: str) -> str:
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
def _prepend_path(env: dict, directory: str) -> dict:
"""Prepend *directory* to env PATH if it is not already present."""
updated = dict(env or {})
if not directory:
return updated
existing = updated.get("PATH", "")
parts = [part for part in existing.split(os.pathsep) if part]
if directory not in parts:
parts = [directory, *parts]
updated["PATH"] = os.pathsep.join(parts) if parts else directory
return updated
def _resolve_stdio_command(command: str, env: dict) -> tuple[str, dict]:
"""Resolve a stdio MCP command against the exact subprocess environment.
This primarily exists to make bare ``npx``/``npm``/``node`` commands work
reliably even when MCP subprocesses run under a filtered PATH.
"""
resolved_command = os.path.expanduser(str(command).strip())
resolved_env = dict(env or {})
if os.sep not in resolved_command:
path_arg = resolved_env["PATH"] if "PATH" in resolved_env else None
which_hit = shutil.which(resolved_command, path=path_arg)
if which_hit:
resolved_command = which_hit
elif resolved_command in {"npx", "npm", "node"}:
hermes_home = os.path.expanduser(
os.getenv(
"HERMES_HOME", os.path.join(os.path.expanduser("~"), ".hermes")
)
)
candidates = [
os.path.join(hermes_home, "node", "bin", resolved_command),
os.path.join(os.path.expanduser("~"), ".local", "bin", resolved_command),
]
for candidate in candidates:
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
resolved_command = candidate
break
command_dir = os.path.dirname(resolved_command)
if command_dir:
resolved_env = _prepend_path(resolved_env, command_dir)
return resolved_command, resolved_env
def _format_connect_error(exc: BaseException) -> str:
"""Render nested MCP connection errors into an actionable short message."""
def _find_missing(current: BaseException) -> Optional[str]:
nested = getattr(current, "exceptions", None)
if nested:
for child in nested:
missing = _find_missing(child)
if missing:
return missing
return None
if isinstance(current, FileNotFoundError):
if getattr(current, "filename", None):
return str(current.filename)
match = re.search(r"No such file or directory: '([^']+)'", str(current))
if match:
return match.group(1)
for attr in ("__cause__", "__context__"):
nested_exc = getattr(current, attr, None)
if isinstance(nested_exc, BaseException):
missing = _find_missing(nested_exc)
if missing:
return missing
return None
def _flatten_messages(current: BaseException) -> List[str]:
nested = getattr(current, "exceptions", None)
if nested:
flattened: List[str] = []
for child in nested:
flattened.extend(_flatten_messages(child))
return flattened
messages = []
text = str(current).strip()
if text:
messages.append(text)
for attr in ("__cause__", "__context__"):
nested_exc = getattr(current, attr, None)
if isinstance(nested_exc, BaseException):
messages.extend(_flatten_messages(nested_exc))
return messages or [current.__class__.__name__]
missing = _find_missing(exc)
if missing:
message = f"missing executable '{missing}'"
if os.path.basename(missing) in {"npx", "npm", "node"}:
message += (
" (ensure Node.js is installed and PATH includes its bin directory, "
"or set mcp_servers.<name>.command to an absolute path and include "
"that directory in mcp_servers.<name>.env.PATH)"
)
return _sanitize_error(message)
deduped: List[str] = []
for item in _flatten_messages(exc):
if item not in deduped:
deduped.append(item)
return _sanitize_error("; ".join(deduped[:3]))
# ---------------------------------------------------------------------------
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
# ---------------------------------------------------------------------------
@ -608,6 +719,7 @@ class MCPServerTask:
)
safe_env = _build_safe_env(user_env)
command, safe_env = _resolve_stdio_command(command, safe_env)
server_params = StdioServerParameters(
command=command,
args=args,
@ -1340,9 +1452,12 @@ def discover_mcp_tools() -> List[str]:
for name, result in zip(server_names, results):
if isinstance(result, Exception):
failed_count += 1
command = new_servers.get(name, {}).get("command")
logger.warning(
"Failed to connect to MCP server '%s': %s",
name, result,
"Failed to connect to MCP server '%s'%s: %s",
name,
f" (command={command})" if command else "",
_format_connect_error(result),
)
elif isinstance(result, list):
all_tools.extend(result)