feat: smart context length probing with persistent caching + banner display

Replaces the unsafe 128K fallback for unknown models with a descending
probe strategy (2M → 1M → 512K → 200K → 128K → 64K → 32K). When a
context-length error occurs, the agent steps down tiers and retries.
The discovered limit is cached per model+provider combo in
~/.hermes/context_length_cache.yaml so subsequent sessions skip probing.

Also parses API error messages to extract the actual context limit
(e.g. 'maximum context length is 32768 tokens') for instant resolution.

The CLI banner now displays the context window size next to the model
name (e.g. 'claude-opus-4 · 200K context · Nous Research').

Changes:
- agent/model_metadata.py: CONTEXT_PROBE_TIERS, persistent cache
  (save/load/get), parse_context_limit_from_error(), get_next_probe_tier()
- agent/context_compressor.py: accepts base_url, passes to metadata
- run_agent.py: step-down logic in context error handler, caches on success
- cli.py + hermes_cli/banner.py: context length in welcome banner
- tests: 22 new tests for probing, parsing, and caching

Addresses #132. PR #319's approach (8K default) rejected — too conservative.
This commit is contained in:
teknium1 2026-03-05 16:09:57 -08:00
parent 55b173dd03
commit c886333d32
6 changed files with 324 additions and 16 deletions

View file

@ -34,17 +34,20 @@ class ContextCompressor:
summary_target_tokens: int = 2500,
quiet_mode: bool = False,
summary_model_override: str = None,
base_url: str = "",
):
self.model = model
self.base_url = base_url
self.threshold_percent = threshold_percent
self.protect_first_n = protect_first_n
self.protect_last_n = protect_last_n
self.summary_target_tokens = summary_target_tokens
self.quiet_mode = quiet_mode
self.context_length = get_model_context_length(model)
self.context_length = get_model_context_length(model, base_url=base_url)
self.threshold_tokens = int(self.context_length * threshold_percent)
self.compression_count = 0
self._context_probed = False # True after a step-down from context error
self.last_prompt_tokens = 0
self.last_completion_tokens = 0

View file

@ -5,10 +5,14 @@ and run_agent.py for pre-flight context checks.
"""
import logging
import os
import re
import time
from typing import Any, Dict, List
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
import yaml
from hermes_constants import OPENROUTER_MODELS_URL
@ -18,6 +22,18 @@ _model_metadata_cache: Dict[str, Dict[str, Any]] = {}
_model_metadata_cache_time: float = 0
_MODEL_CACHE_TTL = 3600
# Descending tiers for context length probing when the model is unknown.
# We start high and step down on context-length errors until one works.
CONTEXT_PROBE_TIERS = [
2_000_000,
1_000_000,
512_000,
200_000,
128_000,
64_000,
32_000,
]
DEFAULT_CONTEXT_LENGTHS = {
"anthropic/claude-opus-4": 200000,
"anthropic/claude-opus-4.5": 200000,
@ -71,17 +87,115 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
return _model_metadata_cache or {}
def get_model_context_length(model: str) -> int:
"""Get the context length for a model (API first, then fallback defaults)."""
def _get_context_cache_path() -> Path:
"""Return path to the persistent context length cache file."""
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
return hermes_home / "context_length_cache.yaml"
def _load_context_cache() -> Dict[str, int]:
"""Load the model+provider → context_length cache from disk."""
path = _get_context_cache_path()
if not path.exists():
return {}
try:
with open(path) as f:
data = yaml.safe_load(f) or {}
return data.get("context_lengths", {})
except Exception as e:
logger.debug("Failed to load context length cache: %s", e)
return {}
def save_context_length(model: str, base_url: str, length: int) -> None:
"""Persist a discovered context length for a model+provider combo.
Cache key is ``model@base_url`` so the same model name served from
different providers can have different limits.
"""
key = f"{model}@{base_url}"
cache = _load_context_cache()
if cache.get(key) == length:
return # already stored
cache[key] = length
path = _get_context_cache_path()
try:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
logger.info("Cached context length %s%s tokens", key, f"{length:,}")
except Exception as e:
logger.debug("Failed to save context length cache: %s", e)
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
"""Look up a previously discovered context length for model+provider."""
key = f"{model}@{base_url}"
cache = _load_context_cache()
return cache.get(key)
def get_next_probe_tier(current_length: int) -> Optional[int]:
"""Return the next lower probe tier, or None if already at minimum."""
for tier in CONTEXT_PROBE_TIERS:
if tier < current_length:
return tier
return None
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
"""Try to extract the actual context limit from an API error message.
Many providers include the limit in their error text, e.g.:
- "maximum context length is 32768 tokens"
- "context_length_exceeded: 131072"
- "Maximum context size 32768 exceeded"
- "model's max context length is 65536"
"""
error_lower = error_msg.lower()
# Pattern: look for numbers near context-related keywords
patterns = [
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
]
for pattern in patterns:
match = re.search(pattern, error_lower)
if match:
limit = int(match.group(1))
# Sanity check: must be a reasonable context length
if 1024 <= limit <= 10_000_000:
return limit
return None
def get_model_context_length(model: str, base_url: str = "") -> int:
"""Get the context length for a model.
Resolution order:
1. Persistent cache (previously discovered via probing)
2. OpenRouter API metadata
3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match)
4. First probe tier (2M) will be narrowed on first context error
"""
# 1. Check persistent cache (model+provider)
if base_url:
cached = get_cached_context_length(model, base_url)
if cached is not None:
return cached
# 2. OpenRouter API metadata
metadata = fetch_model_metadata()
if model in metadata:
return metadata[model].get("context_length", 128000)
# 3. Hardcoded defaults (fuzzy match)
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
if default_model in model or model in default_model:
return length
return 128000
# 4. Unknown model — start at highest probe tier
return CONTEXT_PROBE_TIERS[0]
def estimate_tokens_rough(text: str) -> int:

23
cli.py
View file

@ -508,7 +508,18 @@ def _get_available_skills() -> Dict[str, List[str]]:
return skills_by_category
def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None):
def _format_context_length(tokens: int) -> str:
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
if tokens >= 1_000_000:
val = tokens / 1_000_000
return f"{val:g}M"
elif tokens >= 1_000:
val = tokens / 1_000
return f"{val:g}K"
return str(tokens)
def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None, context_length: int = None):
"""
Build and print a Claude Code-style welcome banner with caduceus on left and info on right.
@ -519,6 +530,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic
tools: List of tool definitions
enabled_toolsets: List of enabled toolset names
session_id: Unique session identifier for logging
context_length: Model's context window size in tokens
"""
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
@ -544,7 +556,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic
if len(model_short) > 28:
model_short = model_short[:25] + "..."
left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
left_lines.append(f"[dim #B8860B]{cwd}[/]")
# Add session ID if provided
@ -1079,6 +1092,11 @@ class HermesCLI:
# Get terminal working directory (where commands will execute)
cwd = os.getenv("TERMINAL_CWD", os.getcwd())
# Get context length for display
ctx_len = None
if hasattr(self, 'agent') and self.agent and hasattr(self.agent, 'context_compressor'):
ctx_len = self.agent.context_compressor.context_length
# Build and display the banner
build_welcome_banner(
console=self.console,
@ -1087,6 +1105,7 @@ class HermesCLI:
tools=tools,
enabled_toolsets=self.enabled_toolsets,
session_id=self.session_id,
context_length=ctx_len,
)
# Show tool availability warnings if any tools are disabled

View file

@ -99,11 +99,23 @@ def get_available_skills() -> Dict[str, List[str]]:
# Welcome banner
# =========================================================================
def _format_context_length(tokens: int) -> str:
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
if tokens >= 1_000_000:
val = tokens / 1_000_000
return f"{val:g}M"
elif tokens >= 1_000:
val = tokens / 1_000
return f"{val:g}K"
return str(tokens)
def build_welcome_banner(console: Console, model: str, cwd: str,
tools: List[dict] = None,
enabled_toolsets: List[str] = None,
session_id: str = None,
get_toolset_for_tool=None):
get_toolset_for_tool=None,
context_length: int = None):
"""Build and print a welcome banner with caduceus on left and info on right.
Args:
@ -114,6 +126,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
enabled_toolsets: List of enabled toolset names.
session_id: Session identifier.
get_toolset_for_tool: Callable to map tool name -> toolset name.
context_length: Model's context window size in tokens.
"""
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
if get_toolset_for_tool is None:
@ -135,7 +148,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
model_short = model.split("/")[-1] if "/" in model else model
if len(model_short) > 28:
model_short = model_short[:25] + "..."
left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
left_lines.append(f"[dim #B8860B]{cwd}[/]")
if session_id:
left_lines.append(f"[dim #8B8682]Session: {session_id}[/]")

View file

@ -82,6 +82,8 @@ from agent.prompt_builder import (
from agent.model_metadata import (
fetch_model_metadata, get_model_context_length,
estimate_tokens_rough, estimate_messages_tokens_rough,
get_next_probe_tier, parse_context_limit_from_error,
save_context_length,
)
from agent.context_compressor import ContextCompressor
from agent.prompt_caching import apply_anthropic_cache_control
@ -536,6 +538,7 @@ class AIAgent:
summary_target_tokens=500,
summary_model_override=compression_summary_model,
quiet_mode=self.quiet_mode,
base_url=self.base_url,
)
self.compression_enabled = compression_enabled
self._user_turn_count = 0
@ -3236,6 +3239,13 @@ class AIAgent:
}
self.context_compressor.update_from_response(usage_dict)
# Cache discovered context length after successful call
if self.context_compressor._context_probed:
ctx = self.context_compressor.context_length
save_context_length(self.model, self.base_url, ctx)
print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}")
self.context_compressor._context_probed = False
self.session_prompt_tokens += prompt_tokens
self.session_completion_tokens += completion_tokens
self.session_total_tokens += total_tokens
@ -3355,18 +3365,37 @@ class AIAgent:
])
if is_context_length_error:
print(f"{self.log_prefix}⚠️ Context length exceeded - attempting compression...")
compressor = self.context_compressor
old_ctx = compressor.context_length
# Try to parse the actual limit from the error message
parsed_limit = parse_context_limit_from_error(error_msg)
if parsed_limit and parsed_limit < old_ctx:
new_ctx = parsed_limit
print(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})")
else:
# Step down to the next probe tier
new_ctx = get_next_probe_tier(old_ctx)
if new_ctx and new_ctx < old_ctx:
compressor.context_length = new_ctx
compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent)
compressor._context_probed = True
print(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,}{new_ctx:,} tokens")
else:
print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...")
original_len = len(messages)
messages, active_system_prompt = self._compress_context(
messages, system_message, approx_tokens=approx_tokens
)
if len(messages) < original_len:
print(f"{self.log_prefix} 🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
continue # Retry with compressed messages
if len(messages) < original_len or new_ctx and new_ctx < old_ctx:
if len(messages) < original_len:
print(f"{self.log_prefix} 🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
continue # Retry with compressed messages or new tier
else:
# Can't compress further
# Can't compress further and already at minimum tier
print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.")
print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.")
logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.")

View file

@ -1,13 +1,22 @@
"""Tests for agent/model_metadata.py — token estimation and context lengths."""
import os
import tempfile
import pytest
import yaml
from unittest.mock import patch, MagicMock
from agent.model_metadata import (
CONTEXT_PROBE_TIERS,
DEFAULT_CONTEXT_LENGTHS,
estimate_tokens_rough,
estimate_messages_tokens_rough,
get_model_context_length,
get_next_probe_tier,
get_cached_context_length,
parse_context_limit_from_error,
save_context_length,
fetch_model_metadata,
_MODEL_CACHE_TTL,
)
@ -101,10 +110,10 @@ class TestGetModelContextLength:
assert result == 200000
@patch("agent.model_metadata.fetch_model_metadata")
def test_unknown_model_returns_128k(self, mock_fetch):
def test_unknown_model_returns_first_probe_tier(self, mock_fetch):
mock_fetch.return_value = {}
result = get_model_context_length("unknown/never-heard-of-this")
assert result == 128000
assert result == CONTEXT_PROBE_TIERS[0] # 2M — will be narrowed on context error
@patch("agent.model_metadata.fetch_model_metadata")
def test_partial_match_in_defaults(self, mock_fetch):
@ -154,3 +163,123 @@ class TestFetchModelMetadata:
mock_get.side_effect = Exception("Network error")
result = fetch_model_metadata(force_refresh=True)
assert result == {}
# =========================================================================
# Context probe tiers
# =========================================================================
class TestContextProbeTiers:
def test_tiers_descending(self):
for i in range(len(CONTEXT_PROBE_TIERS) - 1):
assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1]
def test_first_tier_is_2m(self):
assert CONTEXT_PROBE_TIERS[0] == 2_000_000
def test_last_tier_is_32k(self):
assert CONTEXT_PROBE_TIERS[-1] == 32_000
class TestGetNextProbeTier:
def test_from_2m(self):
assert get_next_probe_tier(2_000_000) == 1_000_000
def test_from_1m(self):
assert get_next_probe_tier(1_000_000) == 512_000
def test_from_128k(self):
assert get_next_probe_tier(128_000) == 64_000
def test_from_32k_returns_none(self):
assert get_next_probe_tier(32_000) is None
def test_from_below_min_returns_none(self):
assert get_next_probe_tier(16_000) is None
def test_from_arbitrary_value(self):
# 300K is between 512K and 200K, should return 200K
assert get_next_probe_tier(300_000) == 200_000
# =========================================================================
# Error message parsing
# =========================================================================
class TestParseContextLimitFromError:
def test_openai_format(self):
msg = "This model's maximum context length is 32768 tokens. However, your messages resulted in 45000 tokens."
assert parse_context_limit_from_error(msg) == 32768
def test_context_length_exceeded(self):
msg = "context_length_exceeded: maximum context length is 131072"
assert parse_context_limit_from_error(msg) == 131072
def test_context_size_exceeded(self):
msg = "Maximum context size 65536 exceeded"
assert parse_context_limit_from_error(msg) == 65536
def test_no_limit_in_message(self):
msg = "Something went wrong with the API"
assert parse_context_limit_from_error(msg) is None
def test_unreasonable_number_rejected(self):
msg = "context length is 42 tokens" # too small
assert parse_context_limit_from_error(msg) is None
def test_ollama_format(self):
msg = "Context size has been exceeded. Maximum context size is 32768"
assert parse_context_limit_from_error(msg) == 32768
# =========================================================================
# Persistent context length cache
# =========================================================================
class TestContextLengthCache:
def test_save_and_load(self, tmp_path):
cache_file = tmp_path / "context_length_cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("test/model", "http://localhost:8080/v1", 32768)
result = get_cached_context_length("test/model", "http://localhost:8080/v1")
assert result == 32768
def test_missing_cache_returns_none(self, tmp_path):
cache_file = tmp_path / "nonexistent.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
assert get_cached_context_length("test/model", "http://x") is None
def test_multiple_models_cached(self, tmp_path):
cache_file = tmp_path / "context_length_cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("model-a", "http://a", 64000)
save_context_length("model-b", "http://b", 128000)
assert get_cached_context_length("model-a", "http://a") == 64000
assert get_cached_context_length("model-b", "http://b") == 128000
def test_same_model_different_providers(self, tmp_path):
cache_file = tmp_path / "context_length_cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("llama-3", "http://local:8080", 32768)
save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072)
assert get_cached_context_length("llama-3", "http://local:8080") == 32768
assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072
def test_idempotent_save(self, tmp_path):
cache_file = tmp_path / "context_length_cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("model", "http://x", 32768)
save_context_length("model", "http://x", 32768) # same value
with open(cache_file) as f:
data = yaml.safe_load(f)
assert len(data["context_lengths"]) == 1
@patch("agent.model_metadata.fetch_model_metadata")
def test_cached_value_takes_priority(self, mock_fetch, tmp_path):
"""Cached context length should be used before API or defaults."""
mock_fetch.return_value = {}
cache_file = tmp_path / "context_length_cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("unknown/model", "http://local", 65536)
result = get_model_context_length("unknown/model", base_url="http://local")
assert result == 65536