fix: resolve named custom delegation providers
This commit is contained in:
parent
6d8286f396
commit
4422637e7a
2 changed files with 168 additions and 1 deletions
|
|
@ -18,6 +18,10 @@ from hermes_cli.config import load_config
|
||||||
from hermes_constants import OPENROUTER_BASE_URL
|
from hermes_constants import OPENROUTER_BASE_URL
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_custom_provider_name(value: str) -> str:
|
||||||
|
return value.strip().lower().replace(" ", "-")
|
||||||
|
|
||||||
|
|
||||||
def _get_model_config() -> Dict[str, Any]:
|
def _get_model_config() -> Dict[str, Any]:
|
||||||
config = load_config()
|
config = load_config()
|
||||||
model_cfg = config.get("model")
|
model_cfg = config.get("model")
|
||||||
|
|
@ -47,6 +51,69 @@ def resolve_requested_provider(requested: Optional[str] = None) -> str:
|
||||||
return "auto"
|
return "auto"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, Any]]:
|
||||||
|
requested_norm = _normalize_custom_provider_name(requested_provider or "")
|
||||||
|
if not requested_norm or requested_norm == "custom":
|
||||||
|
return None
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
custom_providers = config.get("custom_providers")
|
||||||
|
if not isinstance(custom_providers, list):
|
||||||
|
return None
|
||||||
|
|
||||||
|
for entry in custom_providers:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
name = entry.get("name")
|
||||||
|
base_url = entry.get("base_url")
|
||||||
|
if not isinstance(name, str) or not isinstance(base_url, str):
|
||||||
|
continue
|
||||||
|
name_norm = _normalize_custom_provider_name(name)
|
||||||
|
menu_key = f"custom:{name_norm}"
|
||||||
|
if requested_norm not in {name_norm, menu_key}:
|
||||||
|
continue
|
||||||
|
return {
|
||||||
|
"name": name.strip(),
|
||||||
|
"base_url": base_url.strip(),
|
||||||
|
"api_key": str(entry.get("api_key", "") or "").strip(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_named_custom_runtime(
|
||||||
|
*,
|
||||||
|
requested_provider: str,
|
||||||
|
explicit_api_key: Optional[str] = None,
|
||||||
|
explicit_base_url: Optional[str] = None,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
custom_provider = _get_named_custom_provider(requested_provider)
|
||||||
|
if not custom_provider:
|
||||||
|
return None
|
||||||
|
|
||||||
|
base_url = (
|
||||||
|
(explicit_base_url or "").strip()
|
||||||
|
or custom_provider.get("base_url", "")
|
||||||
|
).rstrip("/")
|
||||||
|
if not base_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
(explicit_api_key or "").strip()
|
||||||
|
or custom_provider.get("api_key", "")
|
||||||
|
or os.getenv("OPENAI_API_KEY", "").strip()
|
||||||
|
or os.getenv("OPENROUTER_API_KEY", "").strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"provider": "openrouter",
|
||||||
|
"api_mode": "chat_completions",
|
||||||
|
"base_url": base_url,
|
||||||
|
"api_key": api_key,
|
||||||
|
"source": f"custom_provider:{custom_provider.get('name', requested_provider)}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _resolve_openrouter_runtime(
|
def _resolve_openrouter_runtime(
|
||||||
*,
|
*,
|
||||||
requested_provider: str,
|
requested_provider: str,
|
||||||
|
|
@ -122,6 +189,15 @@ def resolve_runtime_provider(
|
||||||
"""Resolve runtime provider credentials for agent execution."""
|
"""Resolve runtime provider credentials for agent execution."""
|
||||||
requested_provider = resolve_requested_provider(requested)
|
requested_provider = resolve_requested_provider(requested)
|
||||||
|
|
||||||
|
custom_runtime = _resolve_named_custom_runtime(
|
||||||
|
requested_provider=requested_provider,
|
||||||
|
explicit_api_key=explicit_api_key,
|
||||||
|
explicit_base_url=explicit_base_url,
|
||||||
|
)
|
||||||
|
if custom_runtime:
|
||||||
|
custom_runtime["requested_provider"] = requested_provider
|
||||||
|
return custom_runtime
|
||||||
|
|
||||||
provider = resolve_provider(
|
provider = resolve_provider(
|
||||||
requested_provider,
|
requested_provider,
|
||||||
explicit_api_key=explicit_api_key,
|
explicit_api_key=explicit_api_key,
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,7 @@ def test_custom_endpoint_auto_provider_prefers_openai_key(monkeypatch):
|
||||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://my-vllm-server.example.com/v1")
|
monkeypatch.setenv("OPENAI_BASE_URL", "https://my-vllm-server.example.com/v1")
|
||||||
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-vllm-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-vllm-key")
|
||||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-should-not-leak")
|
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-...leak")
|
||||||
|
|
||||||
resolved = rp.resolve_runtime_provider(requested="auto")
|
resolved = rp.resolve_runtime_provider(requested="auto")
|
||||||
|
|
||||||
|
|
@ -158,6 +158,97 @@ def test_custom_endpoint_auto_provider_prefers_openai_key(monkeypatch):
|
||||||
assert resolved["api_key"] == "sk-vllm-key"
|
assert resolved["api_key"] == "sk-vllm-key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_named_custom_provider_uses_saved_credentials(monkeypatch):
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
|
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
rp,
|
||||||
|
"load_config",
|
||||||
|
lambda: {
|
||||||
|
"custom_providers": [
|
||||||
|
{
|
||||||
|
"name": "Local",
|
||||||
|
"base_url": "http://1.2.3.4:1234/v1",
|
||||||
|
"api_key": "local-provider-key",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
rp,
|
||||||
|
"resolve_provider",
|
||||||
|
lambda *a, **k: (_ for _ in ()).throw(
|
||||||
|
AssertionError(
|
||||||
|
"resolve_provider should not be called for named custom providers"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = rp.resolve_runtime_provider(requested="local")
|
||||||
|
|
||||||
|
assert resolved["provider"] == "openrouter"
|
||||||
|
assert resolved["api_mode"] == "chat_completions"
|
||||||
|
assert resolved["base_url"] == "http://1.2.3.4:1234/v1"
|
||||||
|
assert resolved["api_key"] == "local-provider-key"
|
||||||
|
assert resolved["requested_provider"] == "local"
|
||||||
|
assert resolved["source"] == "custom_provider:Local"
|
||||||
|
|
||||||
|
|
||||||
|
def test_named_custom_provider_falls_back_to_openai_api_key(monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "env-openai-key")
|
||||||
|
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
rp,
|
||||||
|
"load_config",
|
||||||
|
lambda: {
|
||||||
|
"custom_providers": [
|
||||||
|
{
|
||||||
|
"name": "Local LLM",
|
||||||
|
"base_url": "http://localhost:1234/v1",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
rp,
|
||||||
|
"resolve_provider",
|
||||||
|
lambda *a, **k: (_ for _ in ()).throw(
|
||||||
|
AssertionError(
|
||||||
|
"resolve_provider should not be called for named custom providers"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = rp.resolve_runtime_provider(requested="custom:local-llm")
|
||||||
|
|
||||||
|
assert resolved["base_url"] == "http://localhost:1234/v1"
|
||||||
|
assert resolved["api_key"] == "env-openai-key"
|
||||||
|
assert resolved["requested_provider"] == "custom:local-llm"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_runtime_provider_nous_api(monkeypatch):
|
||||||
|
"""Nous Portal API key provider resolves via the api_key path."""
|
||||||
|
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "nous-api")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
rp,
|
||||||
|
"resolve_api_key_provider_credentials",
|
||||||
|
lambda pid: {
|
||||||
|
"provider": "nous-api",
|
||||||
|
"api_key": "nous-test-key",
|
||||||
|
"base_url": "https://inference-api.nousresearch.com/v1",
|
||||||
|
"source": "NOUS_API_KEY",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = rp.resolve_runtime_provider(requested="nous-api")
|
||||||
|
|
||||||
|
assert resolved["provider"] == "nous-api"
|
||||||
|
assert resolved["api_mode"] == "chat_completions"
|
||||||
|
assert resolved["base_url"] == "https://inference-api.nousresearch.com/v1"
|
||||||
|
assert resolved["api_key"] == "nous-test-key"
|
||||||
|
assert resolved["requested_provider"] == "nous-api"
|
||||||
|
|
||||||
|
|
||||||
def test_explicit_openrouter_skips_openai_base_url(monkeypatch):
|
def test_explicit_openrouter_skips_openai_base_url(monkeypatch):
|
||||||
"""When the user explicitly requests openrouter, OPENAI_BASE_URL
|
"""When the user explicitly requests openrouter, OPENAI_BASE_URL
|
||||||
(which may point to a custom endpoint) must not override the
|
(which may point to a custom endpoint) must not override the
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue