From c886333d3218e3b23402111169199ca81486111b Mon Sep 17 00:00:00 2001 From: teknium1 Date: Thu, 5 Mar 2026 16:09:57 -0800 Subject: [PATCH] feat: smart context length probing with persistent caching + banner display MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the unsafe 128K fallback for unknown models with a descending probe strategy (2M → 1M → 512K → 200K → 128K → 64K → 32K). When a context-length error occurs, the agent steps down tiers and retries. The discovered limit is cached per model+provider combo in ~/.hermes/context_length_cache.yaml so subsequent sessions skip probing. Also parses API error messages to extract the actual context limit (e.g. 'maximum context length is 32768 tokens') for instant resolution. The CLI banner now displays the context window size next to the model name (e.g. 'claude-opus-4 · 200K context · Nous Research'). Changes: - agent/model_metadata.py: CONTEXT_PROBE_TIERS, persistent cache (save/load/get), parse_context_limit_from_error(), get_next_probe_tier() - agent/context_compressor.py: accepts base_url, passes to metadata - run_agent.py: step-down logic in context error handler, caches on success - cli.py + hermes_cli/banner.py: context length in welcome banner - tests: 22 new tests for probing, parsing, and caching Addresses #132. PR #319's approach (8K default) rejected — too conservative. --- agent/context_compressor.py | 5 +- agent/model_metadata.py | 122 +++++++++++++++++++++++++- cli.py | 23 ++++- hermes_cli/banner.py | 18 +++- run_agent.py | 39 +++++++-- tests/agent/test_model_metadata.py | 133 ++++++++++++++++++++++++++++- 6 files changed, 324 insertions(+), 16 deletions(-) diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 1e8129f2..9c601a1b 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -34,17 +34,20 @@ class ContextCompressor: summary_target_tokens: int = 2500, quiet_mode: bool = False, summary_model_override: str = None, + base_url: str = "", ): self.model = model + self.base_url = base_url self.threshold_percent = threshold_percent self.protect_first_n = protect_first_n self.protect_last_n = protect_last_n self.summary_target_tokens = summary_target_tokens self.quiet_mode = quiet_mode - self.context_length = get_model_context_length(model) + self.context_length = get_model_context_length(model, base_url=base_url) self.threshold_tokens = int(self.context_length * threshold_percent) self.compression_count = 0 + self._context_probed = False # True after a step-down from context error self.last_prompt_tokens = 0 self.last_completion_tokens = 0 diff --git a/agent/model_metadata.py b/agent/model_metadata.py index d5eebd07..cf379979 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -5,10 +5,14 @@ and run_agent.py for pre-flight context checks. """ import logging +import os +import re import time -from typing import Any, Dict, List +from pathlib import Path +from typing import Any, Dict, List, Optional import requests +import yaml from hermes_constants import OPENROUTER_MODELS_URL @@ -18,6 +22,18 @@ _model_metadata_cache: Dict[str, Dict[str, Any]] = {} _model_metadata_cache_time: float = 0 _MODEL_CACHE_TTL = 3600 +# Descending tiers for context length probing when the model is unknown. +# We start high and step down on context-length errors until one works. +CONTEXT_PROBE_TIERS = [ + 2_000_000, + 1_000_000, + 512_000, + 200_000, + 128_000, + 64_000, + 32_000, +] + DEFAULT_CONTEXT_LENGTHS = { "anthropic/claude-opus-4": 200000, "anthropic/claude-opus-4.5": 200000, @@ -71,17 +87,115 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any return _model_metadata_cache or {} -def get_model_context_length(model: str) -> int: - """Get the context length for a model (API first, then fallback defaults).""" +def _get_context_cache_path() -> Path: + """Return path to the persistent context length cache file.""" + hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) + return hermes_home / "context_length_cache.yaml" + + +def _load_context_cache() -> Dict[str, int]: + """Load the model+provider → context_length cache from disk.""" + path = _get_context_cache_path() + if not path.exists(): + return {} + try: + with open(path) as f: + data = yaml.safe_load(f) or {} + return data.get("context_lengths", {}) + except Exception as e: + logger.debug("Failed to load context length cache: %s", e) + return {} + + +def save_context_length(model: str, base_url: str, length: int) -> None: + """Persist a discovered context length for a model+provider combo. + + Cache key is ``model@base_url`` so the same model name served from + different providers can have different limits. + """ + key = f"{model}@{base_url}" + cache = _load_context_cache() + if cache.get(key) == length: + return # already stored + cache[key] = length + path = _get_context_cache_path() + try: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + yaml.dump({"context_lengths": cache}, f, default_flow_style=False) + logger.info("Cached context length %s → %s tokens", key, f"{length:,}") + except Exception as e: + logger.debug("Failed to save context length cache: %s", e) + + +def get_cached_context_length(model: str, base_url: str) -> Optional[int]: + """Look up a previously discovered context length for model+provider.""" + key = f"{model}@{base_url}" + cache = _load_context_cache() + return cache.get(key) + + +def get_next_probe_tier(current_length: int) -> Optional[int]: + """Return the next lower probe tier, or None if already at minimum.""" + for tier in CONTEXT_PROBE_TIERS: + if tier < current_length: + return tier + return None + + +def parse_context_limit_from_error(error_msg: str) -> Optional[int]: + """Try to extract the actual context limit from an API error message. + + Many providers include the limit in their error text, e.g.: + - "maximum context length is 32768 tokens" + - "context_length_exceeded: 131072" + - "Maximum context size 32768 exceeded" + - "model's max context length is 65536" + """ + error_lower = error_msg.lower() + # Pattern: look for numbers near context-related keywords + patterns = [ + r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})', + r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})', + r'(\d{4,})\s*(?:token)?\s*(?:context|limit)', + ] + for pattern in patterns: + match = re.search(pattern, error_lower) + if match: + limit = int(match.group(1)) + # Sanity check: must be a reasonable context length + if 1024 <= limit <= 10_000_000: + return limit + return None + + +def get_model_context_length(model: str, base_url: str = "") -> int: + """Get the context length for a model. + + Resolution order: + 1. Persistent cache (previously discovered via probing) + 2. OpenRouter API metadata + 3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match) + 4. First probe tier (2M) — will be narrowed on first context error + """ + # 1. Check persistent cache (model+provider) + if base_url: + cached = get_cached_context_length(model, base_url) + if cached is not None: + return cached + + # 2. OpenRouter API metadata metadata = fetch_model_metadata() if model in metadata: return metadata[model].get("context_length", 128000) + # 3. Hardcoded defaults (fuzzy match) for default_model, length in DEFAULT_CONTEXT_LENGTHS.items(): if default_model in model or model in default_model: return length - return 128000 + # 4. Unknown model — start at highest probe tier + return CONTEXT_PROBE_TIERS[0] def estimate_tokens_rough(text: str) -> int: diff --git a/cli.py b/cli.py index 591487e5..3c7c26c2 100755 --- a/cli.py +++ b/cli.py @@ -508,7 +508,18 @@ def _get_available_skills() -> Dict[str, List[str]]: return skills_by_category -def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None): +def _format_context_length(tokens: int) -> str: + """Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M').""" + if tokens >= 1_000_000: + val = tokens / 1_000_000 + return f"{val:g}M" + elif tokens >= 1_000: + val = tokens / 1_000 + return f"{val:g}K" + return str(tokens) + + +def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None, context_length: int = None): """ Build and print a Claude Code-style welcome banner with caduceus on left and info on right. @@ -519,6 +530,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic tools: List of tool definitions enabled_toolsets: List of enabled toolset names session_id: Unique session identifier for logging + context_length: Model's context window size in tokens """ from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS @@ -544,7 +556,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic if len(model_short) > 28: model_short = model_short[:25] + "..." - left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") + ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else "" + left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") left_lines.append(f"[dim #B8860B]{cwd}[/]") # Add session ID if provided @@ -1079,6 +1092,11 @@ class HermesCLI: # Get terminal working directory (where commands will execute) cwd = os.getenv("TERMINAL_CWD", os.getcwd()) + # Get context length for display + ctx_len = None + if hasattr(self, 'agent') and self.agent and hasattr(self.agent, 'context_compressor'): + ctx_len = self.agent.context_compressor.context_length + # Build and display the banner build_welcome_banner( console=self.console, @@ -1087,6 +1105,7 @@ class HermesCLI: tools=tools, enabled_toolsets=self.enabled_toolsets, session_id=self.session_id, + context_length=ctx_len, ) # Show tool availability warnings if any tools are disabled diff --git a/hermes_cli/banner.py b/hermes_cli/banner.py index be1b3a95..127208e4 100644 --- a/hermes_cli/banner.py +++ b/hermes_cli/banner.py @@ -99,11 +99,23 @@ def get_available_skills() -> Dict[str, List[str]]: # Welcome banner # ========================================================================= +def _format_context_length(tokens: int) -> str: + """Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M').""" + if tokens >= 1_000_000: + val = tokens / 1_000_000 + return f"{val:g}M" + elif tokens >= 1_000: + val = tokens / 1_000 + return f"{val:g}K" + return str(tokens) + + def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None, - get_toolset_for_tool=None): + get_toolset_for_tool=None, + context_length: int = None): """Build and print a welcome banner with caduceus on left and info on right. Args: @@ -114,6 +126,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, enabled_toolsets: List of enabled toolset names. session_id: Session identifier. get_toolset_for_tool: Callable to map tool name -> toolset name. + context_length: Model's context window size in tokens. """ from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS if get_toolset_for_tool is None: @@ -135,7 +148,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str, model_short = model.split("/")[-1] if "/" in model else model if len(model_short) > 28: model_short = model_short[:25] + "..." - left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") + ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else "" + left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") left_lines.append(f"[dim #B8860B]{cwd}[/]") if session_id: left_lines.append(f"[dim #8B8682]Session: {session_id}[/]") diff --git a/run_agent.py b/run_agent.py index e02c5fa3..c320b0fa 100644 --- a/run_agent.py +++ b/run_agent.py @@ -82,6 +82,8 @@ from agent.prompt_builder import ( from agent.model_metadata import ( fetch_model_metadata, get_model_context_length, estimate_tokens_rough, estimate_messages_tokens_rough, + get_next_probe_tier, parse_context_limit_from_error, + save_context_length, ) from agent.context_compressor import ContextCompressor from agent.prompt_caching import apply_anthropic_cache_control @@ -536,6 +538,7 @@ class AIAgent: summary_target_tokens=500, summary_model_override=compression_summary_model, quiet_mode=self.quiet_mode, + base_url=self.base_url, ) self.compression_enabled = compression_enabled self._user_turn_count = 0 @@ -3236,6 +3239,13 @@ class AIAgent: } self.context_compressor.update_from_response(usage_dict) + # Cache discovered context length after successful call + if self.context_compressor._context_probed: + ctx = self.context_compressor.context_length + save_context_length(self.model, self.base_url, ctx) + print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}") + self.context_compressor._context_probed = False + self.session_prompt_tokens += prompt_tokens self.session_completion_tokens += completion_tokens self.session_total_tokens += total_tokens @@ -3355,18 +3365,37 @@ class AIAgent: ]) if is_context_length_error: - print(f"{self.log_prefix}⚠️ Context length exceeded - attempting compression...") + compressor = self.context_compressor + old_ctx = compressor.context_length + + # Try to parse the actual limit from the error message + parsed_limit = parse_context_limit_from_error(error_msg) + if parsed_limit and parsed_limit < old_ctx: + new_ctx = parsed_limit + print(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})") + else: + # Step down to the next probe tier + new_ctx = get_next_probe_tier(old_ctx) + + if new_ctx and new_ctx < old_ctx: + compressor.context_length = new_ctx + compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent) + compressor._context_probed = True + print(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,} → {new_ctx:,} tokens") + else: + print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...") original_len = len(messages) messages, active_system_prompt = self._compress_context( messages, system_message, approx_tokens=approx_tokens ) - if len(messages) < original_len: - print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") - continue # Retry with compressed messages + if len(messages) < original_len or new_ctx and new_ctx < old_ctx: + if len(messages) < original_len: + print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") + continue # Retry with compressed messages or new tier else: - # Can't compress further + # Can't compress further and already at minimum tier print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.") print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.") logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.") diff --git a/tests/agent/test_model_metadata.py b/tests/agent/test_model_metadata.py index 404ee6b2..ffc98cb2 100644 --- a/tests/agent/test_model_metadata.py +++ b/tests/agent/test_model_metadata.py @@ -1,13 +1,22 @@ """Tests for agent/model_metadata.py — token estimation and context lengths.""" +import os +import tempfile + import pytest +import yaml from unittest.mock import patch, MagicMock from agent.model_metadata import ( + CONTEXT_PROBE_TIERS, DEFAULT_CONTEXT_LENGTHS, estimate_tokens_rough, estimate_messages_tokens_rough, get_model_context_length, + get_next_probe_tier, + get_cached_context_length, + parse_context_limit_from_error, + save_context_length, fetch_model_metadata, _MODEL_CACHE_TTL, ) @@ -101,10 +110,10 @@ class TestGetModelContextLength: assert result == 200000 @patch("agent.model_metadata.fetch_model_metadata") - def test_unknown_model_returns_128k(self, mock_fetch): + def test_unknown_model_returns_first_probe_tier(self, mock_fetch): mock_fetch.return_value = {} result = get_model_context_length("unknown/never-heard-of-this") - assert result == 128000 + assert result == CONTEXT_PROBE_TIERS[0] # 2M — will be narrowed on context error @patch("agent.model_metadata.fetch_model_metadata") def test_partial_match_in_defaults(self, mock_fetch): @@ -154,3 +163,123 @@ class TestFetchModelMetadata: mock_get.side_effect = Exception("Network error") result = fetch_model_metadata(force_refresh=True) assert result == {} + + +# ========================================================================= +# Context probe tiers +# ========================================================================= + +class TestContextProbeTiers: + def test_tiers_descending(self): + for i in range(len(CONTEXT_PROBE_TIERS) - 1): + assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1] + + def test_first_tier_is_2m(self): + assert CONTEXT_PROBE_TIERS[0] == 2_000_000 + + def test_last_tier_is_32k(self): + assert CONTEXT_PROBE_TIERS[-1] == 32_000 + + +class TestGetNextProbeTier: + def test_from_2m(self): + assert get_next_probe_tier(2_000_000) == 1_000_000 + + def test_from_1m(self): + assert get_next_probe_tier(1_000_000) == 512_000 + + def test_from_128k(self): + assert get_next_probe_tier(128_000) == 64_000 + + def test_from_32k_returns_none(self): + assert get_next_probe_tier(32_000) is None + + def test_from_below_min_returns_none(self): + assert get_next_probe_tier(16_000) is None + + def test_from_arbitrary_value(self): + # 300K is between 512K and 200K, should return 200K + assert get_next_probe_tier(300_000) == 200_000 + + +# ========================================================================= +# Error message parsing +# ========================================================================= + +class TestParseContextLimitFromError: + def test_openai_format(self): + msg = "This model's maximum context length is 32768 tokens. However, your messages resulted in 45000 tokens." + assert parse_context_limit_from_error(msg) == 32768 + + def test_context_length_exceeded(self): + msg = "context_length_exceeded: maximum context length is 131072" + assert parse_context_limit_from_error(msg) == 131072 + + def test_context_size_exceeded(self): + msg = "Maximum context size 65536 exceeded" + assert parse_context_limit_from_error(msg) == 65536 + + def test_no_limit_in_message(self): + msg = "Something went wrong with the API" + assert parse_context_limit_from_error(msg) is None + + def test_unreasonable_number_rejected(self): + msg = "context length is 42 tokens" # too small + assert parse_context_limit_from_error(msg) is None + + def test_ollama_format(self): + msg = "Context size has been exceeded. Maximum context size is 32768" + assert parse_context_limit_from_error(msg) == 32768 + + +# ========================================================================= +# Persistent context length cache +# ========================================================================= + +class TestContextLengthCache: + def test_save_and_load(self, tmp_path): + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("test/model", "http://localhost:8080/v1", 32768) + result = get_cached_context_length("test/model", "http://localhost:8080/v1") + assert result == 32768 + + def test_missing_cache_returns_none(self, tmp_path): + cache_file = tmp_path / "nonexistent.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + assert get_cached_context_length("test/model", "http://x") is None + + def test_multiple_models_cached(self, tmp_path): + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("model-a", "http://a", 64000) + save_context_length("model-b", "http://b", 128000) + assert get_cached_context_length("model-a", "http://a") == 64000 + assert get_cached_context_length("model-b", "http://b") == 128000 + + def test_same_model_different_providers(self, tmp_path): + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("llama-3", "http://local:8080", 32768) + save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072) + assert get_cached_context_length("llama-3", "http://local:8080") == 32768 + assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072 + + def test_idempotent_save(self, tmp_path): + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("model", "http://x", 32768) + save_context_length("model", "http://x", 32768) # same value + with open(cache_file) as f: + data = yaml.safe_load(f) + assert len(data["context_lengths"]) == 1 + + @patch("agent.model_metadata.fetch_model_metadata") + def test_cached_value_takes_priority(self, mock_fetch, tmp_path): + """Cached context length should be used before API or defaults.""" + mock_fetch.return_value = {} + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("unknown/model", "http://local", 65536) + result = get_model_context_length("unknown/model", base_url="http://local") + assert result == 65536