from __future__ import annotations import argparse import asyncio import json import os import time from dataclasses import asdict, dataclass from pathlib import Path from urllib.parse import urljoin import aiohttp from adapter.matrix.agent_registry import AgentDefinition, load_agent_registry from sdk.real import RealPlatformClient @dataclass class AgentCheckResult: agent_id: str label: str chat_id: str base_url: str ws_url: str ok: bool stage: str latency_ms: int error: str = "" response_type: str = "" def build_agent_ws_url(base_url: str, chat_id: str) -> str: normalized = RealPlatformClient._normalize_agent_base_url(base_url) return urljoin(normalized, f"v1/agent_ws/{chat_id}/") def _message_type(payload: str) -> str: try: data = json.loads(payload) except json.JSONDecodeError: return "" value = data.get("type") return value if isinstance(value, str) else "" async def _receive_text(ws: aiohttp.ClientWebSocketResponse, timeout: float) -> str: msg = await asyncio.wait_for(ws.receive(), timeout=timeout) if msg.type == aiohttp.WSMsgType.TEXT: return str(msg.data) if msg.type == aiohttp.WSMsgType.ERROR: raise RuntimeError(f"websocket error: {ws.exception()}") raise RuntimeError(f"unexpected websocket message type: {msg.type.name}") async def check_agent( agent: AgentDefinition, *, fallback_base_url: str, chat_id: str, timeout: float, message: str | None, ) -> AgentCheckResult: base_url = agent.base_url or fallback_base_url ws_url = build_agent_ws_url(base_url, chat_id) if base_url else "" started = time.perf_counter() def result(ok: bool, stage: str, error: str = "", response_type: str = "") -> AgentCheckResult: return AgentCheckResult( agent_id=agent.agent_id, label=agent.label, chat_id=chat_id, base_url=base_url, ws_url=ws_url, ok=ok, stage=stage, latency_ms=int((time.perf_counter() - started) * 1000), error=error, response_type=response_type, ) if not base_url: return result(False, "config", "missing base_url and AGENT_BASE_URL") try: client_timeout = aiohttp.ClientTimeout( total=timeout, connect=timeout, sock_connect=timeout, sock_read=timeout, ) async with aiohttp.ClientSession(timeout=client_timeout) as session: async with session.ws_connect(ws_url, heartbeat=30) as ws: raw_status = await _receive_text(ws, timeout) status_type = _message_type(raw_status) if status_type != "STATUS": return result( False, "status", f"expected STATUS, got {raw_status[:200]}", status_type, ) if not message: return result(True, "status", response_type=status_type) payload = { "type": "USER_MESSAGE", "text": message, "attachments": [], } await ws.send_str(json.dumps(payload)) while True: raw_event = await _receive_text(ws, timeout) event_type = _message_type(raw_event) if event_type == "ERROR": return result(False, "message", raw_event[:200], event_type) if event_type == "AGENT_EVENT_END": return result(True, "message", response_type=event_type) if not event_type: return result(False, "message", f"invalid JSON event: {raw_event[:200]}") except TimeoutError: return result(False, "timeout", f"no response within {timeout:g}s") except Exception as exc: return result(False, "connect", str(exc)) def _select_agents( agents: tuple[AgentDefinition, ...], selected: set[str], ) -> list[AgentDefinition]: if not selected: return list(agents) return [agent for agent in agents if agent.agent_id in selected] async def run_checks(args: argparse.Namespace) -> list[AgentCheckResult]: registry = load_agent_registry(args.config) selected = _select_agents(registry.agents, set(args.agent)) if not selected: raise SystemExit("no matching agents selected") fallback_base_url = args.base_url or os.environ.get("AGENT_BASE_URL", "") semaphore = asyncio.Semaphore(args.concurrency) async def run_one(index: int, agent: AgentDefinition) -> AgentCheckResult: chat_id = str(args.chat_id if args.chat_id is not None else args.chat_id_base + index) async with semaphore: return await check_agent( agent, fallback_base_url=fallback_base_url, chat_id=chat_id, timeout=args.timeout, message=args.message, ) return await asyncio.gather(*(run_one(index, agent) for index, agent in enumerate(selected))) def print_table(results: list[AgentCheckResult]) -> None: for item in results: status = "OK" if item.ok else "FAIL" detail = item.response_type or item.error print( f"{status:4} {item.agent_id:20} {item.stage:8} " f"{item.latency_ms:5}ms chat={item.chat_id} url={item.ws_url} {detail}" ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Smoke-check Matrix agent WebSocket endpoints from matrix-agents.yaml." ) parser.add_argument("--config", type=Path, default=Path("config/matrix-agents.yaml")) parser.add_argument("--agent", action="append", default=[], help="Agent id to check") parser.add_argument("--base-url", default="", help="Fallback base URL when an agent has none") parser.add_argument("--timeout", type=float, default=10.0) parser.add_argument("--concurrency", type=int, default=5) parser.add_argument("--chat-id", type=int, default=None, help="Use one explicit chat id") parser.add_argument("--chat-id-base", type=int, default=900000) parser.add_argument("--message", default=None, help="Optional test message after STATUS") parser.add_argument("--json", action="store_true", help="Print machine-readable JSON") return parser.parse_args() def main() -> int: args = parse_args() results = asyncio.run(run_checks(args)) if args.json: print(json.dumps([asdict(result) for result in results], ensure_ascii=False, indent=2)) else: print_table(results) return 0 if all(result.ok for result in results) else 1 if __name__ == "__main__": raise SystemExit(main())