Merge pull request #881 from NousResearch/hermes/hermes-b0162f8d
fix: provider selection not persisting when switching via hermes model
This commit is contained in:
commit
ac53bf1d71
5 changed files with 151 additions and 18 deletions
|
|
@ -1671,11 +1671,11 @@ def _save_model_choice(model_id: str) -> None:
|
||||||
from hermes_cli.config import save_config, load_config, save_env_value
|
from hermes_cli.config import save_config, load_config, save_env_value
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
# Handle both string and dict model formats
|
# Always use dict format so provider/base_url can be stored alongside
|
||||||
if isinstance(config.get("model"), dict):
|
if isinstance(config.get("model"), dict):
|
||||||
config["model"]["default"] = model_id
|
config["model"]["default"] = model_id
|
||||||
else:
|
else:
|
||||||
config["model"] = model_id
|
config["model"] = {"default": model_id}
|
||||||
save_config(config)
|
save_config(config)
|
||||||
save_env_value("LLM_MODEL", model_id)
|
save_env_value("LLM_MODEL", model_id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -906,9 +906,11 @@ def _model_flow_openrouter(config, current_model=""):
|
||||||
from hermes_cli.config import load_config, save_config
|
from hermes_cli.config import load_config, save_config
|
||||||
cfg = load_config()
|
cfg = load_config()
|
||||||
model = cfg.get("model")
|
model = cfg.get("model")
|
||||||
if isinstance(model, dict):
|
if not isinstance(model, dict):
|
||||||
model["provider"] = "openrouter"
|
model = {"default": model} if model else {}
|
||||||
model["base_url"] = OPENROUTER_BASE_URL
|
cfg["model"] = model
|
||||||
|
model["provider"] = "openrouter"
|
||||||
|
model["base_url"] = OPENROUTER_BASE_URL
|
||||||
save_config(cfg)
|
save_config(cfg)
|
||||||
deactivate_provider()
|
deactivate_provider()
|
||||||
print(f"Default model set to: {selected} (via OpenRouter)")
|
print(f"Default model set to: {selected} (via OpenRouter)")
|
||||||
|
|
@ -1090,9 +1092,11 @@ def _model_flow_custom(config):
|
||||||
# Update config and deactivate any OAuth provider
|
# Update config and deactivate any OAuth provider
|
||||||
cfg = load_config()
|
cfg = load_config()
|
||||||
model = cfg.get("model")
|
model = cfg.get("model")
|
||||||
if isinstance(model, dict):
|
if not isinstance(model, dict):
|
||||||
model["provider"] = "custom"
|
model = {"default": model} if model else {}
|
||||||
model["base_url"] = effective_url
|
cfg["model"] = model
|
||||||
|
model["provider"] = "custom"
|
||||||
|
model["base_url"] = effective_url
|
||||||
save_config(cfg)
|
save_config(cfg)
|
||||||
deactivate_provider()
|
deactivate_provider()
|
||||||
|
|
||||||
|
|
@ -1235,9 +1239,11 @@ def _model_flow_named_custom(config, provider_info):
|
||||||
|
|
||||||
cfg = load_config()
|
cfg = load_config()
|
||||||
model = cfg.get("model")
|
model = cfg.get("model")
|
||||||
if isinstance(model, dict):
|
if not isinstance(model, dict):
|
||||||
model["provider"] = "custom"
|
model = {"default": model} if model else {}
|
||||||
model["base_url"] = base_url
|
cfg["model"] = model
|
||||||
|
model["provider"] = "custom"
|
||||||
|
model["base_url"] = base_url
|
||||||
save_config(cfg)
|
save_config(cfg)
|
||||||
deactivate_provider()
|
deactivate_provider()
|
||||||
|
|
||||||
|
|
@ -1307,9 +1313,11 @@ def _model_flow_named_custom(config, provider_info):
|
||||||
|
|
||||||
cfg = load_config()
|
cfg = load_config()
|
||||||
model = cfg.get("model")
|
model = cfg.get("model")
|
||||||
if isinstance(model, dict):
|
if not isinstance(model, dict):
|
||||||
model["provider"] = "custom"
|
model = {"default": model} if model else {}
|
||||||
model["base_url"] = base_url
|
cfg["model"] = model
|
||||||
|
model["provider"] = "custom"
|
||||||
|
model["base_url"] = base_url
|
||||||
save_config(cfg)
|
save_config(cfg)
|
||||||
deactivate_provider()
|
deactivate_provider()
|
||||||
|
|
||||||
|
|
@ -1420,9 +1428,11 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||||
# Update config with provider and base URL
|
# Update config with provider and base URL
|
||||||
cfg = load_config()
|
cfg = load_config()
|
||||||
model = cfg.get("model")
|
model = cfg.get("model")
|
||||||
if isinstance(model, dict):
|
if not isinstance(model, dict):
|
||||||
model["provider"] = provider_id
|
model = {"default": model} if model else {}
|
||||||
model["base_url"] = effective_base
|
cfg["model"] = model
|
||||||
|
model["provider"] = provider_id
|
||||||
|
model["base_url"] = effective_base
|
||||||
save_config(cfg)
|
save_config(cfg)
|
||||||
deactivate_provider()
|
deactivate_provider()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,9 +66,14 @@ def _resolve_openrouter_runtime(
|
||||||
if not cfg_provider or cfg_provider == "auto":
|
if not cfg_provider or cfg_provider == "auto":
|
||||||
use_config_base_url = True
|
use_config_base_url = True
|
||||||
|
|
||||||
|
# When the user explicitly requested the openrouter provider, skip
|
||||||
|
# OPENAI_BASE_URL — it typically points to a custom / non-OpenRouter
|
||||||
|
# endpoint and would prevent switching back to OpenRouter (#874).
|
||||||
|
skip_openai_base = requested_norm == "openrouter"
|
||||||
|
|
||||||
base_url = (
|
base_url = (
|
||||||
(explicit_base_url or "").strip()
|
(explicit_base_url or "").strip()
|
||||||
or env_openai_base_url
|
or ("" if skip_openai_base else env_openai_base_url)
|
||||||
or (cfg_base_url.strip() if use_config_base_url else "")
|
or (cfg_base_url.strip() if use_config_base_url else "")
|
||||||
or env_openrouter_base_url
|
or env_openrouter_base_url
|
||||||
or OPENROUTER_BASE_URL
|
or OPENROUTER_BASE_URL
|
||||||
|
|
|
||||||
99
tests/test_model_provider_persistence.py
Normal file
99
tests/test_model_provider_persistence.py
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
"""Tests that provider selection via `hermes model` always persists correctly.
|
||||||
|
|
||||||
|
Regression tests for the bug where _save_model_choice could save config.model
|
||||||
|
as a plain string, causing subsequent provider writes (which check
|
||||||
|
isinstance(model, dict)) to silently fail — leaving the provider unset and
|
||||||
|
falling back to auto-detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_home(tmp_path, monkeypatch):
|
||||||
|
"""Isolated HERMES_HOME with a minimal string-format config."""
|
||||||
|
home = tmp_path / "hermes"
|
||||||
|
home.mkdir()
|
||||||
|
config_yaml = home / "config.yaml"
|
||||||
|
# Start with model as a plain string — the format that triggered the bug
|
||||||
|
config_yaml.write_text("model: some-old-model\n")
|
||||||
|
env_file = home / ".env"
|
||||||
|
env_file.write_text("")
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||||
|
# Clear env vars that could interfere
|
||||||
|
monkeypatch.delenv("HERMES_MODEL", raising=False)
|
||||||
|
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||||
|
monkeypatch.delenv("HERMES_INFERENCE_PROVIDER", raising=False)
|
||||||
|
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
|
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||||
|
return home
|
||||||
|
|
||||||
|
|
||||||
|
class TestSaveModelChoiceAlwaysDict:
|
||||||
|
def test_string_model_becomes_dict(self, config_home):
|
||||||
|
"""When config.model is a plain string, _save_model_choice must
|
||||||
|
convert it to a dict so provider can be set afterwards."""
|
||||||
|
from hermes_cli.auth import _save_model_choice
|
||||||
|
|
||||||
|
_save_model_choice("kimi-k2.5")
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
|
||||||
|
model = config.get("model")
|
||||||
|
assert isinstance(model, dict), (
|
||||||
|
f"Expected model to be a dict after save, got {type(model)}: {model}"
|
||||||
|
)
|
||||||
|
assert model["default"] == "kimi-k2.5"
|
||||||
|
|
||||||
|
def test_dict_model_stays_dict(self, config_home):
|
||||||
|
"""When config.model is already a dict, _save_model_choice preserves it."""
|
||||||
|
import yaml
|
||||||
|
(config_home / "config.yaml").write_text(
|
||||||
|
"model:\n default: old-model\n provider: openrouter\n"
|
||||||
|
)
|
||||||
|
from hermes_cli.auth import _save_model_choice
|
||||||
|
|
||||||
|
_save_model_choice("new-model")
|
||||||
|
|
||||||
|
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
|
||||||
|
model = config.get("model")
|
||||||
|
assert isinstance(model, dict)
|
||||||
|
assert model["default"] == "new-model"
|
||||||
|
assert model["provider"] == "openrouter" # preserved
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderPersistsAfterModelSave:
|
||||||
|
def test_api_key_provider_saved_when_model_was_string(self, config_home, monkeypatch):
|
||||||
|
"""_model_flow_api_key_provider must persist the provider even when
|
||||||
|
config.model started as a plain string."""
|
||||||
|
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||||
|
|
||||||
|
pconfig = PROVIDER_REGISTRY.get("kimi-coding")
|
||||||
|
if not pconfig:
|
||||||
|
pytest.skip("kimi-coding not in PROVIDER_REGISTRY")
|
||||||
|
|
||||||
|
# Simulate: user has a Kimi API key, model was a string
|
||||||
|
monkeypatch.setenv("KIMI_API_KEY", "sk-kimi-test-key")
|
||||||
|
|
||||||
|
from hermes_cli.main import _model_flow_api_key_provider
|
||||||
|
from hermes_cli.config import load_config
|
||||||
|
|
||||||
|
# Mock the model selection prompt to return "kimi-k2.5"
|
||||||
|
# Also mock input() for the base URL prompt and builtins.input
|
||||||
|
with patch("hermes_cli.auth._prompt_model_selection", return_value="kimi-k2.5"), \
|
||||||
|
patch("hermes_cli.auth.deactivate_provider"), \
|
||||||
|
patch("builtins.input", return_value=""):
|
||||||
|
_model_flow_api_key_provider(load_config(), "kimi-coding", "old-model")
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
|
||||||
|
model = config.get("model")
|
||||||
|
assert isinstance(model, dict), f"model should be dict, got {type(model)}"
|
||||||
|
assert model.get("provider") == "kimi-coding", (
|
||||||
|
f"provider should be 'kimi-coding', got {model.get('provider')}"
|
||||||
|
)
|
||||||
|
assert model.get("default") == "kimi-k2.5"
|
||||||
|
|
@ -181,6 +181,25 @@ def test_resolve_runtime_provider_nous_api(monkeypatch):
|
||||||
assert resolved["requested_provider"] == "nous-api"
|
assert resolved["requested_provider"] == "nous-api"
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_openrouter_skips_openai_base_url(monkeypatch):
|
||||||
|
"""When the user explicitly requests openrouter, OPENAI_BASE_URL
|
||||||
|
(which may point to a custom endpoint) must not override the
|
||||||
|
OpenRouter base URL. Regression test for #874."""
|
||||||
|
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||||
|
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||||
|
monkeypatch.setenv("OPENAI_BASE_URL", "https://my-custom-llm.example.com/v1")
|
||||||
|
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
|
||||||
|
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||||
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||||
|
|
||||||
|
resolved = rp.resolve_runtime_provider(requested="openrouter")
|
||||||
|
|
||||||
|
assert resolved["provider"] == "openrouter"
|
||||||
|
assert "openrouter.ai" in resolved["base_url"]
|
||||||
|
assert "my-custom-llm" not in resolved["base_url"]
|
||||||
|
assert resolved["api_key"] == "or-test-key"
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_requested_provider_precedence(monkeypatch):
|
def test_resolve_requested_provider_precedence(monkeypatch):
|
||||||
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "nous")
|
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "nous")
|
||||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {"provider": "openai-codex"})
|
monkeypatch.setattr(rp, "_get_model_config", lambda: {"provider": "openai-codex"})
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue