Merge pull request #796 from 0xbyt4/fix/discovery-failed-count
Clean bug fix — failed MCP server connections were silently swallowed, making failed_count dead code. Well-tested.
This commit is contained in:
commit
24a0c08d58
2 changed files with 131 additions and 13 deletions
|
|
@ -2326,3 +2326,127 @@ class TestMCPServerTaskSamplingIntegration:
|
||||||
kwargs = server._sampling.session_kwargs()
|
kwargs = server._sampling.session_kwargs()
|
||||||
assert "sampling_callback" in kwargs
|
assert "sampling_callback" in kwargs
|
||||||
assert "sampling_capabilities" in kwargs
|
assert "sampling_capabilities" in kwargs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Discovery failed_count tracking
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDiscoveryFailedCount:
|
||||||
|
"""Verify discover_mcp_tools() correctly tracks failed server connections."""
|
||||||
|
|
||||||
|
def test_failed_server_increments_failed_count(self):
|
||||||
|
"""When _discover_and_register_server raises, failed_count increments."""
|
||||||
|
from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop
|
||||||
|
|
||||||
|
fake_config = {
|
||||||
|
"good_server": {"command": "npx", "args": ["good"]},
|
||||||
|
"bad_server": {"command": "npx", "args": ["bad"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_register(name, cfg):
|
||||||
|
if name == "bad_server":
|
||||||
|
raise ConnectionError("Connection refused")
|
||||||
|
# Simulate successful registration
|
||||||
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
server = MCPServerTask(name)
|
||||||
|
server.session = MagicMock()
|
||||||
|
server._tools = [_make_mcp_tool("tool_a")]
|
||||||
|
_servers[name] = server
|
||||||
|
return [f"mcp_{name}_tool_a"]
|
||||||
|
|
||||||
|
with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
|
||||||
|
patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
|
||||||
|
patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||||
|
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_good_server_tool_a"]):
|
||||||
|
_ensure_mcp_loop()
|
||||||
|
|
||||||
|
# Capture the logger to verify failed_count in summary
|
||||||
|
with patch("tools.mcp_tool.logger") as mock_logger:
|
||||||
|
discover_mcp_tools()
|
||||||
|
|
||||||
|
# Find the summary info call
|
||||||
|
info_calls = [
|
||||||
|
str(call)
|
||||||
|
for call in mock_logger.info.call_args_list
|
||||||
|
if "failed" in str(call).lower() or "MCP:" in str(call)
|
||||||
|
]
|
||||||
|
# The summary should mention the failure
|
||||||
|
assert any("1 failed" in str(c) for c in info_calls), (
|
||||||
|
f"Summary should report 1 failed server, got: {info_calls}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_servers.pop("good_server", None)
|
||||||
|
_servers.pop("bad_server", None)
|
||||||
|
|
||||||
|
def test_all_servers_fail_still_prints_summary(self):
|
||||||
|
"""When all servers fail, a summary with failure count is still printed."""
|
||||||
|
from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop
|
||||||
|
|
||||||
|
fake_config = {
|
||||||
|
"srv1": {"command": "npx", "args": ["a"]},
|
||||||
|
"srv2": {"command": "npx", "args": ["b"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def always_fail(name, cfg):
|
||||||
|
raise ConnectionError(f"Server {name} refused")
|
||||||
|
|
||||||
|
with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
|
||||||
|
patch("tools.mcp_tool._discover_and_register_server", side_effect=always_fail), \
|
||||||
|
patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||||
|
patch("tools.mcp_tool._existing_tool_names", return_value=[]):
|
||||||
|
_ensure_mcp_loop()
|
||||||
|
|
||||||
|
with patch("tools.mcp_tool.logger") as mock_logger:
|
||||||
|
discover_mcp_tools()
|
||||||
|
|
||||||
|
# Summary must be printed even when all servers fail
|
||||||
|
info_calls = [str(call) for call in mock_logger.info.call_args_list]
|
||||||
|
assert any("2 failed" in str(c) for c in info_calls), (
|
||||||
|
f"Summary should report 2 failed servers, got: {info_calls}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_servers.pop("srv1", None)
|
||||||
|
_servers.pop("srv2", None)
|
||||||
|
|
||||||
|
def test_ok_servers_excludes_failures(self):
|
||||||
|
"""ok_servers count correctly excludes failed servers."""
|
||||||
|
from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop
|
||||||
|
|
||||||
|
fake_config = {
|
||||||
|
"ok1": {"command": "npx", "args": ["ok1"]},
|
||||||
|
"ok2": {"command": "npx", "args": ["ok2"]},
|
||||||
|
"fail1": {"command": "npx", "args": ["fail"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def selective_register(name, cfg):
|
||||||
|
if name == "fail1":
|
||||||
|
raise ConnectionError("Refused")
|
||||||
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
server = MCPServerTask(name)
|
||||||
|
server.session = MagicMock()
|
||||||
|
server._tools = [_make_mcp_tool("t")]
|
||||||
|
_servers[name] = server
|
||||||
|
return [f"mcp_{name}_t"]
|
||||||
|
|
||||||
|
with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
|
||||||
|
patch("tools.mcp_tool._discover_and_register_server", side_effect=selective_register), \
|
||||||
|
patch("tools.mcp_tool._MCP_AVAILABLE", True), \
|
||||||
|
patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_ok1_t", "mcp_ok2_t"]):
|
||||||
|
_ensure_mcp_loop()
|
||||||
|
|
||||||
|
with patch("tools.mcp_tool.logger") as mock_logger:
|
||||||
|
discover_mcp_tools()
|
||||||
|
|
||||||
|
info_calls = [str(call) for call in mock_logger.info.call_args_list]
|
||||||
|
# Should say "2 server(s)" not "3 server(s)"
|
||||||
|
assert any("2 server" in str(c) for c in info_calls), (
|
||||||
|
f"Summary should report 2 ok servers, got: {info_calls}"
|
||||||
|
)
|
||||||
|
assert any("1 failed" in str(c) for c in info_calls), (
|
||||||
|
f"Summary should report 1 failed, got: {info_calls}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_servers.pop("ok1", None)
|
||||||
|
_servers.pop("ok2", None)
|
||||||
|
_servers.pop("fail1", None)
|
||||||
|
|
|
||||||
|
|
@ -1331,29 +1331,23 @@ def discover_mcp_tools() -> List[str]:
|
||||||
|
|
||||||
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
async def _discover_one(name: str, cfg: dict) -> List[str]:
|
||||||
"""Connect to a single server and return its registered tool names."""
|
"""Connect to a single server and return its registered tool names."""
|
||||||
transport_desc = cfg.get("url", f'{cfg.get("command", "?")} {" ".join(cfg.get("args", [])[:2])}')
|
return await _discover_and_register_server(name, cfg)
|
||||||
try:
|
|
||||||
registered = await _discover_and_register_server(name, cfg)
|
|
||||||
transport_type = "HTTP" if "url" in cfg else "stdio"
|
|
||||||
return registered
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"Failed to connect to MCP server '%s': %s",
|
|
||||||
name, exc,
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def _discover_all():
|
async def _discover_all():
|
||||||
nonlocal failed_count
|
nonlocal failed_count
|
||||||
|
server_names = list(new_servers.keys())
|
||||||
# Connect to all servers in PARALLEL
|
# Connect to all servers in PARALLEL
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(_discover_one(name, cfg) for name, cfg in new_servers.items()),
|
*(_discover_one(name, cfg) for name, cfg in new_servers.items()),
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
for result in results:
|
for name, result in zip(server_names, results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
logger.warning("MCP discovery error: %s", result)
|
logger.warning(
|
||||||
|
"Failed to connect to MCP server '%s': %s",
|
||||||
|
name, result,
|
||||||
|
)
|
||||||
elif isinstance(result, list):
|
elif isinstance(result, list):
|
||||||
all_tools.extend(result)
|
all_tools.extend(result)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue