* fix(mcp-oauth): port mismatch, path traversal, and shared state in OAuth flow Three bugs in the new MCP OAuth 2.1 PKCE implementation: 1. CRITICAL: OAuth redirect port mismatch — build_oauth_auth() calls _find_free_port() to register the redirect_uri, but _wait_for_callback() calls _find_free_port() again getting a DIFFERENT port. Browser redirects to port A, server listens on port B — callback never arrives, 120s timeout. Fix: share the port via module-level _oauth_port variable. 2. MEDIUM: Path traversal via unsanitized server_name — HermesTokenStorage uses server_name directly in filenames. A name like "../../.ssh/config" writes token files outside ~/.hermes/mcp-tokens/. Fix: sanitize server_name with the same regex pattern used elsewhere. 3. MEDIUM: Class-level auth_code/state on _CallbackHandler causes data races if concurrent OAuth flows run. Second callback overwrites first. Fix: factory function _make_callback_handler() returns a handler class with a closure-scoped result dict, isolating each flow. * test: add tests for MCP OAuth path traversal, handler isolation, and port sharing 7 new tests covering: - Path traversal blocked (../../.ssh/config stays in mcp-tokens/) - Dots/slashes sanitized and resolved within base dir - Normal server names preserved - Special characters sanitized (@, :, /) - Concurrent handler result dicts are independent - Handler writes to its own result dict, not class-level - build_oauth_auth stores port in module-level _oauth_port --------- Co-authored-by: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com>
This commit is contained in:
parent
fa6f069577
commit
ed805f57ff
2 changed files with 126 additions and 26 deletions
|
|
@ -35,11 +35,19 @@ _TOKEN_DIR_NAME = "mcp-tokens"
|
|||
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _sanitize_server_name(name: str) -> str:
|
||||
"""Sanitize server name for safe use as a filename."""
|
||||
import re
|
||||
clean = re.sub(r"[^\w\-]", "-", name.strip().lower())
|
||||
clean = re.sub(r"-+", "-", clean).strip("-")
|
||||
return clean[:60] or "unnamed"
|
||||
|
||||
|
||||
class HermesTokenStorage:
|
||||
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
|
||||
|
||||
def __init__(self, server_name: str):
|
||||
self._server_name = server_name
|
||||
self._server_name = _sanitize_server_name(server_name)
|
||||
|
||||
def _base_dir(self) -> Path:
|
||||
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
|
@ -119,21 +127,28 @@ def _find_free_port() -> int:
|
|||
return s.getsockname()[1]
|
||||
|
||||
|
||||
class _CallbackHandler(BaseHTTPRequestHandler):
|
||||
auth_code: str | None = None
|
||||
state: str | None = None
|
||||
def _make_callback_handler():
|
||||
"""Create a callback handler class with instance-scoped result storage."""
|
||||
result = {"auth_code": None, "state": None}
|
||||
|
||||
def do_GET(self):
|
||||
qs = parse_qs(urlparse(self.path).query)
|
||||
_CallbackHandler.auth_code = (qs.get("code") or [None])[0]
|
||||
_CallbackHandler.state = (qs.get("state") or [None])[0]
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
qs = parse_qs(urlparse(self.path).query)
|
||||
result["auth_code"] = (qs.get("code") or [None])[0]
|
||||
result["state"] = (qs.get("state") or [None])[0]
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
|
||||
|
||||
def log_message(self, *_args: Any) -> None:
|
||||
pass # suppress HTTP log noise
|
||||
def log_message(self, *_args: Any) -> None:
|
||||
pass
|
||||
|
||||
return Handler, result
|
||||
|
||||
|
||||
# Port chosen at build time and shared with the callback handler via closure.
|
||||
_oauth_port: int | None = None
|
||||
|
||||
|
||||
async def _redirect_to_browser(auth_url: str) -> None:
|
||||
|
|
@ -149,11 +164,11 @@ async def _redirect_to_browser(auth_url: str) -> None:
|
|||
|
||||
|
||||
async def _wait_for_callback() -> tuple[str, str | None]:
|
||||
"""Start a local HTTP server and wait for the OAuth redirect callback."""
|
||||
port = _find_free_port()
|
||||
server = HTTPServer(("127.0.0.1", port), _CallbackHandler)
|
||||
_CallbackHandler.auth_code = None
|
||||
_CallbackHandler.state = None
|
||||
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect."""
|
||||
global _oauth_port
|
||||
port = _oauth_port or _find_free_port()
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
server = HTTPServer(("127.0.0.1", port), HandlerClass)
|
||||
|
||||
def _serve():
|
||||
server.timeout = 120
|
||||
|
|
@ -162,17 +177,15 @@ async def _wait_for_callback() -> tuple[str, str | None]:
|
|||
thread = threading.Thread(target=_serve, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Wait for the callback
|
||||
for _ in range(1200): # 120 seconds
|
||||
await asyncio.sleep(0.1)
|
||||
if _CallbackHandler.auth_code is not None:
|
||||
if result["auth_code"] is not None:
|
||||
break
|
||||
|
||||
server.server_close()
|
||||
code = _CallbackHandler.auth_code or ""
|
||||
state = _CallbackHandler.state
|
||||
code = result["auth_code"] or ""
|
||||
state = result["state"]
|
||||
if not code:
|
||||
# Fallback to manual entry
|
||||
print(" Browser callback timed out. Paste the authorization code manually:")
|
||||
code = input(" Code: ").strip()
|
||||
return code, state
|
||||
|
|
@ -206,8 +219,9 @@ def build_oauth_auth(server_name: str, server_url: str):
|
|||
logger.warning("MCP SDK auth module not available — OAuth disabled")
|
||||
return None
|
||||
|
||||
port = _find_free_port()
|
||||
redirect_uri = f"http://127.0.0.1:{port}/callback"
|
||||
global _oauth_port
|
||||
_oauth_port = _find_free_port()
|
||||
redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback"
|
||||
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name="Hermes Agent",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue