feat: use endpoint metadata for custom model context and pricing (#1906)

* perf: cache base_url.lower() via property, consolidate triple load_config(), hoist set constant

run_agent.py:
- Add base_url property that auto-caches _base_url_lower on every
  assignment, eliminating 12+ redundant .lower() calls per API cycle
  across __init__, _build_api_kwargs, _supports_reasoning_extra_body,
  and the main conversation loop
- Consolidate three separate load_config() disk reads in __init__
  (memory, skills, compression) into a single call, reusing the
  result dict for all three config sections

model_tools.py:
- Hoist _READ_SEARCH_TOOLS set to module level (was rebuilt inside
  handle_function_call on every tool invocation)

* Use endpoint metadata for custom model context and pricing

---------

Co-authored-by: kshitij <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
Teknium 2026-03-18 03:04:07 -07:00 committed by GitHub
parent 11f029c311
commit a2440f72f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 375 additions and 49 deletions

View file

@ -45,16 +45,18 @@ class ContextCompressor:
quiet_mode: bool = False, quiet_mode: bool = False,
summary_model_override: str = None, summary_model_override: str = None,
base_url: str = "", base_url: str = "",
api_key: str = "",
): ):
self.model = model self.model = model
self.base_url = base_url self.base_url = base_url
self.api_key = api_key
self.threshold_percent = threshold_percent self.threshold_percent = threshold_percent
self.protect_first_n = protect_first_n self.protect_first_n = protect_first_n
self.protect_last_n = protect_last_n self.protect_last_n = protect_last_n
self.summary_target_tokens = summary_target_tokens self.summary_target_tokens = summary_target_tokens
self.quiet_mode = quiet_mode self.quiet_mode = quiet_mode
self.context_length = get_model_context_length(model, base_url=base_url) self.context_length = get_model_context_length(model, base_url=base_url, api_key=api_key)
self.threshold_tokens = int(self.context_length * threshold_percent) self.threshold_tokens = int(self.context_length * threshold_percent)
self.compression_count = 0 self.compression_count = 0
self._context_probed = False # True after a step-down from context error self._context_probed = False # True after a step-down from context error

View file

@ -10,6 +10,7 @@ import re
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import requests import requests
import yaml import yaml
@ -21,6 +22,9 @@ logger = logging.getLogger(__name__)
_model_metadata_cache: Dict[str, Dict[str, Any]] = {} _model_metadata_cache: Dict[str, Dict[str, Any]] = {}
_model_metadata_cache_time: float = 0 _model_metadata_cache_time: float = 0
_MODEL_CACHE_TTL = 3600 _MODEL_CACHE_TTL = 3600
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
_ENDPOINT_MODEL_CACHE_TTL = 300
# Descending tiers for context length probing when the model is unknown. # Descending tiers for context length probing when the model is unknown.
# We start high and step down on context-length errors until one works. # We start high and step down on context-length errors until one works.
@ -123,6 +127,128 @@ DEFAULT_CONTEXT_LENGTHS = {
"qwen-vl-max": 32768, "qwen-vl-max": 32768,
} }
_CONTEXT_LENGTH_KEYS = (
"context_length",
"context_window",
"max_context_length",
"max_position_embeddings",
"max_model_len",
"max_input_tokens",
"max_sequence_length",
"max_seq_len",
)
_MAX_COMPLETION_KEYS = (
"max_completion_tokens",
"max_output_tokens",
"max_tokens",
)
def _normalize_base_url(base_url: str) -> str:
return (base_url or "").strip().rstrip("/")
def _is_openrouter_base_url(base_url: str) -> bool:
return "openrouter.ai" in _normalize_base_url(base_url).lower()
def _is_custom_endpoint(base_url: str) -> bool:
normalized = _normalize_base_url(base_url)
return bool(normalized) and not _is_openrouter_base_url(normalized)
def _is_known_provider_base_url(base_url: str) -> bool:
normalized = _normalize_base_url(base_url)
if not normalized:
return False
parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}")
host = parsed.netloc.lower() or parsed.path.lower()
known_hosts = (
"api.openai.com",
"chatgpt.com",
"api.anthropic.com",
"api.z.ai",
"api.moonshot.ai",
"api.kimi.com",
"api.minimax",
)
return any(known_host in host for known_host in known_hosts)
def _iter_nested_dicts(value: Any):
if isinstance(value, dict):
yield value
for nested in value.values():
yield from _iter_nested_dicts(nested)
elif isinstance(value, list):
for item in value:
yield from _iter_nested_dicts(item)
def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]:
try:
if isinstance(value, bool):
return None
if isinstance(value, str):
value = value.strip().replace(",", "")
result = int(value)
except (TypeError, ValueError):
return None
if minimum <= result <= maximum:
return result
return None
def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]:
keyset = {key.lower() for key in keys}
for mapping in _iter_nested_dicts(payload):
for key, value in mapping.items():
if str(key).lower() not in keyset:
continue
coerced = _coerce_reasonable_int(value)
if coerced is not None:
return coerced
return None
def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]:
return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS)
def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
return _extract_first_int(payload, _MAX_COMPLETION_KEYS)
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
alias_map = {
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
"request": ("request", "request_cost"),
"cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"),
"cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"),
}
for mapping in _iter_nested_dicts(payload):
normalized = {str(key).lower(): value for key, value in mapping.items()}
if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()):
continue
pricing: Dict[str, Any] = {}
for target, aliases in alias_map.items():
for alias in aliases:
if alias in normalized and normalized[alias] not in (None, ""):
pricing[target] = normalized[alias]
break
if pricing:
return pricing
return {}
def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None:
cache[model_id] = entry
if "/" in model_id:
bare_model = model_id.split("/", 1)[1]
cache.setdefault(bare_model, entry)
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]: def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
"""Fetch model metadata from OpenRouter (cached for 1 hour).""" """Fetch model metadata from OpenRouter (cached for 1 hour)."""
@ -139,15 +265,16 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
cache = {} cache = {}
for model in data.get("data", []): for model in data.get("data", []):
model_id = model.get("id", "") model_id = model.get("id", "")
cache[model_id] = { entry = {
"context_length": model.get("context_length", 128000), "context_length": model.get("context_length", 128000),
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096), "max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
"name": model.get("name", model_id), "name": model.get("name", model_id),
"pricing": model.get("pricing", {}), "pricing": model.get("pricing", {}),
} }
_add_model_aliases(cache, model_id, entry)
canonical = model.get("canonical_slug", "") canonical = model.get("canonical_slug", "")
if canonical and canonical != model_id: if canonical and canonical != model_id:
cache[canonical] = cache[model_id] _add_model_aliases(cache, canonical, entry)
_model_metadata_cache = cache _model_metadata_cache = cache
_model_metadata_cache_time = time.time() _model_metadata_cache_time = time.time()
@ -159,6 +286,75 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
return _model_metadata_cache or {} return _model_metadata_cache or {}
def fetch_endpoint_model_metadata(
base_url: str,
api_key: str = "",
force_refresh: bool = False,
) -> Dict[str, Dict[str, Any]]:
"""Fetch model metadata from an OpenAI-compatible ``/models`` endpoint.
This is used for explicit custom endpoints where hardcoded global model-name
defaults are unreliable. Results are cached in memory per base URL.
"""
normalized = _normalize_base_url(base_url)
if not normalized or _is_openrouter_base_url(normalized):
return {}
if not force_refresh:
cached = _endpoint_model_metadata_cache.get(normalized)
cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0)
if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL:
return cached
candidates = [normalized]
if normalized.endswith("/v1"):
alternate = normalized[:-3].rstrip("/")
else:
alternate = normalized + "/v1"
if alternate and alternate not in candidates:
candidates.append(alternate)
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
last_error: Optional[Exception] = None
for candidate in candidates:
url = candidate.rstrip("/") + "/models"
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
payload = response.json()
cache: Dict[str, Dict[str, Any]] = {}
for model in payload.get("data", []):
if not isinstance(model, dict):
continue
model_id = model.get("id")
if not model_id:
continue
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
context_length = _extract_context_length(model)
if context_length is not None:
entry["context_length"] = context_length
max_completion_tokens = _extract_max_completion_tokens(model)
if max_completion_tokens is not None:
entry["max_completion_tokens"] = max_completion_tokens
pricing = _extract_pricing(model)
if pricing:
entry["pricing"] = pricing
_add_model_aliases(cache, model_id, entry)
_endpoint_model_metadata_cache[normalized] = cache
_endpoint_model_metadata_cache_time[normalized] = time.time()
return cache
except Exception as exc:
last_error = exc
if last_error:
logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error)
_endpoint_model_metadata_cache[normalized] = {}
_endpoint_model_metadata_cache_time[normalized] = time.time()
return {}
def _get_context_cache_path() -> Path: def _get_context_cache_path() -> Path:
"""Return path to the persistent context length cache file.""" """Return path to the persistent context length cache file."""
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
@ -243,14 +439,15 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
return None return None
def get_model_context_length(model: str, base_url: str = "") -> int: def get_model_context_length(model: str, base_url: str = "", api_key: str = "") -> int:
"""Get the context length for a model. """Get the context length for a model.
Resolution order: Resolution order:
1. Persistent cache (previously discovered via probing) 1. Persistent cache (previously discovered via probing)
2. OpenRouter API metadata 2. Active endpoint metadata (/models for explicit custom endpoints)
3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match) 3. OpenRouter API metadata
4. First probe tier (2M) will be narrowed on first context error 4. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
5. First probe tier (2M) will be narrowed on first context error
""" """
# 1. Check persistent cache (model+provider) # 1. Check persistent cache (model+provider)
if base_url: if base_url:
@ -258,19 +455,31 @@ def get_model_context_length(model: str, base_url: str = "") -> int:
if cached is not None: if cached is not None:
return cached return cached
# 2. OpenRouter API metadata # 2. Active endpoint metadata for explicit custom routes
if _is_custom_endpoint(base_url):
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
if model in endpoint_metadata:
context_length = endpoint_metadata[model].get("context_length")
if isinstance(context_length, int):
return context_length
if not _is_known_provider_base_url(base_url):
# Explicit third-party endpoints should not borrow fuzzy global
# defaults from unrelated providers with similarly named models.
return CONTEXT_PROBE_TIERS[0]
# 3. OpenRouter API metadata
metadata = fetch_model_metadata() metadata = fetch_model_metadata()
if model in metadata: if model in metadata:
return metadata[model].get("context_length", 128000) return metadata[model].get("context_length", 128000)
# 3. Hardcoded defaults (fuzzy match — longest key first for specificity) # 4. Hardcoded defaults (fuzzy match — longest key first for specificity)
for default_model, length in sorted( for default_model, length in sorted(
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
): ):
if default_model in model or model in default_model: if default_model in model or model in default_model:
return length return length
# 4. Unknown model — start at highest probe tier # 5. Unknown model — start at highest probe tier
return CONTEXT_PROBE_TIERS[0] return CONTEXT_PROBE_TIERS[0]

View file

@ -5,7 +5,7 @@ from datetime import datetime, timezone
from decimal import Decimal from decimal import Decimal
from typing import Any, Dict, Literal, Optional from typing import Any, Dict, Literal, Optional
from agent.model_metadata import fetch_model_metadata from agent.model_metadata import fetch_endpoint_model_metadata, fetch_model_metadata
DEFAULT_PRICING = {"input": 0.0, "output": 0.0} DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
@ -335,8 +335,21 @@ def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry]
def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]: def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
metadata = fetch_model_metadata() return _pricing_entry_from_metadata(
model_id = route.model fetch_model_metadata(),
route.model,
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models",
pricing_version="openrouter-models-api",
)
def _pricing_entry_from_metadata(
metadata: Dict[str, Dict[str, Any]],
model_id: str,
*,
source_url: str,
pricing_version: str,
) -> Optional[PricingEntry]:
if model_id not in metadata: if model_id not in metadata:
return None return None
pricing = metadata[model_id].get("pricing") or {} pricing = metadata[model_id].get("pricing") or {}
@ -355,6 +368,7 @@ def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
) )
if prompt is None and completion is None and request is None: if prompt is None and completion is None and request is None:
return None return None
def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]: def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]:
if value is None: if value is None:
return None return None
@ -367,8 +381,8 @@ def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
cache_write_cost_per_million=_per_token_to_per_million(cache_write), cache_write_cost_per_million=_per_token_to_per_million(cache_write),
request_cost=request, request_cost=request,
source="provider_models_api", source="provider_models_api",
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models", source_url=source_url,
pricing_version="openrouter-models-api", pricing_version=pricing_version,
fetched_at=_UTC_NOW(), fetched_at=_UTC_NOW(),
) )
@ -377,6 +391,7 @@ def get_pricing_entry(
model_name: str, model_name: str,
provider: Optional[str] = None, provider: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> Optional[PricingEntry]: ) -> Optional[PricingEntry]:
route = resolve_billing_route(model_name, provider=provider, base_url=base_url) route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
if route.billing_mode == "subscription_included": if route.billing_mode == "subscription_included":
@ -390,6 +405,15 @@ def get_pricing_entry(
) )
if route.provider == "openrouter": if route.provider == "openrouter":
return _openrouter_pricing_entry(route) return _openrouter_pricing_entry(route)
if route.base_url:
entry = _pricing_entry_from_metadata(
fetch_endpoint_model_metadata(route.base_url, api_key=api_key or ""),
route.model,
source_url=f"{route.base_url.rstrip('/')}/models",
pricing_version="openai-compatible-models-api",
)
if entry:
return entry
return _lookup_official_docs_pricing(route) return _lookup_official_docs_pricing(route)
@ -460,6 +484,7 @@ def estimate_usage_cost(
*, *,
provider: Optional[str] = None, provider: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> CostResult: ) -> CostResult:
route = resolve_billing_route(model_name, provider=provider, base_url=base_url) route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
if route.billing_mode == "subscription_included": if route.billing_mode == "subscription_included":
@ -471,7 +496,7 @@ def estimate_usage_cost(
pricing_version="included-route", pricing_version="included-route",
) )
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url) entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
if not entry: if not entry:
return CostResult(amount_usd=None, status="unknown", source="none", label="n/a") return CostResult(amount_usd=None, status="unknown", source="none", label="n/a")
@ -536,6 +561,7 @@ def has_known_pricing(
model_name: str, model_name: str,
provider: Optional[str] = None, provider: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> bool: ) -> bool:
"""Check whether we have pricing data for this model+route. """Check whether we have pricing data for this model+route.
@ -545,7 +571,7 @@ def has_known_pricing(
route = resolve_billing_route(model_name, provider=provider, base_url=base_url) route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
if route.billing_mode == "subscription_included": if route.billing_mode == "subscription_included":
return True return True
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url) entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
return entry is not None return entry is not None
@ -553,13 +579,14 @@ def get_pricing(
model_name: str, model_name: str,
provider: Optional[str] = None, provider: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Backward-compatible thin wrapper for legacy callers. """Backward-compatible thin wrapper for legacy callers.
Returns only non-cache input/output fields when a pricing entry exists. Returns only non-cache input/output fields when a pricing entry exists.
Unknown routes return zeroes. Unknown routes return zeroes.
""" """
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url) entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
if not entry: if not entry:
return {"input": 0.0, "output": 0.0} return {"input": 0.0, "output": 0.0}
return { return {
@ -575,6 +602,7 @@ def estimate_cost_usd(
*, *,
provider: Optional[str] = None, provider: Optional[str] = None,
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> float: ) -> float:
"""Backward-compatible helper for legacy callers. """Backward-compatible helper for legacy callers.
@ -586,6 +614,7 @@ def estimate_cost_usd(
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens), CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
provider=provider, provider=provider,
base_url=base_url, base_url=base_url,
api_key=api_key,
) )
return float(result.amount_usd or _ZERO) return float(result.amount_usd or _ZERO)

View file

@ -276,6 +276,7 @@ def get_tool_definitions(
# The registry still holds their schemas; dispatch just returns a stub error # The registry still holds their schemas; dispatch just returns a stub error
# so if something slips through, the LLM sees a sensible message. # so if something slips through, the LLM sees a sensible message.
_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"} _AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
def handle_function_call( def handle_function_call(
@ -305,7 +306,6 @@ def handle_function_call(
""" """
# Notify the read-loop tracker when a non-read/search tool runs, # Notify the read-loop tracker when a non-read/search tool runs,
# so the *consecutive* counter resets (reads after other work are fine). # so the *consecutive* counter resets (reads after other work are fine).
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
if function_name not in _READ_SEARCH_TOOLS: if function_name not in _READ_SEARCH_TOOLS:
try: try:
from tools.file_tools import notify_other_tool_call from tools.file_tools import notify_other_tool_call

View file

@ -268,6 +268,15 @@ class AIAgent:
for AI models that support function calling. for AI models that support function calling.
""" """
@property
def base_url(self) -> str:
return self._base_url
@base_url.setter
def base_url(self, value: str) -> None:
self._base_url = value
self._base_url_lower = value.lower() if value else ""
def __init__( def __init__(
self, self,
base_url: str = None, base_url: str = None,
@ -383,10 +392,10 @@ class AIAgent:
self.api_mode = api_mode self.api_mode = api_mode
elif self.provider == "openai-codex": elif self.provider == "openai-codex":
self.api_mode = "codex_responses" self.api_mode = "codex_responses"
elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self.base_url.lower(): elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self._base_url_lower:
self.api_mode = "codex_responses" self.api_mode = "codex_responses"
self.provider = "openai-codex" self.provider = "openai-codex"
elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self.base_url.lower()): elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self._base_url_lower):
self.api_mode = "anthropic_messages" self.api_mode = "anthropic_messages"
self.provider = "anthropic" self.provider = "anthropic"
else: else:
@ -395,7 +404,7 @@ class AIAgent:
# Pre-warm OpenRouter model metadata cache in a background thread. # Pre-warm OpenRouter model metadata cache in a background thread.
# fetch_model_metadata() is cached for 1 hour; this avoids a blocking # fetch_model_metadata() is cached for 1 hour; this avoids a blocking
# HTTP request on the first API response when pricing is estimated. # HTTP request on the first API response when pricing is estimated.
if self.provider == "openrouter" or "openrouter" in self.base_url.lower(): if self.provider == "openrouter" or "openrouter" in self._base_url_lower:
threading.Thread( threading.Thread(
target=lambda: fetch_model_metadata(), target=lambda: fetch_model_metadata(),
daemon=True, daemon=True,
@ -439,7 +448,7 @@ class AIAgent:
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter. # Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
# Reduces input costs by ~75% on multi-turn conversations by caching the # Reduces input costs by ~75% on multi-turn conversations by caching the
# conversation prefix. Uses system_and_3 strategy (4 breakpoints). # conversation prefix. Uses system_and_3 strategy (4 breakpoints).
is_openrouter = "openrouter" in self.base_url.lower() is_openrouter = "openrouter" in self._base_url_lower
is_claude = "claude" in self.model.lower() is_claude = "claude" in self.model.lower()
is_native_anthropic = self.api_mode == "anthropic_messages" is_native_anthropic = self.api_mode == "anthropic_messages"
self._use_prompt_caching = (is_openrouter and is_claude) or is_native_anthropic self._use_prompt_caching = (is_openrouter and is_claude) or is_native_anthropic
@ -555,6 +564,7 @@ class AIAgent:
if self.api_mode == "anthropic_messages": if self.api_mode == "anthropic_messages":
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
effective_key = api_key or resolve_anthropic_token() or "" effective_key = api_key or resolve_anthropic_token() or ""
self.api_key = effective_key
self._anthropic_api_key = effective_key self._anthropic_api_key = effective_key
self._anthropic_base_url = base_url self._anthropic_base_url = base_url
from agent.anthropic_adapter import _is_oauth_token as _is_oat from agent.anthropic_adapter import _is_oauth_token as _is_oat
@ -609,6 +619,7 @@ class AIAgent:
} }
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
self.api_key = client_kwargs.get("api_key", "")
try: try:
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True) self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
if not self.quiet_mode: if not self.quiet_mode:
@ -732,6 +743,13 @@ class AIAgent:
from tools.todo_tool import TodoStore from tools.todo_tool import TodoStore
self._todo_store = TodoStore() self._todo_store = TodoStore()
# Load config once for memory, skills, and compression sections
try:
from hermes_cli.config import load_config as _load_agent_config
_agent_cfg = _load_agent_config()
except Exception:
_agent_cfg = {}
# Persistent memory (MEMORY.md + USER.md) -- loaded from disk # Persistent memory (MEMORY.md + USER.md) -- loaded from disk
self._memory_store = None self._memory_store = None
self._memory_enabled = False self._memory_enabled = False
@ -742,8 +760,7 @@ class AIAgent:
self._iters_since_skill = 0 self._iters_since_skill = 0
if not skip_memory: if not skip_memory:
try: try:
from hermes_cli.config import load_config as _load_mem_config mem_config = _agent_cfg.get("memory", {})
mem_config = _load_mem_config().get("memory", {})
self._memory_enabled = mem_config.get("memory_enabled", False) self._memory_enabled = mem_config.get("memory_enabled", False)
self._user_profile_enabled = mem_config.get("user_profile_enabled", False) self._user_profile_enabled = mem_config.get("user_profile_enabled", False)
self._memory_nudge_interval = int(mem_config.get("nudge_interval", 10)) self._memory_nudge_interval = int(mem_config.get("nudge_interval", 10))
@ -831,8 +848,7 @@ class AIAgent:
# Skills config: nudge interval for skill creation reminders # Skills config: nudge interval for skill creation reminders
self._skill_nudge_interval = 10 self._skill_nudge_interval = 10
try: try:
from hermes_cli.config import load_config as _load_skills_config skills_config = _agent_cfg.get("skills", {})
skills_config = _load_skills_config().get("skills", {})
self._skill_nudge_interval = int(skills_config.get("creation_nudge_interval", 15)) self._skill_nudge_interval = int(skills_config.get("creation_nudge_interval", 15))
except Exception: except Exception:
pass pass
@ -840,12 +856,8 @@ class AIAgent:
# Initialize context compressor for automatic context management # Initialize context compressor for automatic context management
# Compresses conversation when approaching model's context limit # Compresses conversation when approaching model's context limit
# Configuration via config.yaml (compression section) # Configuration via config.yaml (compression section)
try: _compression_cfg = _agent_cfg.get("compression", {})
from hermes_cli.config import load_config as _load_compression_config if not isinstance(_compression_cfg, dict):
_compression_cfg = _load_compression_config().get("compression", {})
if not isinstance(_compression_cfg, dict):
_compression_cfg = {}
except ImportError:
_compression_cfg = {} _compression_cfg = {}
compression_threshold = float(_compression_cfg.get("threshold", 0.50)) compression_threshold = float(_compression_cfg.get("threshold", 0.50))
compression_enabled = str(_compression_cfg.get("enabled", True)).lower() in ("true", "1", "yes") compression_enabled = str(_compression_cfg.get("enabled", True)).lower() in ("true", "1", "yes")
@ -860,6 +872,7 @@ class AIAgent:
summary_model_override=compression_summary_model, summary_model_override=compression_summary_model,
quiet_mode=self.quiet_mode, quiet_mode=self.quiet_mode,
base_url=self.base_url, base_url=self.base_url,
api_key=getattr(self, "api_key", ""),
) )
self.compression_enabled = compression_enabled self.compression_enabled = compression_enabled
self._user_turn_count = 0 self._user_turn_count = 0
@ -915,8 +928,8 @@ class AIAgent:
OpenAI models use 'max_tokens'. OpenAI models use 'max_tokens'.
""" """
_is_direct_openai = ( _is_direct_openai = (
"api.openai.com" in self.base_url.lower() "api.openai.com" in self._base_url_lower
and "openrouter" not in self.base_url.lower() and "openrouter" not in self._base_url_lower
) )
if _is_direct_openai: if _is_direct_openai:
return {"max_completion_tokens": value} return {"max_completion_tokens": value}
@ -3643,7 +3656,7 @@ class AIAgent:
extra_body = {} extra_body = {}
_is_openrouter = "openrouter" in self.base_url.lower() _is_openrouter = "openrouter" in self._base_url_lower
# Provider preferences (only, ignore, order, sort) are OpenRouter- # Provider preferences (only, ignore, order, sort) are OpenRouter-
# specific. Only send to OpenRouter-compatible endpoints. # specific. Only send to OpenRouter-compatible endpoints.
@ -3651,7 +3664,7 @@ class AIAgent:
# for _is_nous when their backend is updated. # for _is_nous when their backend is updated.
if provider_preferences and _is_openrouter: if provider_preferences and _is_openrouter:
extra_body["provider"] = provider_preferences extra_body["provider"] = provider_preferences
_is_nous = "nousresearch" in self.base_url.lower() _is_nous = "nousresearch" in self._base_url_lower
if self._supports_reasoning_extra_body(): if self._supports_reasoning_extra_body():
if self.reasoning_config is not None: if self.reasoning_config is not None:
@ -3684,14 +3697,13 @@ class AIAgent:
Some providers/routes reject `reasoning` with 400s, so gate it to Some providers/routes reject `reasoning` with 400s, so gate it to
known reasoning-capable model families and direct Nous Portal. known reasoning-capable model families and direct Nous Portal.
""" """
base_url = (self.base_url or "").lower() if "nousresearch" in self._base_url_lower:
if "nousresearch" in base_url:
return True return True
if "ai-gateway.vercel.sh" in base_url: if "ai-gateway.vercel.sh" in self._base_url_lower:
return True return True
if "openrouter" not in base_url: if "openrouter" not in self._base_url_lower:
return False return False
if "api.mistral.ai" in base_url: if "api.mistral.ai" in self._base_url_lower:
return False return False
model = (self.model or "").lower() model = (self.model or "").lower()
@ -3877,7 +3889,7 @@ class AIAgent:
try: try:
# Build API messages for the flush call # Build API messages for the flush call
_is_strict_api = "api.mistral.ai" in self.base_url.lower() _is_strict_api = "api.mistral.ai" in self._base_url_lower
api_messages = [] api_messages = []
for msg in messages: for msg in messages:
api_msg = msg.copy() api_msg = msg.copy()
@ -4653,7 +4665,7 @@ class AIAgent:
try: try:
# Build API messages, stripping internal-only fields # Build API messages, stripping internal-only fields
# (finish_reason, reasoning) that strict APIs like Mistral reject with 422 # (finish_reason, reasoning) that strict APIs like Mistral reject with 422
_is_strict_api = "api.mistral.ai" in self.base_url.lower() _is_strict_api = "api.mistral.ai" in self._base_url_lower
api_messages = [] api_messages = []
for msg in messages: for msg in messages:
api_msg = msg.copy() api_msg = msg.copy()
@ -4674,7 +4686,7 @@ class AIAgent:
api_messages.insert(sys_offset + idx, pfm.copy()) api_messages.insert(sys_offset + idx, pfm.copy())
summary_extra_body = {} summary_extra_body = {}
_is_nous = "nousresearch" in self.base_url.lower() _is_nous = "nousresearch" in self._base_url_lower
if self._supports_reasoning_extra_body(): if self._supports_reasoning_extra_body():
if self.reasoning_config is not None: if self.reasoning_config is not None:
summary_extra_body["reasoning"] = self.reasoning_config summary_extra_body["reasoning"] = self.reasoning_config
@ -5092,7 +5104,7 @@ class AIAgent:
# strict providers like Mistral that reject unknown fields with 422. # strict providers like Mistral that reject unknown fields with 422.
# Uses new dicts so the internal messages list retains the fields # Uses new dicts so the internal messages list retains the fields
# for Codex Responses compatibility. # for Codex Responses compatibility.
if "api.mistral.ai" in self.base_url.lower(): if "api.mistral.ai" in self._base_url_lower:
self._sanitize_tool_calls_for_strict_api(api_msg) self._sanitize_tool_calls_for_strict_api(api_msg)
# Keep 'reasoning_details' - OpenRouter uses this for multi-turn reasoning context # Keep 'reasoning_details' - OpenRouter uses this for multi-turn reasoning context
# The signature field helps maintain reasoning continuity # The signature field helps maintain reasoning continuity
@ -5464,6 +5476,7 @@ class AIAgent:
canonical_usage, canonical_usage,
provider=self.provider, provider=self.provider,
base_url=self.base_url, base_url=self.base_url,
api_key=getattr(self, "api_key", ""),
) )
if cost_result.amount_usd is not None: if cost_result.amount_usd is not None:
self.session_estimated_cost_usd += float(cost_result.amount_usd) self.session_estimated_cost_usd += float(cost_result.amount_usd)

View file

@ -188,6 +188,36 @@ class TestGetModelContextLength:
result = get_model_context_length("custom/model") result = get_model_context_length("custom/model")
assert result == CONTEXT_PROBE_TIERS[0] assert result == CONTEXT_PROBE_TIERS[0]
@patch("agent.model_metadata.fetch_model_metadata")
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
def test_custom_endpoint_metadata_beats_fuzzy_default(self, mock_endpoint_fetch, mock_fetch):
mock_fetch.return_value = {}
mock_endpoint_fetch.return_value = {
"zai-org/GLM-5-TEE": {"context_length": 65536}
}
result = get_model_context_length(
"zai-org/GLM-5-TEE",
base_url="https://llm.chutes.ai/v1",
api_key="test-key",
)
assert result == 65536
@patch("agent.model_metadata.fetch_model_metadata")
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
def test_custom_endpoint_without_metadata_skips_name_based_default(self, mock_endpoint_fetch, mock_fetch):
mock_fetch.return_value = {}
mock_endpoint_fetch.return_value = {}
result = get_model_context_length(
"zai-org/GLM-5-TEE",
base_url="https://llm.chutes.ai/v1",
api_key="test-key",
)
assert result == CONTEXT_PROBE_TIERS[0]
# ========================================================================= # =========================================================================
# fetch_model_metadata — caching, TTL, slugs, failures # fetch_model_metadata — caching, TTL, slugs, failures
@ -258,6 +288,25 @@ class TestFetchModelMetadata:
assert "anthropic/claude-3.5-sonnet" in result assert "anthropic/claude-3.5-sonnet" in result
assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000 assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000
@patch("agent.model_metadata.requests.get")
def test_provider_prefixed_models_get_bare_aliases(self, mock_get):
self._reset_cache()
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [{
"id": "provider/test-model",
"context_length": 123456,
"name": "Provider: Test Model",
}]
}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
result = fetch_model_metadata(force_refresh=True)
assert result["provider/test-model"]["context_length"] == 123456
assert result["test-model"]["context_length"] == 123456
@patch("agent.model_metadata.requests.get") @patch("agent.model_metadata.requests.get")
def test_ttl_expiry_triggers_refetch(self, mock_get): def test_ttl_expiry_triggers_refetch(self, mock_get):
"""Cache expires after _MODEL_CACHE_TTL seconds.""" """Cache expires after _MODEL_CACHE_TTL seconds."""

View file

@ -99,3 +99,27 @@ def test_estimate_usage_cost_refuses_cache_pricing_without_official_cache_rate(m
) )
assert result.status == "unknown" assert result.status == "unknown"
def test_custom_endpoint_models_api_pricing_is_supported(monkeypatch):
monkeypatch.setattr(
"agent.usage_pricing.fetch_endpoint_model_metadata",
lambda base_url, api_key=None: {
"zai-org/GLM-5-TEE": {
"pricing": {
"prompt": "0.0000005",
"completion": "0.000002",
}
}
},
)
entry = get_pricing_entry(
"zai-org/GLM-5-TEE",
provider="custom",
base_url="https://llm.chutes.ai/v1",
api_key="test-key",
)
assert float(entry.input_cost_per_million) == 0.5
assert float(entry.output_cost_per_million) == 2.0