diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 528273f9..bd346376 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -162,7 +162,7 @@ def list_available_providers() -> list[dict[str, str]]: _PROVIDER_ORDER = [ "openrouter", "nous", "openai-codex", "zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", - "ai-gateway", "deepseek", + "ai-gateway", "deepseek", "custom", ] # Build reverse alias map aliases_for: dict[str, list[str]] = {} @@ -176,9 +176,12 @@ def list_available_providers() -> list[dict[str, str]]: # Check if this provider has credentials available has_creds = False try: - from hermes_cli.runtime_provider import resolve_runtime_provider - runtime = resolve_runtime_provider(requested=pid) - has_creds = bool(runtime.get("api_key")) + if pid == "custom": + has_creds = bool(_get_custom_base_url()) + else: + from hermes_cli.runtime_provider import resolve_runtime_provider + runtime = resolve_runtime_provider(requested=pid) + has_creds = bool(runtime.get("api_key")) except Exception: pass result.append({ @@ -217,6 +220,19 @@ def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]: return (current_provider, stripped) +def _get_custom_base_url() -> str: + """Get the custom endpoint base_url from config.yaml.""" + try: + from hermes_cli.config import load_config + config = load_config() + model_cfg = config.get("model", {}) + if isinstance(model_cfg, dict): + return str(model_cfg.get("base_url", "")).strip() + except Exception: + pass + return "" + + def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]: """Return ``(model_id, description)`` tuples for a provider's model list. @@ -396,6 +412,18 @@ def provider_model_ids(provider: Optional[str]) -> list[str]: live = _fetch_ai_gateway_models() if live: return live + if normalized == "custom": + base_url = _get_custom_base_url() + if base_url: + # Try common API key env vars for custom endpoints + api_key = ( + os.getenv("CUSTOM_API_KEY", "") + or os.getenv("OPENAI_API_KEY", "") + or os.getenv("OPENROUTER_API_KEY", "") + ) + live = fetch_api_models(api_key, base_url) + if live: + return live return list(_PROVIDER_MODELS.get(normalized, []))