feat(mcp): add sampling support — server-initiated LLM requests (#753)
Add MCP sampling/createMessage capability via SamplingHandler class. Text-only sampling + tool use in sampling with governance (rate limits, model whitelist, token caps, tool loop limits). Per-server audit metrics. Based on concept from PR #366 by eren-karakus0. Restructured as class-based design with bug fixes and tests using real MCP SDK types. 50 new tests, 2600 total passing.
This commit is contained in:
parent
1f0944de21
commit
654e16187e
5 changed files with 1307 additions and 4 deletions
|
|
@ -29,6 +29,18 @@ Example config::
|
|||
headers:
|
||||
Authorization: "Bearer sk-..."
|
||||
timeout: 180
|
||||
analysis:
|
||||
command: "npx"
|
||||
args: ["-y", "analysis-server"]
|
||||
sampling: # server-initiated LLM requests
|
||||
enabled: true # default: true
|
||||
model: "gemini-3-flash" # override model (optional)
|
||||
max_tokens_cap: 4096 # max tokens per request
|
||||
timeout: 30 # LLM call timeout (seconds)
|
||||
max_rpm: 10 # max requests per minute
|
||||
allowed_models: [] # model whitelist (empty = all)
|
||||
max_tool_rounds: 5 # tool loop limit (0 = disable)
|
||||
log_level: "info" # audit verbosity
|
||||
|
||||
Features:
|
||||
- Stdio transport (command + args) and HTTP/StreamableHTTP transport (url)
|
||||
|
|
@ -37,6 +49,8 @@ Features:
|
|||
- Credential stripping in error messages returned to the LLM
|
||||
- Configurable per-server timeouts for tool calls and connections
|
||||
- Thread-safe architecture with dedicated background event loop
|
||||
- Sampling support: MCP servers can request LLM completions via
|
||||
sampling/createMessage (text and tool-use responses)
|
||||
|
||||
Architecture:
|
||||
A dedicated background event loop (_mcp_loop) runs in a daemon thread.
|
||||
|
|
@ -58,9 +72,11 @@ Thread safety:
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -71,6 +87,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
_MCP_AVAILABLE = False
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
_MCP_SAMPLING_TYPES = False
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
|
@ -80,6 +97,20 @@ try:
|
|||
_MCP_HTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
_MCP_HTTP_AVAILABLE = False
|
||||
# Sampling types -- separated so older SDK versions don't break MCP support
|
||||
try:
|
||||
from mcp.types import (
|
||||
CreateMessageResult,
|
||||
CreateMessageResultWithTools,
|
||||
ErrorData,
|
||||
SamplingCapability,
|
||||
SamplingToolsCapability,
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
_MCP_SAMPLING_TYPES = True
|
||||
except ImportError:
|
||||
logger.debug("MCP sampling types not available -- sampling disabled")
|
||||
except ImportError:
|
||||
logger.debug("mcp package not installed -- MCP tool support disabled")
|
||||
|
||||
|
|
@ -145,6 +176,386 @@ def _sanitize_error(text: str) -> str:
|
|||
return _CREDENTIAL_PATTERN.sub("[REDACTED]", text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sampling -- server-initiated LLM requests (MCP sampling/createMessage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_numeric(value, default, coerce=int, minimum=1):
|
||||
"""Coerce a config value to a numeric type, returning *default* on failure.
|
||||
|
||||
Handles string values from YAML (e.g. ``"10"`` instead of ``10``),
|
||||
non-finite floats, and values below *minimum*.
|
||||
"""
|
||||
try:
|
||||
result = coerce(value)
|
||||
if isinstance(result, float) and not math.isfinite(result):
|
||||
return default
|
||||
return max(result, minimum)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
return default
|
||||
|
||||
|
||||
class SamplingHandler:
|
||||
"""Handles sampling/createMessage requests for a single MCP server.
|
||||
|
||||
Each MCPServerTask that has sampling enabled creates one SamplingHandler.
|
||||
The handler is callable and passed directly to ``ClientSession`` as
|
||||
the ``sampling_callback``. All state (rate-limit timestamps, metrics,
|
||||
tool-loop counters) lives on the instance -- no module-level globals.
|
||||
|
||||
The callback is async and runs on the MCP background event loop. The
|
||||
sync LLM call is offloaded to a thread via ``asyncio.to_thread()`` so
|
||||
it doesn't block the event loop.
|
||||
"""
|
||||
|
||||
_STOP_REASON_MAP = {"stop": "endTurn", "length": "maxTokens", "tool_calls": "toolUse"}
|
||||
|
||||
def __init__(self, server_name: str, config: dict):
|
||||
self.server_name = server_name
|
||||
self.max_rpm = _safe_numeric(config.get("max_rpm", 10), 10, int)
|
||||
self.timeout = _safe_numeric(config.get("timeout", 30), 30, float)
|
||||
self.max_tokens_cap = _safe_numeric(config.get("max_tokens_cap", 4096), 4096, int)
|
||||
self.max_tool_rounds = _safe_numeric(
|
||||
config.get("max_tool_rounds", 5), 5, int, minimum=0,
|
||||
)
|
||||
self.model_override = config.get("model")
|
||||
self.allowed_models = config.get("allowed_models", [])
|
||||
|
||||
_log_levels = {"debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING}
|
||||
self.audit_level = _log_levels.get(
|
||||
str(config.get("log_level", "info")).lower(), logging.INFO,
|
||||
)
|
||||
|
||||
# Per-instance state
|
||||
self._rate_timestamps: List[float] = []
|
||||
self._tool_loop_count = 0
|
||||
self.metrics = {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
||||
|
||||
# -- Rate limiting -------------------------------------------------------
|
||||
|
||||
def _check_rate_limit(self) -> bool:
|
||||
"""Sliding-window rate limiter. Returns True if request is allowed."""
|
||||
now = time.time()
|
||||
window = now - 60
|
||||
self._rate_timestamps[:] = [t for t in self._rate_timestamps if t > window]
|
||||
if len(self._rate_timestamps) >= self.max_rpm:
|
||||
return False
|
||||
self._rate_timestamps.append(now)
|
||||
return True
|
||||
|
||||
# -- Model resolution ----------------------------------------------------
|
||||
|
||||
def _resolve_model(self, preferences) -> Optional[str]:
|
||||
"""Config override > server hint > None (use default)."""
|
||||
if self.model_override:
|
||||
return self.model_override
|
||||
if preferences and hasattr(preferences, "hints") and preferences.hints:
|
||||
for hint in preferences.hints:
|
||||
if hasattr(hint, "name") and hint.name:
|
||||
return hint.name
|
||||
return None
|
||||
|
||||
# -- Message conversion --------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_result_text(block) -> str:
|
||||
"""Extract text from a ToolResultContent block."""
|
||||
if not hasattr(block, "content") or block.content is None:
|
||||
return ""
|
||||
items = block.content if isinstance(block.content, list) else [block.content]
|
||||
return "\n".join(item.text for item in items if hasattr(item, "text"))
|
||||
|
||||
def _convert_messages(self, params) -> List[dict]:
|
||||
"""Convert MCP SamplingMessages to OpenAI format.
|
||||
|
||||
Uses ``msg.content_as_list`` (SDK helper) so single-block and
|
||||
list-of-blocks are handled uniformly. Dispatches per block type
|
||||
with ``isinstance`` on real SDK types when available, falling back
|
||||
to duck-typing via ``hasattr`` for compatibility.
|
||||
"""
|
||||
messages: List[dict] = []
|
||||
for msg in params.messages:
|
||||
blocks = msg.content_as_list if hasattr(msg, "content_as_list") else (
|
||||
msg.content if isinstance(msg.content, list) else [msg.content]
|
||||
)
|
||||
|
||||
# Separate blocks by kind
|
||||
tool_results = [b for b in blocks if hasattr(b, "toolUseId")]
|
||||
tool_uses = [b for b in blocks if hasattr(b, "name") and hasattr(b, "input") and not hasattr(b, "toolUseId")]
|
||||
content_blocks = [b for b in blocks if not hasattr(b, "toolUseId") and not (hasattr(b, "name") and hasattr(b, "input"))]
|
||||
|
||||
# Emit tool result messages (role: tool)
|
||||
for tr in tool_results:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.toolUseId,
|
||||
"content": self._extract_tool_result_text(tr),
|
||||
})
|
||||
|
||||
# Emit assistant tool_calls message
|
||||
if tool_uses:
|
||||
tc_list = []
|
||||
for tu in tool_uses:
|
||||
tc_list.append({
|
||||
"id": getattr(tu, "id", f"call_{len(tc_list)}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tu.name,
|
||||
"arguments": json.dumps(tu.input) if isinstance(tu.input, dict) else str(tu.input),
|
||||
},
|
||||
})
|
||||
msg_dict: dict = {"role": msg.role, "tool_calls": tc_list}
|
||||
# Include any accompanying text
|
||||
text_parts = [b.text for b in content_blocks if hasattr(b, "text")]
|
||||
if text_parts:
|
||||
msg_dict["content"] = "\n".join(text_parts)
|
||||
messages.append(msg_dict)
|
||||
elif content_blocks:
|
||||
# Pure text/image content
|
||||
if len(content_blocks) == 1 and hasattr(content_blocks[0], "text"):
|
||||
messages.append({"role": msg.role, "content": content_blocks[0].text})
|
||||
else:
|
||||
parts = []
|
||||
for block in content_blocks:
|
||||
if hasattr(block, "text"):
|
||||
parts.append({"type": "text", "text": block.text})
|
||||
elif hasattr(block, "data") and hasattr(block, "mimeType"):
|
||||
parts.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{block.mimeType};base64,{block.data}"},
|
||||
})
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported sampling content block type: %s (skipped)",
|
||||
type(block).__name__,
|
||||
)
|
||||
if parts:
|
||||
messages.append({"role": msg.role, "content": parts})
|
||||
|
||||
return messages
|
||||
|
||||
# -- Error helper --------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _error(message: str, code: int = -1):
|
||||
"""Return ErrorData (MCP spec) or raise as fallback."""
|
||||
if _MCP_SAMPLING_TYPES:
|
||||
return ErrorData(code=code, message=message)
|
||||
raise Exception(message)
|
||||
|
||||
# -- Response building ---------------------------------------------------
|
||||
|
||||
def _build_tool_use_result(self, choice, response):
|
||||
"""Build a CreateMessageResultWithTools from an LLM tool_calls response."""
|
||||
self.metrics["tool_use_count"] += 1
|
||||
|
||||
# Tool loop governance
|
||||
if self.max_tool_rounds == 0:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loops disabled for server '{self.server_name}' (max_tool_rounds=0)"
|
||||
)
|
||||
|
||||
self._tool_loop_count += 1
|
||||
if self._tool_loop_count > self.max_tool_rounds:
|
||||
self._tool_loop_count = 0
|
||||
return self._error(
|
||||
f"Tool loop limit exceeded for server '{self.server_name}' "
|
||||
f"(max {self.max_tool_rounds} rounds)"
|
||||
)
|
||||
|
||||
content_blocks = []
|
||||
for tc in choice.message.tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
parsed = json.loads(args)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.warning(
|
||||
"MCP server '%s': malformed tool_calls arguments "
|
||||
"from LLM (wrapping as raw): %.100s",
|
||||
self.server_name, args,
|
||||
)
|
||||
parsed = {"_raw": args}
|
||||
else:
|
||||
parsed = args if isinstance(args, dict) else {"_raw": str(args)}
|
||||
|
||||
content_blocks.append(ToolUseContent(
|
||||
type="tool_use",
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
input=parsed,
|
||||
))
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s, tool_calls=%d",
|
||||
self.server_name, response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
len(content_blocks),
|
||||
)
|
||||
|
||||
return CreateMessageResultWithTools(
|
||||
role="assistant",
|
||||
content=content_blocks,
|
||||
model=response.model,
|
||||
stopReason="toolUse",
|
||||
)
|
||||
|
||||
def _build_text_result(self, choice, response):
|
||||
"""Build a CreateMessageResult from a normal text response."""
|
||||
self._tool_loop_count = 0 # reset on text response
|
||||
response_text = choice.message.content or ""
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling response: model=%s, tokens=%s",
|
||||
self.server_name, response.model,
|
||||
getattr(getattr(response, "usage", None), "total_tokens", "?"),
|
||||
)
|
||||
|
||||
return CreateMessageResult(
|
||||
role="assistant",
|
||||
content=TextContent(type="text", text=_sanitize_error(response_text)),
|
||||
model=response.model,
|
||||
stopReason=self._STOP_REASON_MAP.get(choice.finish_reason, "endTurn"),
|
||||
)
|
||||
|
||||
# -- Session kwargs helper -----------------------------------------------
|
||||
|
||||
def session_kwargs(self) -> dict:
|
||||
"""Return kwargs to pass to ClientSession for sampling support."""
|
||||
return {
|
||||
"sampling_callback": self,
|
||||
"sampling_capabilities": SamplingCapability(
|
||||
tools=SamplingToolsCapability(),
|
||||
),
|
||||
}
|
||||
|
||||
# -- Main callback -------------------------------------------------------
|
||||
|
||||
async def __call__(self, context, params):
|
||||
"""Sampling callback invoked by the MCP SDK.
|
||||
|
||||
Conforms to ``SamplingFnT`` protocol. Returns
|
||||
``CreateMessageResult``, ``CreateMessageResultWithTools``, or
|
||||
``ErrorData``.
|
||||
"""
|
||||
# Rate limit
|
||||
if not self._check_rate_limit():
|
||||
logger.warning(
|
||||
"MCP server '%s' sampling rate limit exceeded (%d/min)",
|
||||
self.server_name, self.max_rpm,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling rate limit exceeded for server '{self.server_name}' "
|
||||
f"({self.max_rpm} requests/minute)"
|
||||
)
|
||||
|
||||
# Resolve model
|
||||
model = self._resolve_model(getattr(params, "modelPreferences", None))
|
||||
|
||||
# Get auxiliary LLM client
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
client, default_model = get_text_auxiliary_client()
|
||||
if client is None:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error("No LLM provider available for sampling")
|
||||
|
||||
resolved_model = model or default_model
|
||||
|
||||
# Model whitelist check
|
||||
if self.allowed_models and resolved_model not in self.allowed_models:
|
||||
logger.warning(
|
||||
"MCP server '%s' requested model '%s' not in allowed_models",
|
||||
self.server_name, resolved_model,
|
||||
)
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Model '{resolved_model}' not allowed for server "
|
||||
f"'{self.server_name}'. Allowed: {', '.join(self.allowed_models)}"
|
||||
)
|
||||
|
||||
# Convert messages
|
||||
messages = self._convert_messages(params)
|
||||
if hasattr(params, "systemPrompt") and params.systemPrompt:
|
||||
messages.insert(0, {"role": "system", "content": params.systemPrompt})
|
||||
|
||||
# Build LLM call kwargs
|
||||
max_tokens = min(params.maxTokens, self.max_tokens_cap)
|
||||
call_kwargs: dict = {
|
||||
"model": resolved_model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if hasattr(params, "temperature") and params.temperature is not None:
|
||||
call_kwargs["temperature"] = params.temperature
|
||||
if stop := getattr(params, "stopSequences", None):
|
||||
call_kwargs["stop"] = stop
|
||||
|
||||
# Forward server-provided tools
|
||||
server_tools = getattr(params, "tools", None)
|
||||
if server_tools:
|
||||
call_kwargs["tools"] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": getattr(t, "name", ""),
|
||||
"description": getattr(t, "description", "") or "",
|
||||
"parameters": getattr(t, "inputSchema", {}) or {},
|
||||
},
|
||||
}
|
||||
for t in server_tools
|
||||
]
|
||||
if tool_choice := getattr(params, "toolChoice", None):
|
||||
mode = getattr(tool_choice, "mode", "auto")
|
||||
call_kwargs["tool_choice"] = {"auto": "auto", "required": "required", "none": "none"}.get(mode, "auto")
|
||||
|
||||
logger.log(
|
||||
self.audit_level,
|
||||
"MCP server '%s' sampling request: model=%s, max_tokens=%d, messages=%d",
|
||||
self.server_name, resolved_model, max_tokens, len(messages),
|
||||
)
|
||||
|
||||
# Offload sync LLM call to thread (non-blocking)
|
||||
def _sync_call():
|
||||
return client.chat.completions.create(**call_kwargs)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
asyncio.to_thread(_sync_call), timeout=self.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call timed out after {self.timeout}s "
|
||||
f"for server '{self.server_name}'"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.metrics["errors"] += 1
|
||||
return self._error(
|
||||
f"Sampling LLM call failed: {_sanitize_error(str(exc))}"
|
||||
)
|
||||
|
||||
# Track metrics
|
||||
choice = response.choices[0]
|
||||
self.metrics["requests"] += 1
|
||||
total_tokens = getattr(getattr(response, "usage", None), "total_tokens", 0)
|
||||
if isinstance(total_tokens, int):
|
||||
self.metrics["tokens_used"] += total_tokens
|
||||
|
||||
# Dispatch based on response type
|
||||
if (
|
||||
choice.finish_reason == "tool_calls"
|
||||
and hasattr(choice.message, "tool_calls")
|
||||
and choice.message.tool_calls
|
||||
):
|
||||
return self._build_tool_use_result(choice, response)
|
||||
|
||||
return self._build_text_result(choice, response)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server task -- each MCP server lives in one long-lived asyncio Task
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -162,6 +573,7 @@ class MCPServerTask:
|
|||
__slots__ = (
|
||||
"name", "session", "tool_timeout",
|
||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
||||
"_sampling",
|
||||
)
|
||||
|
||||
def __init__(self, name: str):
|
||||
|
|
@ -174,6 +586,7 @@ class MCPServerTask:
|
|||
self._tools: list = []
|
||||
self._error: Optional[Exception] = None
|
||||
self._config: dict = {}
|
||||
self._sampling: Optional[SamplingHandler] = None
|
||||
|
||||
def _is_http(self) -> bool:
|
||||
"""Check if this server uses HTTP transport."""
|
||||
|
|
@ -197,8 +610,9 @@ class MCPServerTask:
|
|||
env=safe_env if safe_env else None,
|
||||
)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
async with stdio_client(server_params) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
|
|
@ -218,12 +632,13 @@ class MCPServerTask:
|
|||
headers = config.get("headers")
|
||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
|
||||
sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {}
|
||||
async with streamablehttp_client(
|
||||
url,
|
||||
headers=headers,
|
||||
timeout=float(connect_timeout),
|
||||
) as (read_stream, write_stream, _get_session_id):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
|
|
@ -250,6 +665,13 @@ class MCPServerTask:
|
|||
self._config = config
|
||||
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
||||
|
||||
# Set up sampling handler if enabled and SDK types are available
|
||||
sampling_config = config.get("sampling", {})
|
||||
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
||||
self._sampling = SamplingHandler(self.name, sampling_config)
|
||||
else:
|
||||
self._sampling = None
|
||||
|
||||
# Validate: warn if both url and command are present
|
||||
if "url" in config and "command" in config:
|
||||
logger.warning(
|
||||
|
|
@ -975,12 +1397,15 @@ def get_mcp_status() -> List[dict]:
|
|||
transport = "http" if "url" in cfg else "stdio"
|
||||
server = active_servers.get(name)
|
||||
if server and server.session is not None:
|
||||
result.append({
|
||||
entry = {
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tools": len(server._tools),
|
||||
"connected": True,
|
||||
})
|
||||
}
|
||||
if server._sampling:
|
||||
entry["sampling"] = dict(server._sampling.metrics)
|
||||
result.append(entry)
|
||||
else:
|
||||
result.append({
|
||||
"name": name,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue