refactor(model): extract shared switch_model() from CLI and gateway handlers
Phase 4 of the /model command overhaul. Both the CLI (cli.py) and gateway (gateway/run.py) /model handlers had ~50 lines of duplicated core logic: parsing, provider detection, credential resolution, and model validation. This extracts that pipeline into hermes_cli/model_switch.py. New module exports: - ModelSwitchResult: dataclass with all fields both handlers need - CustomAutoResult: dataclass for bare '/model custom' results - switch_model(): core pipeline — parse → detect → resolve → validate - switch_to_custom_provider(): resolve endpoint + auto-detect model The shared functions are pure (no I/O side effects). Each caller handles its own platform-specific concerns: - CLI: sets self.model/provider/etc, calls save_config_value(), prints - Gateway: writes config.yaml directly, sets env vars, returns markdown Net result: -244 lines from handlers, +234 lines in shared module. The handlers are now ~80 lines each (down from ~150+) and can't drift apart on core logic.
This commit is contained in:
parent
ce39f9cc44
commit
2e524272b1
3 changed files with 359 additions and 258 deletions
172
cli.py
172
cli.py
|
|
@ -3562,151 +3562,83 @@ class HermesCLI:
|
||||||
# Use original case so model names like "Anthropic/Claude-Opus-4" are preserved
|
# Use original case so model names like "Anthropic/Claude-Opus-4" are preserved
|
||||||
parts = cmd_original.split(maxsplit=1)
|
parts = cmd_original.split(maxsplit=1)
|
||||||
if len(parts) > 1:
|
if len(parts) > 1:
|
||||||
from hermes_cli.auth import resolve_provider
|
from hermes_cli.model_switch import switch_model, switch_to_custom_provider
|
||||||
from hermes_cli.models import (
|
|
||||||
parse_model_input,
|
|
||||||
validate_requested_model,
|
|
||||||
_PROVIDER_LABELS,
|
|
||||||
)
|
|
||||||
|
|
||||||
raw_input = parts[1].strip()
|
raw_input = parts[1].strip()
|
||||||
|
|
||||||
# Handle bare "/model custom" — switch to custom provider
|
# Handle bare "/model custom" — switch to custom provider
|
||||||
# and auto-detect the model from the endpoint.
|
# and auto-detect the model from the endpoint.
|
||||||
if raw_input.strip().lower() == "custom":
|
if raw_input.strip().lower() == "custom":
|
||||||
from hermes_cli.runtime_provider import (
|
result = switch_to_custom_provider()
|
||||||
resolve_runtime_provider,
|
if result.success:
|
||||||
_auto_detect_local_model,
|
self.model = result.model
|
||||||
)
|
self.requested_provider = "custom"
|
||||||
try:
|
self.provider = "custom"
|
||||||
runtime = resolve_runtime_provider(requested="custom")
|
self.api_key = result.api_key
|
||||||
cust_base = runtime.get("base_url", "")
|
self.base_url = result.base_url
|
||||||
cust_key = runtime.get("api_key", "")
|
self.agent = None
|
||||||
if not cust_base or "openrouter.ai" in cust_base:
|
save_config_value("model.default", result.model)
|
||||||
print("(>_<) No custom endpoint configured.")
|
save_config_value("model.provider", "custom")
|
||||||
print(" Set model.base_url in config.yaml, or set OPENAI_BASE_URL in .env,")
|
save_config_value("model.base_url", result.base_url)
|
||||||
print(" or run: hermes setup → Custom OpenAI-compatible endpoint")
|
print(f"(^_^)b Model changed to: {result.model} [provider: Custom]")
|
||||||
return True
|
print(f" Endpoint: {result.base_url}")
|
||||||
detected_model = _auto_detect_local_model(cust_base)
|
print(f" Status: connected (model auto-detected)")
|
||||||
if detected_model:
|
else:
|
||||||
self.model = detected_model
|
print(f"(>_<) {result.error_message}")
|
||||||
self.requested_provider = "custom"
|
|
||||||
self.provider = "custom"
|
|
||||||
self.api_key = cust_key
|
|
||||||
self.base_url = cust_base
|
|
||||||
self.agent = None
|
|
||||||
save_config_value("model.default", detected_model)
|
|
||||||
save_config_value("model.provider", "custom")
|
|
||||||
save_config_value("model.base_url", cust_base)
|
|
||||||
print(f"(^_^)b Model changed to: {detected_model} [provider: Custom]")
|
|
||||||
print(f" Endpoint: {cust_base}")
|
|
||||||
print(f" Status: connected (model auto-detected)")
|
|
||||||
else:
|
|
||||||
print(f"(>_<) Custom endpoint at {cust_base} is reachable but no single model was auto-detected.")
|
|
||||||
print(f" Specify the model explicitly: /model custom:<model-name>")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"(>_<) Could not resolve custom endpoint: {e}")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Parse provider:model syntax (e.g. "openrouter:anthropic/claude-sonnet-4.5")
|
# Core model-switching pipeline (shared with gateway)
|
||||||
current_provider = self.provider or self.requested_provider or "openrouter"
|
current_provider = self.provider or self.requested_provider or "openrouter"
|
||||||
target_provider, new_model = parse_model_input(raw_input, current_provider)
|
result = switch_model(
|
||||||
# Auto-detect provider when no explicit provider:model syntax was used.
|
raw_input,
|
||||||
# Skip auto-detection for custom providers — the model name might
|
current_provider,
|
||||||
# coincidentally match a known provider's catalog, but the user
|
current_base_url=self.base_url or "",
|
||||||
# intends to use it on their custom endpoint. Require explicit
|
current_api_key=self.api_key or "",
|
||||||
# provider:model syntax (e.g. /model openai-codex:gpt-5.2-codex)
|
|
||||||
# to switch away from a custom endpoint.
|
|
||||||
_base = self.base_url or ""
|
|
||||||
is_custom = current_provider == "custom" or (
|
|
||||||
"localhost" in _base or "127.0.0.1" in _base
|
|
||||||
)
|
)
|
||||||
if target_provider == current_provider and not is_custom:
|
|
||||||
from hermes_cli.models import detect_provider_for_model
|
|
||||||
detected = detect_provider_for_model(new_model, current_provider)
|
|
||||||
if detected:
|
|
||||||
target_provider, new_model = detected
|
|
||||||
provider_changed = target_provider != current_provider
|
|
||||||
|
|
||||||
# If provider is changing, re-resolve credentials for the new provider
|
if not result.success:
|
||||||
api_key_for_probe = self.api_key
|
print(f"(>_<) {result.error_message}")
|
||||||
base_url_for_probe = self.base_url
|
if "Did you mean" not in result.error_message:
|
||||||
if provider_changed:
|
print(f" Model unchanged: {self.model}")
|
||||||
try:
|
if "credentials" not in result.error_message.lower():
|
||||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
print(" Tip: Use /model to see available models, /provider to see providers")
|
||||||
runtime = resolve_runtime_provider(requested=target_provider)
|
|
||||||
api_key_for_probe = runtime.get("api_key", "")
|
|
||||||
base_url_for_probe = runtime.get("base_url", "")
|
|
||||||
except Exception as e:
|
|
||||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
|
||||||
if target_provider == "custom":
|
|
||||||
print(f"(>_<) Custom endpoint not configured. Set OPENAI_BASE_URL and OPENAI_API_KEY,")
|
|
||||||
print(f" or run: hermes setup → Custom OpenAI-compatible endpoint")
|
|
||||||
else:
|
|
||||||
print(f"(>_<) Could not resolve credentials for provider '{provider_label}': {e}")
|
|
||||||
print(f"(^_^) Current model unchanged: {self.model}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
try:
|
|
||||||
validation = validate_requested_model(
|
|
||||||
new_model,
|
|
||||||
target_provider,
|
|
||||||
api_key=api_key_for_probe,
|
|
||||||
base_url=base_url_for_probe,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
validation = {"accepted": True, "persist": True, "recognized": False, "message": None}
|
|
||||||
|
|
||||||
if not validation.get("accepted"):
|
|
||||||
print(f"(>_<) {validation.get('message')}")
|
|
||||||
print(f" Model unchanged: {self.model}")
|
|
||||||
if "Did you mean" not in (validation.get("message") or ""):
|
|
||||||
print(" Tip: Use /model to see available models, /provider to see providers")
|
|
||||||
else:
|
else:
|
||||||
self.model = new_model
|
self.model = result.new_model
|
||||||
self.agent = None # Force re-init
|
self.agent = None # Force re-init
|
||||||
|
|
||||||
if provider_changed:
|
if result.provider_changed:
|
||||||
self.requested_provider = target_provider
|
self.requested_provider = result.target_provider
|
||||||
self.provider = target_provider
|
self.provider = result.target_provider
|
||||||
self.api_key = api_key_for_probe
|
self.api_key = result.api_key
|
||||||
self.base_url = base_url_for_probe
|
self.base_url = result.base_url
|
||||||
|
|
||||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
provider_note = f" [provider: {result.provider_label}]" if result.provider_changed else ""
|
||||||
provider_note = f" [provider: {provider_label}]" if provider_changed else ""
|
|
||||||
|
|
||||||
if validation.get("persist"):
|
if result.persist:
|
||||||
saved_model = save_config_value("model.default", new_model)
|
saved_model = save_config_value("model.default", result.new_model)
|
||||||
if provider_changed:
|
if result.provider_changed:
|
||||||
save_config_value("model.provider", target_provider)
|
save_config_value("model.provider", result.target_provider)
|
||||||
# Persist base_url for custom endpoints so it
|
# Persist base_url for custom endpoints; clear
|
||||||
# survives restart; clear it when switching away
|
# when switching away from custom (#2562 Phase 2).
|
||||||
# from custom to prevent stale URLs leaking into
|
if result.base_url and "openrouter.ai" not in (result.base_url or ""):
|
||||||
# the new provider's resolution (#2562 Phase 2).
|
save_config_value("model.base_url", result.base_url)
|
||||||
if base_url_for_probe and "openrouter.ai" not in (base_url_for_probe or ""):
|
|
||||||
save_config_value("model.base_url", base_url_for_probe)
|
|
||||||
else:
|
else:
|
||||||
save_config_value("model.base_url", None)
|
save_config_value("model.base_url", None)
|
||||||
if saved_model:
|
if saved_model:
|
||||||
print(f"(^_^)b Model changed to: {new_model}{provider_note} (saved to config)")
|
print(f"(^_^)b Model changed to: {result.new_model}{provider_note} (saved to config)")
|
||||||
else:
|
else:
|
||||||
print(f"(^_^) Model changed to: {new_model}{provider_note} (this session only)")
|
print(f"(^_^) Model changed to: {result.new_model}{provider_note} (this session only)")
|
||||||
else:
|
else:
|
||||||
message = validation.get("message") or ""
|
print(f"(^_^) Model changed to: {result.new_model}{provider_note} (this session only)")
|
||||||
print(f"(^_^) Model changed to: {new_model}{provider_note} (this session only)")
|
if result.warning_message:
|
||||||
if message:
|
print(f" Reason: {result.warning_message}")
|
||||||
print(f" Reason: {message}")
|
|
||||||
print(" Note: Model will revert on restart. Use a verified model to save to config.")
|
print(" Note: Model will revert on restart. Use a verified model to save to config.")
|
||||||
|
|
||||||
# Show endpoint info for custom providers
|
# Show endpoint info for custom providers
|
||||||
_target_is_custom = target_provider == "custom" or (
|
if result.is_custom_target:
|
||||||
base_url_for_probe and "openrouter.ai" not in (base_url_for_probe or "")
|
endpoint = result.base_url or self.base_url or "custom endpoint"
|
||||||
and ("localhost" in (base_url_for_probe or "") or "127.0.0.1" in (base_url_for_probe or ""))
|
|
||||||
)
|
|
||||||
if _target_is_custom or (is_custom and not provider_changed):
|
|
||||||
endpoint = base_url_for_probe or self.base_url or "custom endpoint"
|
|
||||||
print(f" Endpoint: {endpoint}")
|
print(f" Endpoint: {endpoint}")
|
||||||
if not provider_changed:
|
if not result.provider_changed:
|
||||||
print(f" Tip: To switch providers, use /model provider:model")
|
print(f" Tip: To switch providers, use /model provider:model")
|
||||||
print(f" e.g. /model openai-codex:gpt-5.2-codex")
|
print(f" e.g. /model openai-codex:gpt-5.2-codex")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
211
gateway/run.py
211
gateway/run.py
|
|
@ -2854,117 +2854,10 @@ class GatewayRunner:
|
||||||
# Handle bare "/model custom" — switch to custom provider
|
# Handle bare "/model custom" — switch to custom provider
|
||||||
# and auto-detect the model from the endpoint.
|
# and auto-detect the model from the endpoint.
|
||||||
if args.strip().lower() == "custom":
|
if args.strip().lower() == "custom":
|
||||||
from hermes_cli.runtime_provider import (
|
from hermes_cli.model_switch import switch_to_custom_provider
|
||||||
resolve_runtime_provider as _rtp_custom,
|
cust_result = switch_to_custom_provider()
|
||||||
_auto_detect_local_model,
|
if not cust_result.success:
|
||||||
)
|
return f"⚠️ {cust_result.error_message}"
|
||||||
try:
|
|
||||||
runtime = _rtp_custom(requested="custom")
|
|
||||||
cust_base = runtime.get("base_url", "")
|
|
||||||
if not cust_base or "openrouter.ai" in cust_base:
|
|
||||||
return (
|
|
||||||
"⚠️ No custom endpoint configured.\n"
|
|
||||||
"Set `model.base_url` in config.yaml, or `OPENAI_BASE_URL` in .env,\n"
|
|
||||||
"or run: `hermes setup` → Custom OpenAI-compatible endpoint"
|
|
||||||
)
|
|
||||||
detected_model = _auto_detect_local_model(cust_base)
|
|
||||||
if detected_model:
|
|
||||||
try:
|
|
||||||
user_config = {}
|
|
||||||
if config_path.exists():
|
|
||||||
with open(config_path, encoding="utf-8") as f:
|
|
||||||
user_config = yaml.safe_load(f) or {}
|
|
||||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
|
||||||
user_config["model"] = {}
|
|
||||||
user_config["model"]["default"] = detected_model
|
|
||||||
user_config["model"]["provider"] = "custom"
|
|
||||||
user_config["model"]["base_url"] = cust_base
|
|
||||||
with open(config_path, 'w', encoding="utf-8") as f:
|
|
||||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
|
||||||
except Exception as e:
|
|
||||||
return f"⚠️ Failed to save model change: {e}"
|
|
||||||
os.environ["HERMES_MODEL"] = detected_model
|
|
||||||
os.environ["HERMES_INFERENCE_PROVIDER"] = "custom"
|
|
||||||
self._effective_model = None
|
|
||||||
self._effective_provider = None
|
|
||||||
return (
|
|
||||||
f"🤖 Model changed to `{detected_model}` (saved to config)\n"
|
|
||||||
f"**Provider:** Custom\n"
|
|
||||||
f"**Endpoint:** `{cust_base}`\n"
|
|
||||||
f"_Model auto-detected from endpoint. Takes effect on next message._"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return (
|
|
||||||
f"⚠️ Custom endpoint at `{cust_base}` is reachable but no single model was auto-detected.\n"
|
|
||||||
f"Specify the model explicitly: `/model custom:<model-name>`"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return f"⚠️ Could not resolve custom endpoint: {e}"
|
|
||||||
|
|
||||||
# Parse provider:model syntax
|
|
||||||
target_provider, new_model = parse_model_input(args, current_provider)
|
|
||||||
|
|
||||||
# Detect custom/local provider — skip auto-detection to prevent
|
|
||||||
# silently accepting an OpenRouter model name on a localhost endpoint.
|
|
||||||
# Users must use explicit provider:model syntax to switch away.
|
|
||||||
_resolved_base = ""
|
|
||||||
try:
|
|
||||||
from hermes_cli.runtime_provider import resolve_runtime_provider as _rtp
|
|
||||||
_resolved_base = _rtp(requested=current_provider).get("base_url", "")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
is_custom = current_provider == "custom" or (
|
|
||||||
"localhost" in _resolved_base or "127.0.0.1" in _resolved_base
|
|
||||||
)
|
|
||||||
|
|
||||||
# Auto-detect provider when no explicit provider:model syntax was used
|
|
||||||
if target_provider == current_provider and not is_custom:
|
|
||||||
from hermes_cli.models import detect_provider_for_model
|
|
||||||
detected = detect_provider_for_model(new_model, current_provider)
|
|
||||||
if detected:
|
|
||||||
target_provider, new_model = detected
|
|
||||||
provider_changed = target_provider != current_provider
|
|
||||||
|
|
||||||
# Resolve credentials for the target provider (for API probe)
|
|
||||||
api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
|
||||||
base_url = "https://openrouter.ai/api/v1"
|
|
||||||
if provider_changed:
|
|
||||||
try:
|
|
||||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
|
||||||
runtime = resolve_runtime_provider(requested=target_provider)
|
|
||||||
api_key = runtime.get("api_key", "")
|
|
||||||
base_url = runtime.get("base_url", "")
|
|
||||||
except Exception as e:
|
|
||||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
|
||||||
return f"⚠️ Could not resolve credentials for provider '{provider_label}': {e}"
|
|
||||||
else:
|
|
||||||
# Use current provider's base_url from config or registry
|
|
||||||
try:
|
|
||||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
|
||||||
runtime = resolve_runtime_provider(requested=current_provider)
|
|
||||||
api_key = runtime.get("api_key", "")
|
|
||||||
base_url = runtime.get("base_url", "")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Validate the model against the live API
|
|
||||||
try:
|
|
||||||
validation = validate_requested_model(
|
|
||||||
new_model,
|
|
||||||
target_provider,
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
validation = {"accepted": True, "persist": True, "recognized": False, "message": None}
|
|
||||||
|
|
||||||
if not validation.get("accepted"):
|
|
||||||
msg = validation.get("message", "Invalid model")
|
|
||||||
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
|
|
||||||
return f"⚠️ {msg}{tip}"
|
|
||||||
|
|
||||||
# Persist to config only if validation approves
|
|
||||||
if validation.get("persist"):
|
|
||||||
try:
|
try:
|
||||||
user_config = {}
|
user_config = {}
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
|
|
@ -2972,14 +2865,63 @@ class GatewayRunner:
|
||||||
user_config = yaml.safe_load(f) or {}
|
user_config = yaml.safe_load(f) or {}
|
||||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||||
user_config["model"] = {}
|
user_config["model"] = {}
|
||||||
user_config["model"]["default"] = new_model
|
user_config["model"]["default"] = cust_result.model
|
||||||
if provider_changed:
|
user_config["model"]["provider"] = "custom"
|
||||||
user_config["model"]["provider"] = target_provider
|
user_config["model"]["base_url"] = cust_result.base_url
|
||||||
# Persist base_url for custom endpoints so it survives
|
with open(config_path, 'w', encoding="utf-8") as f:
|
||||||
# restart; clear it when switching away from custom to
|
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||||
# prevent stale URLs leaking (#2562 Phase 2).
|
except Exception as e:
|
||||||
if base_url and "openrouter.ai" not in (base_url or ""):
|
return f"⚠️ Failed to save model change: {e}"
|
||||||
user_config["model"]["base_url"] = base_url
|
os.environ["HERMES_MODEL"] = cust_result.model
|
||||||
|
os.environ["HERMES_INFERENCE_PROVIDER"] = "custom"
|
||||||
|
self._effective_model = None
|
||||||
|
self._effective_provider = None
|
||||||
|
return (
|
||||||
|
f"🤖 Model changed to `{cust_result.model}` (saved to config)\n"
|
||||||
|
f"**Provider:** Custom\n"
|
||||||
|
f"**Endpoint:** `{cust_result.base_url}`\n"
|
||||||
|
f"_Model auto-detected from endpoint. Takes effect on next message._"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Core model-switching pipeline (shared with CLI)
|
||||||
|
from hermes_cli.model_switch import switch_model
|
||||||
|
|
||||||
|
# Resolve current base_url for is_custom detection
|
||||||
|
_resolved_base = ""
|
||||||
|
try:
|
||||||
|
from hermes_cli.runtime_provider import resolve_runtime_provider as _rtp
|
||||||
|
_resolved_base = _rtp(requested=current_provider).get("base_url", "")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = switch_model(
|
||||||
|
args,
|
||||||
|
current_provider,
|
||||||
|
current_base_url=_resolved_base,
|
||||||
|
current_api_key=os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
msg = result.error_message
|
||||||
|
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
|
||||||
|
return f"⚠️ {msg}{tip}"
|
||||||
|
|
||||||
|
# Persist to config only if validation approves
|
||||||
|
if result.persist:
|
||||||
|
try:
|
||||||
|
user_config = {}
|
||||||
|
if config_path.exists():
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
user_config = yaml.safe_load(f) or {}
|
||||||
|
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||||
|
user_config["model"] = {}
|
||||||
|
user_config["model"]["default"] = result.new_model
|
||||||
|
if result.provider_changed:
|
||||||
|
user_config["model"]["provider"] = result.target_provider
|
||||||
|
# Persist base_url for custom endpoints; clear when
|
||||||
|
# switching away from custom (#2562 Phase 2).
|
||||||
|
if result.base_url and "openrouter.ai" not in (result.base_url or ""):
|
||||||
|
user_config["model"]["base_url"] = result.base_url
|
||||||
else:
|
else:
|
||||||
user_config["model"].pop("base_url", None)
|
user_config["model"].pop("base_url", None)
|
||||||
with open(config_path, 'w', encoding="utf-8") as f:
|
with open(config_path, 'w', encoding="utf-8") as f:
|
||||||
|
|
@ -2988,41 +2930,34 @@ class GatewayRunner:
|
||||||
return f"⚠️ Failed to save model change: {e}"
|
return f"⚠️ Failed to save model change: {e}"
|
||||||
|
|
||||||
# Set env vars so the next agent run picks up the change
|
# Set env vars so the next agent run picks up the change
|
||||||
os.environ["HERMES_MODEL"] = new_model
|
os.environ["HERMES_MODEL"] = result.new_model
|
||||||
if provider_changed:
|
if result.provider_changed:
|
||||||
os.environ["HERMES_INFERENCE_PROVIDER"] = target_provider
|
os.environ["HERMES_INFERENCE_PROVIDER"] = result.target_provider
|
||||||
|
|
||||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
provider_note = f"\n**Provider:** {result.provider_label}" if result.provider_changed else ""
|
||||||
provider_note = f"\n**Provider:** {provider_label}" if provider_changed else ""
|
|
||||||
|
|
||||||
warning = ""
|
warning = ""
|
||||||
if validation.get("message"):
|
if result.warning_message:
|
||||||
warning = f"\n⚠️ {validation['message']}"
|
warning = f"\n⚠️ {result.warning_message}"
|
||||||
|
|
||||||
|
persist_note = "saved to config" if result.persist else "this session only — will revert on restart"
|
||||||
|
|
||||||
if validation.get("persist"):
|
|
||||||
persist_note = "saved to config"
|
|
||||||
else:
|
|
||||||
persist_note = "this session only — will revert on restart"
|
|
||||||
# Clear fallback state since user explicitly chose a model
|
# Clear fallback state since user explicitly chose a model
|
||||||
self._effective_model = None
|
self._effective_model = None
|
||||||
self._effective_provider = None
|
self._effective_provider = None
|
||||||
|
|
||||||
# Show endpoint info for custom providers
|
# Show endpoint info for custom providers
|
||||||
_target_is_custom = target_provider == "custom" or (
|
|
||||||
base_url and "openrouter.ai" not in (base_url or "")
|
|
||||||
and ("localhost" in (base_url or "") or "127.0.0.1" in (base_url or ""))
|
|
||||||
)
|
|
||||||
custom_hint = ""
|
custom_hint = ""
|
||||||
if _target_is_custom or (is_custom and not provider_changed):
|
if result.is_custom_target:
|
||||||
endpoint = base_url or _resolved_base or "custom endpoint"
|
endpoint = result.base_url or _resolved_base or "custom endpoint"
|
||||||
custom_hint = f"\n**Endpoint:** `{endpoint}`"
|
custom_hint = f"\n**Endpoint:** `{endpoint}`"
|
||||||
if not provider_changed:
|
if not result.provider_changed:
|
||||||
custom_hint += (
|
custom_hint += (
|
||||||
"\n_To switch providers, use_ `/model provider:model`"
|
"\n_To switch providers, use_ `/model provider:model`"
|
||||||
"\n_e.g._ `/model openrouter:anthropic/claude-sonnet-4`"
|
"\n_e.g._ `/model openrouter:anthropic/claude-sonnet-4`"
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"🤖 Model changed to `{new_model}` ({persist_note}){provider_note}{warning}{custom_hint}\n_(takes effect on next message)_"
|
return f"🤖 Model changed to `{result.new_model}` ({persist_note}){provider_note}{warning}{custom_hint}\n_(takes effect on next message)_"
|
||||||
|
|
||||||
async def _handle_provider_command(self, event: MessageEvent) -> str:
|
async def _handle_provider_command(self, event: MessageEvent) -> str:
|
||||||
"""Handle /provider command - show available providers."""
|
"""Handle /provider command - show available providers."""
|
||||||
|
|
|
||||||
234
hermes_cli/model_switch.py
Normal file
234
hermes_cli/model_switch.py
Normal file
|
|
@ -0,0 +1,234 @@
|
||||||
|
"""Shared model-switching logic for CLI and gateway /model commands.
|
||||||
|
|
||||||
|
Both the CLI (cli.py) and gateway (gateway/run.py) /model handlers
|
||||||
|
share the same core pipeline:
|
||||||
|
|
||||||
|
parse_model_input → is_custom detection → auto-detect provider
|
||||||
|
→ credential resolution → validate model → return result
|
||||||
|
|
||||||
|
This module extracts that shared pipeline into pure functions that
|
||||||
|
return result objects. The callers handle all platform-specific
|
||||||
|
concerns: state mutation, config persistence, output formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelSwitchResult:
|
||||||
|
"""Result of a model switch attempt."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
new_model: str = ""
|
||||||
|
target_provider: str = ""
|
||||||
|
provider_changed: bool = False
|
||||||
|
api_key: str = ""
|
||||||
|
base_url: str = ""
|
||||||
|
persist: bool = False
|
||||||
|
error_message: str = ""
|
||||||
|
warning_message: str = ""
|
||||||
|
is_custom_target: bool = False
|
||||||
|
provider_label: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CustomAutoResult:
|
||||||
|
"""Result of switching to bare 'custom' provider with auto-detect."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
model: str = ""
|
||||||
|
base_url: str = ""
|
||||||
|
api_key: str = ""
|
||||||
|
error_message: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def switch_model(
|
||||||
|
raw_input: str,
|
||||||
|
current_provider: str,
|
||||||
|
current_base_url: str = "",
|
||||||
|
current_api_key: str = "",
|
||||||
|
) -> ModelSwitchResult:
|
||||||
|
"""Core model-switching pipeline shared between CLI and gateway.
|
||||||
|
|
||||||
|
Handles parsing, provider detection, credential resolution, and
|
||||||
|
model validation. Does NOT handle config persistence, state
|
||||||
|
mutation, or output formatting — those are caller responsibilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_input: The user's model input (e.g. "claude-sonnet-4",
|
||||||
|
"zai:glm-5", "custom:local:qwen").
|
||||||
|
current_provider: The currently active provider.
|
||||||
|
current_base_url: The currently active base URL (used for
|
||||||
|
is_custom detection).
|
||||||
|
current_api_key: The currently active API key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelSwitchResult with all information the caller needs to
|
||||||
|
apply the switch and format output.
|
||||||
|
"""
|
||||||
|
from hermes_cli.models import (
|
||||||
|
parse_model_input,
|
||||||
|
detect_provider_for_model,
|
||||||
|
validate_requested_model,
|
||||||
|
_PROVIDER_LABELS,
|
||||||
|
)
|
||||||
|
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||||
|
|
||||||
|
# Step 1: Parse provider:model syntax
|
||||||
|
target_provider, new_model = parse_model_input(raw_input, current_provider)
|
||||||
|
|
||||||
|
# Step 2: Detect if we're currently on a custom endpoint
|
||||||
|
_base = current_base_url or ""
|
||||||
|
is_custom = current_provider == "custom" or (
|
||||||
|
"localhost" in _base or "127.0.0.1" in _base
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Auto-detect provider when no explicit provider:model syntax
|
||||||
|
# was used. Skip for custom providers — the model name might
|
||||||
|
# coincidentally match a known provider's catalog.
|
||||||
|
if target_provider == current_provider and not is_custom:
|
||||||
|
detected = detect_provider_for_model(new_model, current_provider)
|
||||||
|
if detected:
|
||||||
|
target_provider, new_model = detected
|
||||||
|
|
||||||
|
provider_changed = target_provider != current_provider
|
||||||
|
|
||||||
|
# Step 4: Resolve credentials for target provider
|
||||||
|
api_key = current_api_key
|
||||||
|
base_url = current_base_url
|
||||||
|
if provider_changed:
|
||||||
|
try:
|
||||||
|
runtime = resolve_runtime_provider(requested=target_provider)
|
||||||
|
api_key = runtime.get("api_key", "")
|
||||||
|
base_url = runtime.get("base_url", "")
|
||||||
|
except Exception as e:
|
||||||
|
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||||
|
if target_provider == "custom":
|
||||||
|
return ModelSwitchResult(
|
||||||
|
success=False,
|
||||||
|
target_provider=target_provider,
|
||||||
|
error_message=(
|
||||||
|
"No custom endpoint configured. Set model.base_url "
|
||||||
|
"in config.yaml, or set OPENAI_BASE_URL in .env, "
|
||||||
|
"or run: hermes setup → Custom OpenAI-compatible endpoint"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return ModelSwitchResult(
|
||||||
|
success=False,
|
||||||
|
target_provider=target_provider,
|
||||||
|
error_message=(
|
||||||
|
f"Could not resolve credentials for provider "
|
||||||
|
f"'{provider_label}': {e}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Gateway also resolves for unchanged provider to get accurate
|
||||||
|
# base_url for validation probing.
|
||||||
|
try:
|
||||||
|
runtime = resolve_runtime_provider(requested=current_provider)
|
||||||
|
api_key = runtime.get("api_key", "")
|
||||||
|
base_url = runtime.get("base_url", "")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Step 5: Validate the model
|
||||||
|
try:
|
||||||
|
validation = validate_requested_model(
|
||||||
|
new_model,
|
||||||
|
target_provider,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
validation = {
|
||||||
|
"accepted": True,
|
||||||
|
"persist": True,
|
||||||
|
"recognized": False,
|
||||||
|
"message": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not validation.get("accepted"):
|
||||||
|
msg = validation.get("message", "Invalid model")
|
||||||
|
return ModelSwitchResult(
|
||||||
|
success=False,
|
||||||
|
new_model=new_model,
|
||||||
|
target_provider=target_provider,
|
||||||
|
error_message=msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 6: Build result
|
||||||
|
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||||
|
is_custom_target = target_provider == "custom" or (
|
||||||
|
base_url
|
||||||
|
and "openrouter.ai" not in (base_url or "")
|
||||||
|
and ("localhost" in (base_url or "") or "127.0.0.1" in (base_url or ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelSwitchResult(
|
||||||
|
success=True,
|
||||||
|
new_model=new_model,
|
||||||
|
target_provider=target_provider,
|
||||||
|
provider_changed=provider_changed,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
persist=bool(validation.get("persist")),
|
||||||
|
warning_message=validation.get("message") or "",
|
||||||
|
is_custom_target=is_custom_target,
|
||||||
|
provider_label=provider_label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def switch_to_custom_provider() -> CustomAutoResult:
|
||||||
|
"""Handle bare '/model custom' — resolve endpoint and auto-detect model.
|
||||||
|
|
||||||
|
Returns a result object; the caller handles persistence and output.
|
||||||
|
"""
|
||||||
|
from hermes_cli.runtime_provider import (
|
||||||
|
resolve_runtime_provider,
|
||||||
|
_auto_detect_local_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
runtime = resolve_runtime_provider(requested="custom")
|
||||||
|
except Exception as e:
|
||||||
|
return CustomAutoResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Could not resolve custom endpoint: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
cust_base = runtime.get("base_url", "")
|
||||||
|
cust_key = runtime.get("api_key", "")
|
||||||
|
|
||||||
|
if not cust_base or "openrouter.ai" in cust_base:
|
||||||
|
return CustomAutoResult(
|
||||||
|
success=False,
|
||||||
|
error_message=(
|
||||||
|
"No custom endpoint configured. "
|
||||||
|
"Set model.base_url in config.yaml, or set OPENAI_BASE_URL "
|
||||||
|
"in .env, or run: hermes setup → Custom OpenAI-compatible endpoint"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
detected_model = _auto_detect_local_model(cust_base)
|
||||||
|
if not detected_model:
|
||||||
|
return CustomAutoResult(
|
||||||
|
success=False,
|
||||||
|
base_url=cust_base,
|
||||||
|
api_key=cust_key,
|
||||||
|
error_message=(
|
||||||
|
f"Custom endpoint at {cust_base} is reachable but no single "
|
||||||
|
f"model was auto-detected. Specify the model explicitly: "
|
||||||
|
f"/model custom:<model-name>"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return CustomAutoResult(
|
||||||
|
success=True,
|
||||||
|
model=detected_model,
|
||||||
|
base_url=cust_base,
|
||||||
|
api_key=cust_key,
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue