diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 1e8129f2..9c601a1b 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -34,17 +34,20 @@ class ContextCompressor: summary_target_tokens: int = 2500, quiet_mode: bool = False, summary_model_override: str = None, + base_url: str = "", ): self.model = model + self.base_url = base_url self.threshold_percent = threshold_percent self.protect_first_n = protect_first_n self.protect_last_n = protect_last_n self.summary_target_tokens = summary_target_tokens self.quiet_mode = quiet_mode - self.context_length = get_model_context_length(model) + self.context_length = get_model_context_length(model, base_url=base_url) self.threshold_tokens = int(self.context_length * threshold_percent) self.compression_count = 0 + self._context_probed = False # True after a step-down from context error self.last_prompt_tokens = 0 self.last_completion_tokens = 0 diff --git a/agent/model_metadata.py b/agent/model_metadata.py index d5eebd07..cf379979 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -5,10 +5,14 @@ and run_agent.py for pre-flight context checks. """ import logging +import os +import re import time -from typing import Any, Dict, List +from pathlib import Path +from typing import Any, Dict, List, Optional import requests +import yaml from hermes_constants import OPENROUTER_MODELS_URL @@ -18,6 +22,18 @@ _model_metadata_cache: Dict[str, Dict[str, Any]] = {} _model_metadata_cache_time: float = 0 _MODEL_CACHE_TTL = 3600 +# Descending tiers for context length probing when the model is unknown. +# We start high and step down on context-length errors until one works. +CONTEXT_PROBE_TIERS = [ + 2_000_000, + 1_000_000, + 512_000, + 200_000, + 128_000, + 64_000, + 32_000, +] + DEFAULT_CONTEXT_LENGTHS = { "anthropic/claude-opus-4": 200000, "anthropic/claude-opus-4.5": 200000, @@ -71,17 +87,115 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any return _model_metadata_cache or {} -def get_model_context_length(model: str) -> int: - """Get the context length for a model (API first, then fallback defaults).""" +def _get_context_cache_path() -> Path: + """Return path to the persistent context length cache file.""" + hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) + return hermes_home / "context_length_cache.yaml" + + +def _load_context_cache() -> Dict[str, int]: + """Load the model+provider → context_length cache from disk.""" + path = _get_context_cache_path() + if not path.exists(): + return {} + try: + with open(path) as f: + data = yaml.safe_load(f) or {} + return data.get("context_lengths", {}) + except Exception as e: + logger.debug("Failed to load context length cache: %s", e) + return {} + + +def save_context_length(model: str, base_url: str, length: int) -> None: + """Persist a discovered context length for a model+provider combo. + + Cache key is ``model@base_url`` so the same model name served from + different providers can have different limits. + """ + key = f"{model}@{base_url}" + cache = _load_context_cache() + if cache.get(key) == length: + return # already stored + cache[key] = length + path = _get_context_cache_path() + try: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + yaml.dump({"context_lengths": cache}, f, default_flow_style=False) + logger.info("Cached context length %s → %s tokens", key, f"{length:,}") + except Exception as e: + logger.debug("Failed to save context length cache: %s", e) + + +def get_cached_context_length(model: str, base_url: str) -> Optional[int]: + """Look up a previously discovered context length for model+provider.""" + key = f"{model}@{base_url}" + cache = _load_context_cache() + return cache.get(key) + + +def get_next_probe_tier(current_length: int) -> Optional[int]: + """Return the next lower probe tier, or None if already at minimum.""" + for tier in CONTEXT_PROBE_TIERS: + if tier < current_length: + return tier + return None + + +def parse_context_limit_from_error(error_msg: str) -> Optional[int]: + """Try to extract the actual context limit from an API error message. + + Many providers include the limit in their error text, e.g.: + - "maximum context length is 32768 tokens" + - "context_length_exceeded: 131072" + - "Maximum context size 32768 exceeded" + - "model's max context length is 65536" + """ + error_lower = error_msg.lower() + # Pattern: look for numbers near context-related keywords + patterns = [ + r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})', + r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})', + r'(\d{4,})\s*(?:token)?\s*(?:context|limit)', + ] + for pattern in patterns: + match = re.search(pattern, error_lower) + if match: + limit = int(match.group(1)) + # Sanity check: must be a reasonable context length + if 1024 <= limit <= 10_000_000: + return limit + return None + + +def get_model_context_length(model: str, base_url: str = "") -> int: + """Get the context length for a model. + + Resolution order: + 1. Persistent cache (previously discovered via probing) + 2. OpenRouter API metadata + 3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match) + 4. First probe tier (2M) — will be narrowed on first context error + """ + # 1. Check persistent cache (model+provider) + if base_url: + cached = get_cached_context_length(model, base_url) + if cached is not None: + return cached + + # 2. OpenRouter API metadata metadata = fetch_model_metadata() if model in metadata: return metadata[model].get("context_length", 128000) + # 3. Hardcoded defaults (fuzzy match) for default_model, length in DEFAULT_CONTEXT_LENGTHS.items(): if default_model in model or model in default_model: return length - return 128000 + # 4. Unknown model — start at highest probe tier + return CONTEXT_PROBE_TIERS[0] def estimate_tokens_rough(text: str) -> int: diff --git a/cli.py b/cli.py index 591487e5..3c7c26c2 100755 --- a/cli.py +++ b/cli.py @@ -508,7 +508,18 @@ def _get_available_skills() -> Dict[str, List[str]]: return skills_by_category -def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None): +def _format_context_length(tokens: int) -> str: + """Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M').""" + if tokens >= 1_000_000: + val = tokens / 1_000_000 + return f"{val:g}M" + elif tokens >= 1_000: + val = tokens / 1_000 + return f"{val:g}K" + return str(tokens) + + +def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None, context_length: int = None): """ Build and print a Claude Code-style welcome banner with caduceus on left and info on right. @@ -519,6 +530,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic tools: List of tool definitions enabled_toolsets: List of enabled toolset names session_id: Unique session identifier for logging + context_length: Model's context window size in tokens """ from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS @@ -544,7 +556,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic if len(model_short) > 28: model_short = model_short[:25] + "..." - left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") + ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else "" + left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") left_lines.append(f"[dim #B8860B]{cwd}[/]") # Add session ID if provided @@ -1079,6 +1092,11 @@ class HermesCLI: # Get terminal working directory (where commands will execute) cwd = os.getenv("TERMINAL_CWD", os.getcwd()) + # Get context length for display + ctx_len = None + if hasattr(self, 'agent') and self.agent and hasattr(self.agent, 'context_compressor'): + ctx_len = self.agent.context_compressor.context_length + # Build and display the banner build_welcome_banner( console=self.console, @@ -1087,6 +1105,7 @@ class HermesCLI: tools=tools, enabled_toolsets=self.enabled_toolsets, session_id=self.session_id, + context_length=ctx_len, ) # Show tool availability warnings if any tools are disabled diff --git a/hermes_cli/banner.py b/hermes_cli/banner.py index be1b3a95..127208e4 100644 --- a/hermes_cli/banner.py +++ b/hermes_cli/banner.py @@ -99,11 +99,23 @@ def get_available_skills() -> Dict[str, List[str]]: # Welcome banner # ========================================================================= +def _format_context_length(tokens: int) -> str: + """Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M').""" + if tokens >= 1_000_000: + val = tokens / 1_000_000 + return f"{val:g}M" + elif tokens >= 1_000: + val = tokens / 1_000 + return f"{val:g}K" + return str(tokens) + + def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None, - get_toolset_for_tool=None): + get_toolset_for_tool=None, + context_length: int = None): """Build and print a welcome banner with caduceus on left and info on right. Args: @@ -114,6 +126,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, enabled_toolsets: List of enabled toolset names. session_id: Session identifier. get_toolset_for_tool: Callable to map tool name -> toolset name. + context_length: Model's context window size in tokens. """ from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS if get_toolset_for_tool is None: @@ -135,7 +148,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str, model_short = model.split("/")[-1] if "/" in model else model if len(model_short) > 28: model_short = model_short[:25] + "..." - left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") + ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else "" + left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]") left_lines.append(f"[dim #B8860B]{cwd}[/]") if session_id: left_lines.append(f"[dim #8B8682]Session: {session_id}[/]") diff --git a/run_agent.py b/run_agent.py index e02c5fa3..c320b0fa 100644 --- a/run_agent.py +++ b/run_agent.py @@ -82,6 +82,8 @@ from agent.prompt_builder import ( from agent.model_metadata import ( fetch_model_metadata, get_model_context_length, estimate_tokens_rough, estimate_messages_tokens_rough, + get_next_probe_tier, parse_context_limit_from_error, + save_context_length, ) from agent.context_compressor import ContextCompressor from agent.prompt_caching import apply_anthropic_cache_control @@ -536,6 +538,7 @@ class AIAgent: summary_target_tokens=500, summary_model_override=compression_summary_model, quiet_mode=self.quiet_mode, + base_url=self.base_url, ) self.compression_enabled = compression_enabled self._user_turn_count = 0 @@ -3236,6 +3239,13 @@ class AIAgent: } self.context_compressor.update_from_response(usage_dict) + # Cache discovered context length after successful call + if self.context_compressor._context_probed: + ctx = self.context_compressor.context_length + save_context_length(self.model, self.base_url, ctx) + print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}") + self.context_compressor._context_probed = False + self.session_prompt_tokens += prompt_tokens self.session_completion_tokens += completion_tokens self.session_total_tokens += total_tokens @@ -3355,18 +3365,37 @@ class AIAgent: ]) if is_context_length_error: - print(f"{self.log_prefix}⚠️ Context length exceeded - attempting compression...") + compressor = self.context_compressor + old_ctx = compressor.context_length + + # Try to parse the actual limit from the error message + parsed_limit = parse_context_limit_from_error(error_msg) + if parsed_limit and parsed_limit < old_ctx: + new_ctx = parsed_limit + print(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})") + else: + # Step down to the next probe tier + new_ctx = get_next_probe_tier(old_ctx) + + if new_ctx and new_ctx < old_ctx: + compressor.context_length = new_ctx + compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent) + compressor._context_probed = True + print(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,} → {new_ctx:,} tokens") + else: + print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...") original_len = len(messages) messages, active_system_prompt = self._compress_context( messages, system_message, approx_tokens=approx_tokens ) - if len(messages) < original_len: - print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") - continue # Retry with compressed messages + if len(messages) < original_len or new_ctx and new_ctx < old_ctx: + if len(messages) < original_len: + print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") + continue # Retry with compressed messages or new tier else: - # Can't compress further + # Can't compress further and already at minimum tier print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.") print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.") logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.") diff --git a/tests/agent/test_model_metadata.py b/tests/agent/test_model_metadata.py index 404ee6b2..ffc98cb2 100644 --- a/tests/agent/test_model_metadata.py +++ b/tests/agent/test_model_metadata.py @@ -1,13 +1,22 @@ """Tests for agent/model_metadata.py — token estimation and context lengths.""" +import os +import tempfile + import pytest +import yaml from unittest.mock import patch, MagicMock from agent.model_metadata import ( + CONTEXT_PROBE_TIERS, DEFAULT_CONTEXT_LENGTHS, estimate_tokens_rough, estimate_messages_tokens_rough, get_model_context_length, + get_next_probe_tier, + get_cached_context_length, + parse_context_limit_from_error, + save_context_length, fetch_model_metadata, _MODEL_CACHE_TTL, ) @@ -101,10 +110,10 @@ class TestGetModelContextLength: assert result == 200000 @patch("agent.model_metadata.fetch_model_metadata") - def test_unknown_model_returns_128k(self, mock_fetch): + def test_unknown_model_returns_first_probe_tier(self, mock_fetch): mock_fetch.return_value = {} result = get_model_context_length("unknown/never-heard-of-this") - assert result == 128000 + assert result == CONTEXT_PROBE_TIERS[0] # 2M — will be narrowed on context error @patch("agent.model_metadata.fetch_model_metadata") def test_partial_match_in_defaults(self, mock_fetch): @@ -154,3 +163,123 @@ class TestFetchModelMetadata: mock_get.side_effect = Exception("Network error") result = fetch_model_metadata(force_refresh=True) assert result == {} + + +# ========================================================================= +# Context probe tiers +# ========================================================================= + +class TestContextProbeTiers: + def test_tiers_descending(self): + for i in range(len(CONTEXT_PROBE_TIERS) - 1): + assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1] + + def test_first_tier_is_2m(self): + assert CONTEXT_PROBE_TIERS[0] == 2_000_000 + + def test_last_tier_is_32k(self): + assert CONTEXT_PROBE_TIERS[-1] == 32_000 + + +class TestGetNextProbeTier: + def test_from_2m(self): + assert get_next_probe_tier(2_000_000) == 1_000_000 + + def test_from_1m(self): + assert get_next_probe_tier(1_000_000) == 512_000 + + def test_from_128k(self): + assert get_next_probe_tier(128_000) == 64_000 + + def test_from_32k_returns_none(self): + assert get_next_probe_tier(32_000) is None + + def test_from_below_min_returns_none(self): + assert get_next_probe_tier(16_000) is None + + def test_from_arbitrary_value(self): + # 300K is between 512K and 200K, should return 200K + assert get_next_probe_tier(300_000) == 200_000 + + +# ========================================================================= +# Error message parsing +# ========================================================================= + +class TestParseContextLimitFromError: + def test_openai_format(self): + msg = "This model's maximum context length is 32768 tokens. However, your messages resulted in 45000 tokens." + assert parse_context_limit_from_error(msg) == 32768 + + def test_context_length_exceeded(self): + msg = "context_length_exceeded: maximum context length is 131072" + assert parse_context_limit_from_error(msg) == 131072 + + def test_context_size_exceeded(self): + msg = "Maximum context size 65536 exceeded" + assert parse_context_limit_from_error(msg) == 65536 + + def test_no_limit_in_message(self): + msg = "Something went wrong with the API" + assert parse_context_limit_from_error(msg) is None + + def test_unreasonable_number_rejected(self): + msg = "context length is 42 tokens" # too small + assert parse_context_limit_from_error(msg) is None + + def test_ollama_format(self): + msg = "Context size has been exceeded. Maximum context size is 32768" + assert parse_context_limit_from_error(msg) == 32768 + + +# ========================================================================= +# Persistent context length cache +# ========================================================================= + +class TestContextLengthCache: + def test_save_and_load(self, tmp_path): + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("test/model", "http://localhost:8080/v1", 32768) + result = get_cached_context_length("test/model", "http://localhost:8080/v1") + assert result == 32768 + + def test_missing_cache_returns_none(self, tmp_path): + cache_file = tmp_path / "nonexistent.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + assert get_cached_context_length("test/model", "http://x") is None + + def test_multiple_models_cached(self, tmp_path): + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("model-a", "http://a", 64000) + save_context_length("model-b", "http://b", 128000) + assert get_cached_context_length("model-a", "http://a") == 64000 + assert get_cached_context_length("model-b", "http://b") == 128000 + + def test_same_model_different_providers(self, tmp_path): + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("llama-3", "http://local:8080", 32768) + save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072) + assert get_cached_context_length("llama-3", "http://local:8080") == 32768 + assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072 + + def test_idempotent_save(self, tmp_path): + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("model", "http://x", 32768) + save_context_length("model", "http://x", 32768) # same value + with open(cache_file) as f: + data = yaml.safe_load(f) + assert len(data["context_lengths"]) == 1 + + @patch("agent.model_metadata.fetch_model_metadata") + def test_cached_value_takes_priority(self, mock_fetch, tmp_path): + """Cached context length should be used before API or defaults.""" + mock_fetch.return_value = {} + cache_file = tmp_path / "context_length_cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("unknown/model", "http://local", 65536) + result = get_model_context_length("unknown/model", base_url="http://local") + assert result == 65536