refactor(model): extract shared switch_model() from CLI and gateway handlers

Phase 4 of the /model command overhaul.

Both the CLI (cli.py) and gateway (gateway/run.py) /model handlers
had ~50 lines of duplicated core logic: parsing, provider detection,
credential resolution, and model validation. This extracts that
pipeline into hermes_cli/model_switch.py.

New module exports:
- ModelSwitchResult: dataclass with all fields both handlers need
- CustomAutoResult: dataclass for bare '/model custom' results
- switch_model(): core pipeline — parse → detect → resolve → validate
- switch_to_custom_provider(): resolve endpoint + auto-detect model

The shared functions are pure (no I/O side effects). Each caller
handles its own platform-specific concerns:
- CLI: sets self.model/provider/etc, calls save_config_value(), prints
- Gateway: writes config.yaml directly, sets env vars, returns markdown

Net result: -244 lines from handlers, +234 lines in shared module.
The handlers are now ~80 lines each (down from ~150+) and can't drift
apart on core logic.
This commit is contained in:
Teknium 2026-03-24 07:08:07 -07:00 committed by GitHub
parent ce39f9cc44
commit 2e524272b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 359 additions and 258 deletions

View file

@ -2854,117 +2854,10 @@ class GatewayRunner:
# Handle bare "/model custom" — switch to custom provider
# and auto-detect the model from the endpoint.
if args.strip().lower() == "custom":
from hermes_cli.runtime_provider import (
resolve_runtime_provider as _rtp_custom,
_auto_detect_local_model,
)
try:
runtime = _rtp_custom(requested="custom")
cust_base = runtime.get("base_url", "")
if not cust_base or "openrouter.ai" in cust_base:
return (
"⚠️ No custom endpoint configured.\n"
"Set `model.base_url` in config.yaml, or `OPENAI_BASE_URL` in .env,\n"
"or run: `hermes setup` → Custom OpenAI-compatible endpoint"
)
detected_model = _auto_detect_local_model(cust_base)
if detected_model:
try:
user_config = {}
if config_path.exists():
with open(config_path, encoding="utf-8") as f:
user_config = yaml.safe_load(f) or {}
if "model" not in user_config or not isinstance(user_config["model"], dict):
user_config["model"] = {}
user_config["model"]["default"] = detected_model
user_config["model"]["provider"] = "custom"
user_config["model"]["base_url"] = cust_base
with open(config_path, 'w', encoding="utf-8") as f:
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
except Exception as e:
return f"⚠️ Failed to save model change: {e}"
os.environ["HERMES_MODEL"] = detected_model
os.environ["HERMES_INFERENCE_PROVIDER"] = "custom"
self._effective_model = None
self._effective_provider = None
return (
f"🤖 Model changed to `{detected_model}` (saved to config)\n"
f"**Provider:** Custom\n"
f"**Endpoint:** `{cust_base}`\n"
f"_Model auto-detected from endpoint. Takes effect on next message._"
)
else:
return (
f"⚠️ Custom endpoint at `{cust_base}` is reachable but no single model was auto-detected.\n"
f"Specify the model explicitly: `/model custom:<model-name>`"
)
except Exception as e:
return f"⚠️ Could not resolve custom endpoint: {e}"
# Parse provider:model syntax
target_provider, new_model = parse_model_input(args, current_provider)
# Detect custom/local provider — skip auto-detection to prevent
# silently accepting an OpenRouter model name on a localhost endpoint.
# Users must use explicit provider:model syntax to switch away.
_resolved_base = ""
try:
from hermes_cli.runtime_provider import resolve_runtime_provider as _rtp
_resolved_base = _rtp(requested=current_provider).get("base_url", "")
except Exception:
pass
is_custom = current_provider == "custom" or (
"localhost" in _resolved_base or "127.0.0.1" in _resolved_base
)
# Auto-detect provider when no explicit provider:model syntax was used
if target_provider == current_provider and not is_custom:
from hermes_cli.models import detect_provider_for_model
detected = detect_provider_for_model(new_model, current_provider)
if detected:
target_provider, new_model = detected
provider_changed = target_provider != current_provider
# Resolve credentials for the target provider (for API probe)
api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
base_url = "https://openrouter.ai/api/v1"
if provider_changed:
try:
from hermes_cli.runtime_provider import resolve_runtime_provider
runtime = resolve_runtime_provider(requested=target_provider)
api_key = runtime.get("api_key", "")
base_url = runtime.get("base_url", "")
except Exception as e:
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
return f"⚠️ Could not resolve credentials for provider '{provider_label}': {e}"
else:
# Use current provider's base_url from config or registry
try:
from hermes_cli.runtime_provider import resolve_runtime_provider
runtime = resolve_runtime_provider(requested=current_provider)
api_key = runtime.get("api_key", "")
base_url = runtime.get("base_url", "")
except Exception:
pass
# Validate the model against the live API
try:
validation = validate_requested_model(
new_model,
target_provider,
api_key=api_key,
base_url=base_url,
)
except Exception:
validation = {"accepted": True, "persist": True, "recognized": False, "message": None}
if not validation.get("accepted"):
msg = validation.get("message", "Invalid model")
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
return f"⚠️ {msg}{tip}"
# Persist to config only if validation approves
if validation.get("persist"):
from hermes_cli.model_switch import switch_to_custom_provider
cust_result = switch_to_custom_provider()
if not cust_result.success:
return f"⚠️ {cust_result.error_message}"
try:
user_config = {}
if config_path.exists():
@ -2972,14 +2865,63 @@ class GatewayRunner:
user_config = yaml.safe_load(f) or {}
if "model" not in user_config or not isinstance(user_config["model"], dict):
user_config["model"] = {}
user_config["model"]["default"] = new_model
if provider_changed:
user_config["model"]["provider"] = target_provider
# Persist base_url for custom endpoints so it survives
# restart; clear it when switching away from custom to
# prevent stale URLs leaking (#2562 Phase 2).
if base_url and "openrouter.ai" not in (base_url or ""):
user_config["model"]["base_url"] = base_url
user_config["model"]["default"] = cust_result.model
user_config["model"]["provider"] = "custom"
user_config["model"]["base_url"] = cust_result.base_url
with open(config_path, 'w', encoding="utf-8") as f:
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
except Exception as e:
return f"⚠️ Failed to save model change: {e}"
os.environ["HERMES_MODEL"] = cust_result.model
os.environ["HERMES_INFERENCE_PROVIDER"] = "custom"
self._effective_model = None
self._effective_provider = None
return (
f"🤖 Model changed to `{cust_result.model}` (saved to config)\n"
f"**Provider:** Custom\n"
f"**Endpoint:** `{cust_result.base_url}`\n"
f"_Model auto-detected from endpoint. Takes effect on next message._"
)
# Core model-switching pipeline (shared with CLI)
from hermes_cli.model_switch import switch_model
# Resolve current base_url for is_custom detection
_resolved_base = ""
try:
from hermes_cli.runtime_provider import resolve_runtime_provider as _rtp
_resolved_base = _rtp(requested=current_provider).get("base_url", "")
except Exception:
pass
result = switch_model(
args,
current_provider,
current_base_url=_resolved_base,
current_api_key=os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or "",
)
if not result.success:
msg = result.error_message
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
return f"⚠️ {msg}{tip}"
# Persist to config only if validation approves
if result.persist:
try:
user_config = {}
if config_path.exists():
with open(config_path, encoding="utf-8") as f:
user_config = yaml.safe_load(f) or {}
if "model" not in user_config or not isinstance(user_config["model"], dict):
user_config["model"] = {}
user_config["model"]["default"] = result.new_model
if result.provider_changed:
user_config["model"]["provider"] = result.target_provider
# Persist base_url for custom endpoints; clear when
# switching away from custom (#2562 Phase 2).
if result.base_url and "openrouter.ai" not in (result.base_url or ""):
user_config["model"]["base_url"] = result.base_url
else:
user_config["model"].pop("base_url", None)
with open(config_path, 'w', encoding="utf-8") as f:
@ -2988,41 +2930,34 @@ class GatewayRunner:
return f"⚠️ Failed to save model change: {e}"
# Set env vars so the next agent run picks up the change
os.environ["HERMES_MODEL"] = new_model
if provider_changed:
os.environ["HERMES_INFERENCE_PROVIDER"] = target_provider
os.environ["HERMES_MODEL"] = result.new_model
if result.provider_changed:
os.environ["HERMES_INFERENCE_PROVIDER"] = result.target_provider
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
provider_note = f"\n**Provider:** {provider_label}" if provider_changed else ""
provider_note = f"\n**Provider:** {result.provider_label}" if result.provider_changed else ""
warning = ""
if validation.get("message"):
warning = f"\n⚠️ {validation['message']}"
if result.warning_message:
warning = f"\n⚠️ {result.warning_message}"
persist_note = "saved to config" if result.persist else "this session only — will revert on restart"
if validation.get("persist"):
persist_note = "saved to config"
else:
persist_note = "this session only — will revert on restart"
# Clear fallback state since user explicitly chose a model
self._effective_model = None
self._effective_provider = None
# Show endpoint info for custom providers
_target_is_custom = target_provider == "custom" or (
base_url and "openrouter.ai" not in (base_url or "")
and ("localhost" in (base_url or "") or "127.0.0.1" in (base_url or ""))
)
custom_hint = ""
if _target_is_custom or (is_custom and not provider_changed):
endpoint = base_url or _resolved_base or "custom endpoint"
if result.is_custom_target:
endpoint = result.base_url or _resolved_base or "custom endpoint"
custom_hint = f"\n**Endpoint:** `{endpoint}`"
if not provider_changed:
if not result.provider_changed:
custom_hint += (
"\n_To switch providers, use_ `/model provider:model`"
"\n_e.g._ `/model openrouter:anthropic/claude-sonnet-4`"
)
return f"🤖 Model changed to `{new_model}` ({persist_note}){provider_note}{warning}{custom_hint}\n_(takes effect on next message)_"
return f"🤖 Model changed to `{result.new_model}` ({persist_note}){provider_note}{warning}{custom_hint}\n_(takes effect on next message)_"
async def _handle_provider_command(self, event: MessageEvent) -> str:
"""Handle /provider command - show available providers."""