feat: use endpoint metadata for custom model context and pricing (#1906)
* perf: cache base_url.lower() via property, consolidate triple load_config(), hoist set constant run_agent.py: - Add base_url property that auto-caches _base_url_lower on every assignment, eliminating 12+ redundant .lower() calls per API cycle across __init__, _build_api_kwargs, _supports_reasoning_extra_body, and the main conversation loop - Consolidate three separate load_config() disk reads in __init__ (memory, skills, compression) into a single call, reusing the result dict for all three config sections model_tools.py: - Hoist _READ_SEARCH_TOOLS set to module level (was rebuilt inside handle_function_call on every tool invocation) * Use endpoint metadata for custom model context and pricing --------- Co-authored-by: kshitij <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
parent
11f029c311
commit
a2440f72f6
7 changed files with 375 additions and 49 deletions
|
|
@ -45,16 +45,18 @@ class ContextCompressor:
|
||||||
quiet_mode: bool = False,
|
quiet_mode: bool = False,
|
||||||
summary_model_override: str = None,
|
summary_model_override: str = None,
|
||||||
base_url: str = "",
|
base_url: str = "",
|
||||||
|
api_key: str = "",
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
self.threshold_percent = threshold_percent
|
self.threshold_percent = threshold_percent
|
||||||
self.protect_first_n = protect_first_n
|
self.protect_first_n = protect_first_n
|
||||||
self.protect_last_n = protect_last_n
|
self.protect_last_n = protect_last_n
|
||||||
self.summary_target_tokens = summary_target_tokens
|
self.summary_target_tokens = summary_target_tokens
|
||||||
self.quiet_mode = quiet_mode
|
self.quiet_mode = quiet_mode
|
||||||
|
|
||||||
self.context_length = get_model_context_length(model, base_url=base_url)
|
self.context_length = get_model_context_length(model, base_url=base_url, api_key=api_key)
|
||||||
self.threshold_tokens = int(self.context_length * threshold_percent)
|
self.threshold_tokens = int(self.context_length * threshold_percent)
|
||||||
self.compression_count = 0
|
self.compression_count = 0
|
||||||
self._context_probed = False # True after a step-down from context error
|
self._context_probed = False # True after a step-down from context error
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import re
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -21,6 +22,9 @@ logger = logging.getLogger(__name__)
|
||||||
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
_model_metadata_cache_time: float = 0
|
_model_metadata_cache_time: float = 0
|
||||||
_MODEL_CACHE_TTL = 3600
|
_MODEL_CACHE_TTL = 3600
|
||||||
|
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
||||||
|
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
|
||||||
|
_ENDPOINT_MODEL_CACHE_TTL = 300
|
||||||
|
|
||||||
# Descending tiers for context length probing when the model is unknown.
|
# Descending tiers for context length probing when the model is unknown.
|
||||||
# We start high and step down on context-length errors until one works.
|
# We start high and step down on context-length errors until one works.
|
||||||
|
|
@ -123,6 +127,128 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||||
"qwen-vl-max": 32768,
|
"qwen-vl-max": 32768,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_CONTEXT_LENGTH_KEYS = (
|
||||||
|
"context_length",
|
||||||
|
"context_window",
|
||||||
|
"max_context_length",
|
||||||
|
"max_position_embeddings",
|
||||||
|
"max_model_len",
|
||||||
|
"max_input_tokens",
|
||||||
|
"max_sequence_length",
|
||||||
|
"max_seq_len",
|
||||||
|
)
|
||||||
|
|
||||||
|
_MAX_COMPLETION_KEYS = (
|
||||||
|
"max_completion_tokens",
|
||||||
|
"max_output_tokens",
|
||||||
|
"max_tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_base_url(base_url: str) -> str:
|
||||||
|
return (base_url or "").strip().rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_openrouter_base_url(base_url: str) -> bool:
|
||||||
|
return "openrouter.ai" in _normalize_base_url(base_url).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_custom_endpoint(base_url: str) -> bool:
|
||||||
|
normalized = _normalize_base_url(base_url)
|
||||||
|
return bool(normalized) and not _is_openrouter_base_url(normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_known_provider_base_url(base_url: str) -> bool:
|
||||||
|
normalized = _normalize_base_url(base_url)
|
||||||
|
if not normalized:
|
||||||
|
return False
|
||||||
|
parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}")
|
||||||
|
host = parsed.netloc.lower() or parsed.path.lower()
|
||||||
|
known_hosts = (
|
||||||
|
"api.openai.com",
|
||||||
|
"chatgpt.com",
|
||||||
|
"api.anthropic.com",
|
||||||
|
"api.z.ai",
|
||||||
|
"api.moonshot.ai",
|
||||||
|
"api.kimi.com",
|
||||||
|
"api.minimax",
|
||||||
|
)
|
||||||
|
return any(known_host in host for known_host in known_hosts)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_nested_dicts(value: Any):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
yield value
|
||||||
|
for nested in value.values():
|
||||||
|
yield from _iter_nested_dicts(nested)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
yield from _iter_nested_dicts(item)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]:
|
||||||
|
try:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = value.strip().replace(",", "")
|
||||||
|
result = int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
if minimum <= result <= maximum:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]:
|
||||||
|
keyset = {key.lower() for key in keys}
|
||||||
|
for mapping in _iter_nested_dicts(payload):
|
||||||
|
for key, value in mapping.items():
|
||||||
|
if str(key).lower() not in keyset:
|
||||||
|
continue
|
||||||
|
coerced = _coerce_reasonable_int(value)
|
||||||
|
if coerced is not None:
|
||||||
|
return coerced
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]:
|
||||||
|
return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
|
||||||
|
return _extract_first_int(payload, _MAX_COMPLETION_KEYS)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
alias_map = {
|
||||||
|
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
|
||||||
|
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
|
||||||
|
"request": ("request", "request_cost"),
|
||||||
|
"cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"),
|
||||||
|
"cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"),
|
||||||
|
}
|
||||||
|
for mapping in _iter_nested_dicts(payload):
|
||||||
|
normalized = {str(key).lower(): value for key, value in mapping.items()}
|
||||||
|
if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()):
|
||||||
|
continue
|
||||||
|
pricing: Dict[str, Any] = {}
|
||||||
|
for target, aliases in alias_map.items():
|
||||||
|
for alias in aliases:
|
||||||
|
if alias in normalized and normalized[alias] not in (None, ""):
|
||||||
|
pricing[target] = normalized[alias]
|
||||||
|
break
|
||||||
|
if pricing:
|
||||||
|
return pricing
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None:
|
||||||
|
cache[model_id] = entry
|
||||||
|
if "/" in model_id:
|
||||||
|
bare_model = model_id.split("/", 1)[1]
|
||||||
|
cache.setdefault(bare_model, entry)
|
||||||
|
|
||||||
|
|
||||||
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
||||||
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
||||||
|
|
@ -139,15 +265,16 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
|
||||||
cache = {}
|
cache = {}
|
||||||
for model in data.get("data", []):
|
for model in data.get("data", []):
|
||||||
model_id = model.get("id", "")
|
model_id = model.get("id", "")
|
||||||
cache[model_id] = {
|
entry = {
|
||||||
"context_length": model.get("context_length", 128000),
|
"context_length": model.get("context_length", 128000),
|
||||||
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
|
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
|
||||||
"name": model.get("name", model_id),
|
"name": model.get("name", model_id),
|
||||||
"pricing": model.get("pricing", {}),
|
"pricing": model.get("pricing", {}),
|
||||||
}
|
}
|
||||||
|
_add_model_aliases(cache, model_id, entry)
|
||||||
canonical = model.get("canonical_slug", "")
|
canonical = model.get("canonical_slug", "")
|
||||||
if canonical and canonical != model_id:
|
if canonical and canonical != model_id:
|
||||||
cache[canonical] = cache[model_id]
|
_add_model_aliases(cache, canonical, entry)
|
||||||
|
|
||||||
_model_metadata_cache = cache
|
_model_metadata_cache = cache
|
||||||
_model_metadata_cache_time = time.time()
|
_model_metadata_cache_time = time.time()
|
||||||
|
|
@ -159,6 +286,75 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
|
||||||
return _model_metadata_cache or {}
|
return _model_metadata_cache or {}
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_endpoint_model_metadata(
|
||||||
|
base_url: str,
|
||||||
|
api_key: str = "",
|
||||||
|
force_refresh: bool = False,
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Fetch model metadata from an OpenAI-compatible ``/models`` endpoint.
|
||||||
|
|
||||||
|
This is used for explicit custom endpoints where hardcoded global model-name
|
||||||
|
defaults are unreliable. Results are cached in memory per base URL.
|
||||||
|
"""
|
||||||
|
normalized = _normalize_base_url(base_url)
|
||||||
|
if not normalized or _is_openrouter_base_url(normalized):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not force_refresh:
|
||||||
|
cached = _endpoint_model_metadata_cache.get(normalized)
|
||||||
|
cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0)
|
||||||
|
if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
candidates = [normalized]
|
||||||
|
if normalized.endswith("/v1"):
|
||||||
|
alternate = normalized[:-3].rstrip("/")
|
||||||
|
else:
|
||||||
|
alternate = normalized + "/v1"
|
||||||
|
if alternate and alternate not in candidates:
|
||||||
|
candidates.append(alternate)
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
url = candidate.rstrip("/") + "/models"
|
||||||
|
try:
|
||||||
|
response = requests.get(url, headers=headers, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
payload = response.json()
|
||||||
|
cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
for model in payload.get("data", []):
|
||||||
|
if not isinstance(model, dict):
|
||||||
|
continue
|
||||||
|
model_id = model.get("id")
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
|
||||||
|
context_length = _extract_context_length(model)
|
||||||
|
if context_length is not None:
|
||||||
|
entry["context_length"] = context_length
|
||||||
|
max_completion_tokens = _extract_max_completion_tokens(model)
|
||||||
|
if max_completion_tokens is not None:
|
||||||
|
entry["max_completion_tokens"] = max_completion_tokens
|
||||||
|
pricing = _extract_pricing(model)
|
||||||
|
if pricing:
|
||||||
|
entry["pricing"] = pricing
|
||||||
|
_add_model_aliases(cache, model_id, entry)
|
||||||
|
|
||||||
|
_endpoint_model_metadata_cache[normalized] = cache
|
||||||
|
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
||||||
|
return cache
|
||||||
|
except Exception as exc:
|
||||||
|
last_error = exc
|
||||||
|
|
||||||
|
if last_error:
|
||||||
|
logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error)
|
||||||
|
_endpoint_model_metadata_cache[normalized] = {}
|
||||||
|
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _get_context_cache_path() -> Path:
|
def _get_context_cache_path() -> Path:
|
||||||
"""Return path to the persistent context length cache file."""
|
"""Return path to the persistent context length cache file."""
|
||||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||||
|
|
@ -243,14 +439,15 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_model_context_length(model: str, base_url: str = "") -> int:
|
def get_model_context_length(model: str, base_url: str = "", api_key: str = "") -> int:
|
||||||
"""Get the context length for a model.
|
"""Get the context length for a model.
|
||||||
|
|
||||||
Resolution order:
|
Resolution order:
|
||||||
1. Persistent cache (previously discovered via probing)
|
1. Persistent cache (previously discovered via probing)
|
||||||
2. OpenRouter API metadata
|
2. Active endpoint metadata (/models for explicit custom endpoints)
|
||||||
3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match)
|
3. OpenRouter API metadata
|
||||||
4. First probe tier (2M) — will be narrowed on first context error
|
4. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
|
||||||
|
5. First probe tier (2M) — will be narrowed on first context error
|
||||||
"""
|
"""
|
||||||
# 1. Check persistent cache (model+provider)
|
# 1. Check persistent cache (model+provider)
|
||||||
if base_url:
|
if base_url:
|
||||||
|
|
@ -258,19 +455,31 @@ def get_model_context_length(model: str, base_url: str = "") -> int:
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
return cached
|
return cached
|
||||||
|
|
||||||
# 2. OpenRouter API metadata
|
# 2. Active endpoint metadata for explicit custom routes
|
||||||
|
if _is_custom_endpoint(base_url):
|
||||||
|
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
|
||||||
|
if model in endpoint_metadata:
|
||||||
|
context_length = endpoint_metadata[model].get("context_length")
|
||||||
|
if isinstance(context_length, int):
|
||||||
|
return context_length
|
||||||
|
if not _is_known_provider_base_url(base_url):
|
||||||
|
# Explicit third-party endpoints should not borrow fuzzy global
|
||||||
|
# defaults from unrelated providers with similarly named models.
|
||||||
|
return CONTEXT_PROBE_TIERS[0]
|
||||||
|
|
||||||
|
# 3. OpenRouter API metadata
|
||||||
metadata = fetch_model_metadata()
|
metadata = fetch_model_metadata()
|
||||||
if model in metadata:
|
if model in metadata:
|
||||||
return metadata[model].get("context_length", 128000)
|
return metadata[model].get("context_length", 128000)
|
||||||
|
|
||||||
# 3. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
# 4. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
||||||
for default_model, length in sorted(
|
for default_model, length in sorted(
|
||||||
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
||||||
):
|
):
|
||||||
if default_model in model or model in default_model:
|
if default_model in model or model in default_model:
|
||||||
return length
|
return length
|
||||||
|
|
||||||
# 4. Unknown model — start at highest probe tier
|
# 5. Unknown model — start at highest probe tier
|
||||||
return CONTEXT_PROBE_TIERS[0]
|
return CONTEXT_PROBE_TIERS[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from datetime import datetime, timezone
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, Dict, Literal, Optional
|
from typing import Any, Dict, Literal, Optional
|
||||||
|
|
||||||
from agent.model_metadata import fetch_model_metadata
|
from agent.model_metadata import fetch_endpoint_model_metadata, fetch_model_metadata
|
||||||
|
|
||||||
DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
||||||
|
|
||||||
|
|
@ -335,8 +335,21 @@ def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry]
|
||||||
|
|
||||||
|
|
||||||
def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
|
def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
|
||||||
metadata = fetch_model_metadata()
|
return _pricing_entry_from_metadata(
|
||||||
model_id = route.model
|
fetch_model_metadata(),
|
||||||
|
route.model,
|
||||||
|
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models",
|
||||||
|
pricing_version="openrouter-models-api",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pricing_entry_from_metadata(
|
||||||
|
metadata: Dict[str, Dict[str, Any]],
|
||||||
|
model_id: str,
|
||||||
|
*,
|
||||||
|
source_url: str,
|
||||||
|
pricing_version: str,
|
||||||
|
) -> Optional[PricingEntry]:
|
||||||
if model_id not in metadata:
|
if model_id not in metadata:
|
||||||
return None
|
return None
|
||||||
pricing = metadata[model_id].get("pricing") or {}
|
pricing = metadata[model_id].get("pricing") or {}
|
||||||
|
|
@ -355,6 +368,7 @@ def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
|
||||||
)
|
)
|
||||||
if prompt is None and completion is None and request is None:
|
if prompt is None and completion is None and request is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]:
|
def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -367,8 +381,8 @@ def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
|
||||||
cache_write_cost_per_million=_per_token_to_per_million(cache_write),
|
cache_write_cost_per_million=_per_token_to_per_million(cache_write),
|
||||||
request_cost=request,
|
request_cost=request,
|
||||||
source="provider_models_api",
|
source="provider_models_api",
|
||||||
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models",
|
source_url=source_url,
|
||||||
pricing_version="openrouter-models-api",
|
pricing_version=pricing_version,
|
||||||
fetched_at=_UTC_NOW(),
|
fetched_at=_UTC_NOW(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -377,6 +391,7 @@ def get_pricing_entry(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
) -> Optional[PricingEntry]:
|
) -> Optional[PricingEntry]:
|
||||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||||
if route.billing_mode == "subscription_included":
|
if route.billing_mode == "subscription_included":
|
||||||
|
|
@ -390,6 +405,15 @@ def get_pricing_entry(
|
||||||
)
|
)
|
||||||
if route.provider == "openrouter":
|
if route.provider == "openrouter":
|
||||||
return _openrouter_pricing_entry(route)
|
return _openrouter_pricing_entry(route)
|
||||||
|
if route.base_url:
|
||||||
|
entry = _pricing_entry_from_metadata(
|
||||||
|
fetch_endpoint_model_metadata(route.base_url, api_key=api_key or ""),
|
||||||
|
route.model,
|
||||||
|
source_url=f"{route.base_url.rstrip('/')}/models",
|
||||||
|
pricing_version="openai-compatible-models-api",
|
||||||
|
)
|
||||||
|
if entry:
|
||||||
|
return entry
|
||||||
return _lookup_official_docs_pricing(route)
|
return _lookup_official_docs_pricing(route)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -460,6 +484,7 @@ def estimate_usage_cost(
|
||||||
*,
|
*,
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
) -> CostResult:
|
) -> CostResult:
|
||||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||||
if route.billing_mode == "subscription_included":
|
if route.billing_mode == "subscription_included":
|
||||||
|
|
@ -471,7 +496,7 @@ def estimate_usage_cost(
|
||||||
pricing_version="included-route",
|
pricing_version="included-route",
|
||||||
)
|
)
|
||||||
|
|
||||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
|
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
|
||||||
if not entry:
|
if not entry:
|
||||||
return CostResult(amount_usd=None, status="unknown", source="none", label="n/a")
|
return CostResult(amount_usd=None, status="unknown", source="none", label="n/a")
|
||||||
|
|
||||||
|
|
@ -536,6 +561,7 @@ def has_known_pricing(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check whether we have pricing data for this model+route.
|
"""Check whether we have pricing data for this model+route.
|
||||||
|
|
||||||
|
|
@ -545,7 +571,7 @@ def has_known_pricing(
|
||||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||||
if route.billing_mode == "subscription_included":
|
if route.billing_mode == "subscription_included":
|
||||||
return True
|
return True
|
||||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
|
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
|
||||||
return entry is not None
|
return entry is not None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -553,13 +579,14 @@ def get_pricing(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""Backward-compatible thin wrapper for legacy callers.
|
"""Backward-compatible thin wrapper for legacy callers.
|
||||||
|
|
||||||
Returns only non-cache input/output fields when a pricing entry exists.
|
Returns only non-cache input/output fields when a pricing entry exists.
|
||||||
Unknown routes return zeroes.
|
Unknown routes return zeroes.
|
||||||
"""
|
"""
|
||||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
|
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
|
||||||
if not entry:
|
if not entry:
|
||||||
return {"input": 0.0, "output": 0.0}
|
return {"input": 0.0, "output": 0.0}
|
||||||
return {
|
return {
|
||||||
|
|
@ -575,6 +602,7 @@ def estimate_cost_usd(
|
||||||
*,
|
*,
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Backward-compatible helper for legacy callers.
|
"""Backward-compatible helper for legacy callers.
|
||||||
|
|
||||||
|
|
@ -586,6 +614,7 @@ def estimate_cost_usd(
|
||||||
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
|
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
|
||||||
provider=provider,
|
provider=provider,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
return float(result.amount_usd or _ZERO)
|
return float(result.amount_usd or _ZERO)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -276,6 +276,7 @@ def get_tool_definitions(
|
||||||
# The registry still holds their schemas; dispatch just returns a stub error
|
# The registry still holds their schemas; dispatch just returns a stub error
|
||||||
# so if something slips through, the LLM sees a sensible message.
|
# so if something slips through, the LLM sees a sensible message.
|
||||||
_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
|
_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
|
||||||
|
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
|
||||||
|
|
||||||
|
|
||||||
def handle_function_call(
|
def handle_function_call(
|
||||||
|
|
@ -305,7 +306,6 @@ def handle_function_call(
|
||||||
"""
|
"""
|
||||||
# Notify the read-loop tracker when a non-read/search tool runs,
|
# Notify the read-loop tracker when a non-read/search tool runs,
|
||||||
# so the *consecutive* counter resets (reads after other work are fine).
|
# so the *consecutive* counter resets (reads after other work are fine).
|
||||||
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
|
|
||||||
if function_name not in _READ_SEARCH_TOOLS:
|
if function_name not in _READ_SEARCH_TOOLS:
|
||||||
try:
|
try:
|
||||||
from tools.file_tools import notify_other_tool_call
|
from tools.file_tools import notify_other_tool_call
|
||||||
|
|
|
||||||
73
run_agent.py
73
run_agent.py
|
|
@ -263,11 +263,20 @@ def _inject_honcho_turn_context(content, turn_context: str):
|
||||||
class AIAgent:
|
class AIAgent:
|
||||||
"""
|
"""
|
||||||
AI Agent with tool calling capabilities.
|
AI Agent with tool calling capabilities.
|
||||||
|
|
||||||
This class manages the conversation flow, tool execution, and response handling
|
This class manages the conversation flow, tool execution, and response handling
|
||||||
for AI models that support function calling.
|
for AI models that support function calling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_url(self) -> str:
|
||||||
|
return self._base_url
|
||||||
|
|
||||||
|
@base_url.setter
|
||||||
|
def base_url(self, value: str) -> None:
|
||||||
|
self._base_url = value
|
||||||
|
self._base_url_lower = value.lower() if value else ""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
|
|
@ -383,10 +392,10 @@ class AIAgent:
|
||||||
self.api_mode = api_mode
|
self.api_mode = api_mode
|
||||||
elif self.provider == "openai-codex":
|
elif self.provider == "openai-codex":
|
||||||
self.api_mode = "codex_responses"
|
self.api_mode = "codex_responses"
|
||||||
elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self.base_url.lower():
|
elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self._base_url_lower:
|
||||||
self.api_mode = "codex_responses"
|
self.api_mode = "codex_responses"
|
||||||
self.provider = "openai-codex"
|
self.provider = "openai-codex"
|
||||||
elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self.base_url.lower()):
|
elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self._base_url_lower):
|
||||||
self.api_mode = "anthropic_messages"
|
self.api_mode = "anthropic_messages"
|
||||||
self.provider = "anthropic"
|
self.provider = "anthropic"
|
||||||
else:
|
else:
|
||||||
|
|
@ -395,7 +404,7 @@ class AIAgent:
|
||||||
# Pre-warm OpenRouter model metadata cache in a background thread.
|
# Pre-warm OpenRouter model metadata cache in a background thread.
|
||||||
# fetch_model_metadata() is cached for 1 hour; this avoids a blocking
|
# fetch_model_metadata() is cached for 1 hour; this avoids a blocking
|
||||||
# HTTP request on the first API response when pricing is estimated.
|
# HTTP request on the first API response when pricing is estimated.
|
||||||
if self.provider == "openrouter" or "openrouter" in self.base_url.lower():
|
if self.provider == "openrouter" or "openrouter" in self._base_url_lower:
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=lambda: fetch_model_metadata(),
|
target=lambda: fetch_model_metadata(),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
|
|
@ -439,7 +448,7 @@ class AIAgent:
|
||||||
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
|
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
|
||||||
# Reduces input costs by ~75% on multi-turn conversations by caching the
|
# Reduces input costs by ~75% on multi-turn conversations by caching the
|
||||||
# conversation prefix. Uses system_and_3 strategy (4 breakpoints).
|
# conversation prefix. Uses system_and_3 strategy (4 breakpoints).
|
||||||
is_openrouter = "openrouter" in self.base_url.lower()
|
is_openrouter = "openrouter" in self._base_url_lower
|
||||||
is_claude = "claude" in self.model.lower()
|
is_claude = "claude" in self.model.lower()
|
||||||
is_native_anthropic = self.api_mode == "anthropic_messages"
|
is_native_anthropic = self.api_mode == "anthropic_messages"
|
||||||
self._use_prompt_caching = (is_openrouter and is_claude) or is_native_anthropic
|
self._use_prompt_caching = (is_openrouter and is_claude) or is_native_anthropic
|
||||||
|
|
@ -555,6 +564,7 @@ class AIAgent:
|
||||||
if self.api_mode == "anthropic_messages":
|
if self.api_mode == "anthropic_messages":
|
||||||
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
||||||
effective_key = api_key or resolve_anthropic_token() or ""
|
effective_key = api_key or resolve_anthropic_token() or ""
|
||||||
|
self.api_key = effective_key
|
||||||
self._anthropic_api_key = effective_key
|
self._anthropic_api_key = effective_key
|
||||||
self._anthropic_base_url = base_url
|
self._anthropic_base_url = base_url
|
||||||
from agent.anthropic_adapter import _is_oauth_token as _is_oat
|
from agent.anthropic_adapter import _is_oauth_token as _is_oat
|
||||||
|
|
@ -609,6 +619,7 @@ class AIAgent:
|
||||||
}
|
}
|
||||||
|
|
||||||
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
|
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
|
||||||
|
self.api_key = client_kwargs.get("api_key", "")
|
||||||
try:
|
try:
|
||||||
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
|
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
|
|
@ -732,6 +743,13 @@ class AIAgent:
|
||||||
from tools.todo_tool import TodoStore
|
from tools.todo_tool import TodoStore
|
||||||
self._todo_store = TodoStore()
|
self._todo_store = TodoStore()
|
||||||
|
|
||||||
|
# Load config once for memory, skills, and compression sections
|
||||||
|
try:
|
||||||
|
from hermes_cli.config import load_config as _load_agent_config
|
||||||
|
_agent_cfg = _load_agent_config()
|
||||||
|
except Exception:
|
||||||
|
_agent_cfg = {}
|
||||||
|
|
||||||
# Persistent memory (MEMORY.md + USER.md) -- loaded from disk
|
# Persistent memory (MEMORY.md + USER.md) -- loaded from disk
|
||||||
self._memory_store = None
|
self._memory_store = None
|
||||||
self._memory_enabled = False
|
self._memory_enabled = False
|
||||||
|
|
@ -742,8 +760,7 @@ class AIAgent:
|
||||||
self._iters_since_skill = 0
|
self._iters_since_skill = 0
|
||||||
if not skip_memory:
|
if not skip_memory:
|
||||||
try:
|
try:
|
||||||
from hermes_cli.config import load_config as _load_mem_config
|
mem_config = _agent_cfg.get("memory", {})
|
||||||
mem_config = _load_mem_config().get("memory", {})
|
|
||||||
self._memory_enabled = mem_config.get("memory_enabled", False)
|
self._memory_enabled = mem_config.get("memory_enabled", False)
|
||||||
self._user_profile_enabled = mem_config.get("user_profile_enabled", False)
|
self._user_profile_enabled = mem_config.get("user_profile_enabled", False)
|
||||||
self._memory_nudge_interval = int(mem_config.get("nudge_interval", 10))
|
self._memory_nudge_interval = int(mem_config.get("nudge_interval", 10))
|
||||||
|
|
@ -831,21 +848,16 @@ class AIAgent:
|
||||||
# Skills config: nudge interval for skill creation reminders
|
# Skills config: nudge interval for skill creation reminders
|
||||||
self._skill_nudge_interval = 10
|
self._skill_nudge_interval = 10
|
||||||
try:
|
try:
|
||||||
from hermes_cli.config import load_config as _load_skills_config
|
skills_config = _agent_cfg.get("skills", {})
|
||||||
skills_config = _load_skills_config().get("skills", {})
|
|
||||||
self._skill_nudge_interval = int(skills_config.get("creation_nudge_interval", 15))
|
self._skill_nudge_interval = int(skills_config.get("creation_nudge_interval", 15))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Initialize context compressor for automatic context management
|
# Initialize context compressor for automatic context management
|
||||||
# Compresses conversation when approaching model's context limit
|
# Compresses conversation when approaching model's context limit
|
||||||
# Configuration via config.yaml (compression section)
|
# Configuration via config.yaml (compression section)
|
||||||
try:
|
_compression_cfg = _agent_cfg.get("compression", {})
|
||||||
from hermes_cli.config import load_config as _load_compression_config
|
if not isinstance(_compression_cfg, dict):
|
||||||
_compression_cfg = _load_compression_config().get("compression", {})
|
|
||||||
if not isinstance(_compression_cfg, dict):
|
|
||||||
_compression_cfg = {}
|
|
||||||
except ImportError:
|
|
||||||
_compression_cfg = {}
|
_compression_cfg = {}
|
||||||
compression_threshold = float(_compression_cfg.get("threshold", 0.50))
|
compression_threshold = float(_compression_cfg.get("threshold", 0.50))
|
||||||
compression_enabled = str(_compression_cfg.get("enabled", True)).lower() in ("true", "1", "yes")
|
compression_enabled = str(_compression_cfg.get("enabled", True)).lower() in ("true", "1", "yes")
|
||||||
|
|
@ -860,6 +872,7 @@ class AIAgent:
|
||||||
summary_model_override=compression_summary_model,
|
summary_model_override=compression_summary_model,
|
||||||
quiet_mode=self.quiet_mode,
|
quiet_mode=self.quiet_mode,
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
|
api_key=getattr(self, "api_key", ""),
|
||||||
)
|
)
|
||||||
self.compression_enabled = compression_enabled
|
self.compression_enabled = compression_enabled
|
||||||
self._user_turn_count = 0
|
self._user_turn_count = 0
|
||||||
|
|
@ -915,8 +928,8 @@ class AIAgent:
|
||||||
OpenAI models use 'max_tokens'.
|
OpenAI models use 'max_tokens'.
|
||||||
"""
|
"""
|
||||||
_is_direct_openai = (
|
_is_direct_openai = (
|
||||||
"api.openai.com" in self.base_url.lower()
|
"api.openai.com" in self._base_url_lower
|
||||||
and "openrouter" not in self.base_url.lower()
|
and "openrouter" not in self._base_url_lower
|
||||||
)
|
)
|
||||||
if _is_direct_openai:
|
if _is_direct_openai:
|
||||||
return {"max_completion_tokens": value}
|
return {"max_completion_tokens": value}
|
||||||
|
|
@ -3643,7 +3656,7 @@ class AIAgent:
|
||||||
|
|
||||||
extra_body = {}
|
extra_body = {}
|
||||||
|
|
||||||
_is_openrouter = "openrouter" in self.base_url.lower()
|
_is_openrouter = "openrouter" in self._base_url_lower
|
||||||
|
|
||||||
# Provider preferences (only, ignore, order, sort) are OpenRouter-
|
# Provider preferences (only, ignore, order, sort) are OpenRouter-
|
||||||
# specific. Only send to OpenRouter-compatible endpoints.
|
# specific. Only send to OpenRouter-compatible endpoints.
|
||||||
|
|
@ -3651,7 +3664,7 @@ class AIAgent:
|
||||||
# for _is_nous when their backend is updated.
|
# for _is_nous when their backend is updated.
|
||||||
if provider_preferences and _is_openrouter:
|
if provider_preferences and _is_openrouter:
|
||||||
extra_body["provider"] = provider_preferences
|
extra_body["provider"] = provider_preferences
|
||||||
_is_nous = "nousresearch" in self.base_url.lower()
|
_is_nous = "nousresearch" in self._base_url_lower
|
||||||
|
|
||||||
if self._supports_reasoning_extra_body():
|
if self._supports_reasoning_extra_body():
|
||||||
if self.reasoning_config is not None:
|
if self.reasoning_config is not None:
|
||||||
|
|
@ -3684,14 +3697,13 @@ class AIAgent:
|
||||||
Some providers/routes reject `reasoning` with 400s, so gate it to
|
Some providers/routes reject `reasoning` with 400s, so gate it to
|
||||||
known reasoning-capable model families and direct Nous Portal.
|
known reasoning-capable model families and direct Nous Portal.
|
||||||
"""
|
"""
|
||||||
base_url = (self.base_url or "").lower()
|
if "nousresearch" in self._base_url_lower:
|
||||||
if "nousresearch" in base_url:
|
|
||||||
return True
|
return True
|
||||||
if "ai-gateway.vercel.sh" in base_url:
|
if "ai-gateway.vercel.sh" in self._base_url_lower:
|
||||||
return True
|
return True
|
||||||
if "openrouter" not in base_url:
|
if "openrouter" not in self._base_url_lower:
|
||||||
return False
|
return False
|
||||||
if "api.mistral.ai" in base_url:
|
if "api.mistral.ai" in self._base_url_lower:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
model = (self.model or "").lower()
|
model = (self.model or "").lower()
|
||||||
|
|
@ -3877,7 +3889,7 @@ class AIAgent:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build API messages for the flush call
|
# Build API messages for the flush call
|
||||||
_is_strict_api = "api.mistral.ai" in self.base_url.lower()
|
_is_strict_api = "api.mistral.ai" in self._base_url_lower
|
||||||
api_messages = []
|
api_messages = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
api_msg = msg.copy()
|
api_msg = msg.copy()
|
||||||
|
|
@ -4653,7 +4665,7 @@ class AIAgent:
|
||||||
try:
|
try:
|
||||||
# Build API messages, stripping internal-only fields
|
# Build API messages, stripping internal-only fields
|
||||||
# (finish_reason, reasoning) that strict APIs like Mistral reject with 422
|
# (finish_reason, reasoning) that strict APIs like Mistral reject with 422
|
||||||
_is_strict_api = "api.mistral.ai" in self.base_url.lower()
|
_is_strict_api = "api.mistral.ai" in self._base_url_lower
|
||||||
api_messages = []
|
api_messages = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
api_msg = msg.copy()
|
api_msg = msg.copy()
|
||||||
|
|
@ -4674,7 +4686,7 @@ class AIAgent:
|
||||||
api_messages.insert(sys_offset + idx, pfm.copy())
|
api_messages.insert(sys_offset + idx, pfm.copy())
|
||||||
|
|
||||||
summary_extra_body = {}
|
summary_extra_body = {}
|
||||||
_is_nous = "nousresearch" in self.base_url.lower()
|
_is_nous = "nousresearch" in self._base_url_lower
|
||||||
if self._supports_reasoning_extra_body():
|
if self._supports_reasoning_extra_body():
|
||||||
if self.reasoning_config is not None:
|
if self.reasoning_config is not None:
|
||||||
summary_extra_body["reasoning"] = self.reasoning_config
|
summary_extra_body["reasoning"] = self.reasoning_config
|
||||||
|
|
@ -5092,7 +5104,7 @@ class AIAgent:
|
||||||
# strict providers like Mistral that reject unknown fields with 422.
|
# strict providers like Mistral that reject unknown fields with 422.
|
||||||
# Uses new dicts so the internal messages list retains the fields
|
# Uses new dicts so the internal messages list retains the fields
|
||||||
# for Codex Responses compatibility.
|
# for Codex Responses compatibility.
|
||||||
if "api.mistral.ai" in self.base_url.lower():
|
if "api.mistral.ai" in self._base_url_lower:
|
||||||
self._sanitize_tool_calls_for_strict_api(api_msg)
|
self._sanitize_tool_calls_for_strict_api(api_msg)
|
||||||
# Keep 'reasoning_details' - OpenRouter uses this for multi-turn reasoning context
|
# Keep 'reasoning_details' - OpenRouter uses this for multi-turn reasoning context
|
||||||
# The signature field helps maintain reasoning continuity
|
# The signature field helps maintain reasoning continuity
|
||||||
|
|
@ -5464,6 +5476,7 @@ class AIAgent:
|
||||||
canonical_usage,
|
canonical_usage,
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
|
api_key=getattr(self, "api_key", ""),
|
||||||
)
|
)
|
||||||
if cost_result.amount_usd is not None:
|
if cost_result.amount_usd is not None:
|
||||||
self.session_estimated_cost_usd += float(cost_result.amount_usd)
|
self.session_estimated_cost_usd += float(cost_result.amount_usd)
|
||||||
|
|
|
||||||
|
|
@ -188,6 +188,36 @@ class TestGetModelContextLength:
|
||||||
result = get_model_context_length("custom/model")
|
result = get_model_context_length("custom/model")
|
||||||
assert result == CONTEXT_PROBE_TIERS[0]
|
assert result == CONTEXT_PROBE_TIERS[0]
|
||||||
|
|
||||||
|
@patch("agent.model_metadata.fetch_model_metadata")
|
||||||
|
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
||||||
|
def test_custom_endpoint_metadata_beats_fuzzy_default(self, mock_endpoint_fetch, mock_fetch):
|
||||||
|
mock_fetch.return_value = {}
|
||||||
|
mock_endpoint_fetch.return_value = {
|
||||||
|
"zai-org/GLM-5-TEE": {"context_length": 65536}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = get_model_context_length(
|
||||||
|
"zai-org/GLM-5-TEE",
|
||||||
|
base_url="https://llm.chutes.ai/v1",
|
||||||
|
api_key="test-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == 65536
|
||||||
|
|
||||||
|
@patch("agent.model_metadata.fetch_model_metadata")
|
||||||
|
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
||||||
|
def test_custom_endpoint_without_metadata_skips_name_based_default(self, mock_endpoint_fetch, mock_fetch):
|
||||||
|
mock_fetch.return_value = {}
|
||||||
|
mock_endpoint_fetch.return_value = {}
|
||||||
|
|
||||||
|
result = get_model_context_length(
|
||||||
|
"zai-org/GLM-5-TEE",
|
||||||
|
base_url="https://llm.chutes.ai/v1",
|
||||||
|
api_key="test-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == CONTEXT_PROBE_TIERS[0]
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# fetch_model_metadata — caching, TTL, slugs, failures
|
# fetch_model_metadata — caching, TTL, slugs, failures
|
||||||
|
|
@ -258,6 +288,25 @@ class TestFetchModelMetadata:
|
||||||
assert "anthropic/claude-3.5-sonnet" in result
|
assert "anthropic/claude-3.5-sonnet" in result
|
||||||
assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000
|
assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000
|
||||||
|
|
||||||
|
@patch("agent.model_metadata.requests.get")
|
||||||
|
def test_provider_prefixed_models_get_bare_aliases(self, mock_get):
|
||||||
|
self._reset_cache()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"data": [{
|
||||||
|
"id": "provider/test-model",
|
||||||
|
"context_length": 123456,
|
||||||
|
"name": "Provider: Test Model",
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
result = fetch_model_metadata(force_refresh=True)
|
||||||
|
|
||||||
|
assert result["provider/test-model"]["context_length"] == 123456
|
||||||
|
assert result["test-model"]["context_length"] == 123456
|
||||||
|
|
||||||
@patch("agent.model_metadata.requests.get")
|
@patch("agent.model_metadata.requests.get")
|
||||||
def test_ttl_expiry_triggers_refetch(self, mock_get):
|
def test_ttl_expiry_triggers_refetch(self, mock_get):
|
||||||
"""Cache expires after _MODEL_CACHE_TTL seconds."""
|
"""Cache expires after _MODEL_CACHE_TTL seconds."""
|
||||||
|
|
|
||||||
|
|
@ -99,3 +99,27 @@ def test_estimate_usage_cost_refuses_cache_pricing_without_official_cache_rate(m
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.status == "unknown"
|
assert result.status == "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_endpoint_models_api_pricing_is_supported(monkeypatch):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"agent.usage_pricing.fetch_endpoint_model_metadata",
|
||||||
|
lambda base_url, api_key=None: {
|
||||||
|
"zai-org/GLM-5-TEE": {
|
||||||
|
"pricing": {
|
||||||
|
"prompt": "0.0000005",
|
||||||
|
"completion": "0.000002",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = get_pricing_entry(
|
||||||
|
"zai-org/GLM-5-TEE",
|
||||||
|
provider="custom",
|
||||||
|
base_url="https://llm.chutes.ai/v1",
|
||||||
|
api_key="test-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert float(entry.input_cost_per_million) == 0.5
|
||||||
|
assert float(entry.output_cost_per_million) == 2.0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue