The architecture has been updated

This commit is contained in:
Skyber_2 2026-03-31 23:31:36 +03:00
parent 805f7a017e
commit a01257ead9
1119 changed files with 226 additions and 352 deletions

View file

@ -0,0 +1,6 @@
"""Agent internals -- extracted modules from run_agent.py.
These modules contain pure utility functions and self-contained classes
that were previously embedded in the 3,600-line run_agent.py. Extracting
them makes run_agent.py focused on the AIAgent orchestrator class.
"""

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,658 @@
"""Automatic context window compression for long conversations.
Self-contained class with its own OpenAI client for summarization.
Uses auxiliary model (cheap/fast) to summarize middle turns while
protecting head and tail context.
Improvements over v1:
- Structured summary template (Goal, Progress, Decisions, Files, Next Steps)
- Iterative summary updates (preserves info across multiple compactions)
- Token-budget tail protection instead of fixed message count
- Tool output pruning before LLM summarization (cheap pre-pass)
- Scaled summary budget (proportional to compressed content)
- Richer tool call/result detail in summarizer input
"""
import logging
import os
from typing import Any, Dict, List, Optional
from agent.auxiliary_client import call_llm
from agent.model_metadata import (
get_model_context_length,
estimate_messages_tokens_rough,
)
logger = logging.getLogger(__name__)
SUMMARY_PREFIX = (
"[CONTEXT COMPACTION] Earlier turns in this conversation were compacted "
"to save context space. The summary below describes work that was "
"already completed, and the current session state may still reflect "
"that work (for example, files may already be changed). Use the summary "
"and the current state to continue from where things left off, and "
"avoid repeating work:"
)
LEGACY_SUMMARY_PREFIX = "[CONTEXT SUMMARY]:"
# Minimum / maximum tokens for the summary output
_MIN_SUMMARY_TOKENS = 2000
_MAX_SUMMARY_TOKENS = 8000
# Proportion of compressed content to allocate for summary
_SUMMARY_RATIO = 0.20
# Token budget for tail protection (keep most-recent context)
_DEFAULT_TAIL_TOKEN_BUDGET = 20_000
# Placeholder used when pruning old tool results
_PRUNED_TOOL_PLACEHOLDER = "[Old tool output cleared to save context space]"
# Chars per token rough estimate
_CHARS_PER_TOKEN = 4
class ContextCompressor:
"""Compresses conversation context when approaching the model's context limit.
Algorithm:
1. Prune old tool results (cheap, no LLM call)
2. Protect head messages (system prompt + first exchange)
3. Protect tail messages by token budget (most recent ~20K tokens)
4. Summarize middle turns with structured LLM prompt
5. On subsequent compactions, iteratively update the previous summary
"""
def __init__(
self,
model: str,
threshold_percent: float = 0.50,
protect_first_n: int = 3,
protect_last_n: int = 4,
summary_target_tokens: int = 2500,
quiet_mode: bool = False,
summary_model_override: str = None,
base_url: str = "",
api_key: str = "",
config_context_length: int | None = None,
provider: str = "",
):
self.model = model
self.base_url = base_url
self.api_key = api_key
self.provider = provider
self.threshold_percent = threshold_percent
self.protect_first_n = protect_first_n
self.protect_last_n = protect_last_n
self.summary_target_tokens = summary_target_tokens
self.quiet_mode = quiet_mode
self.context_length = get_model_context_length(
model, base_url=base_url, api_key=api_key,
config_context_length=config_context_length,
provider=provider,
)
self.threshold_tokens = int(self.context_length * threshold_percent)
self.compression_count = 0
if not quiet_mode:
logger.info(
"Context compressor initialized: model=%s context_length=%d "
"threshold=%d (%.0f%%) provider=%s base_url=%s",
model, self.context_length, self.threshold_tokens,
threshold_percent * 100, provider or "none", base_url or "none",
)
self._context_probed = False # True after a step-down from context error
self.last_prompt_tokens = 0
self.last_completion_tokens = 0
self.last_total_tokens = 0
self.summary_model = summary_model_override or ""
# Stores the previous compaction summary for iterative updates
self._previous_summary: Optional[str] = None
def update_from_response(self, usage: Dict[str, Any]):
"""Update tracked token usage from API response."""
self.last_prompt_tokens = usage.get("prompt_tokens", 0)
self.last_completion_tokens = usage.get("completion_tokens", 0)
self.last_total_tokens = usage.get("total_tokens", 0)
def should_compress(self, prompt_tokens: int = None) -> bool:
"""Check if context exceeds the compression threshold."""
tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
return tokens >= self.threshold_tokens
def should_compress_preflight(self, messages: List[Dict[str, Any]]) -> bool:
"""Quick pre-flight check using rough estimate (before API call)."""
rough_estimate = estimate_messages_tokens_rough(messages)
return rough_estimate >= self.threshold_tokens
def get_status(self) -> Dict[str, Any]:
"""Get current compression status for display/logging."""
return {
"last_prompt_tokens": self.last_prompt_tokens,
"threshold_tokens": self.threshold_tokens,
"context_length": self.context_length,
"usage_percent": (self.last_prompt_tokens / self.context_length * 100) if self.context_length else 0,
"compression_count": self.compression_count,
}
# ------------------------------------------------------------------
# Tool output pruning (cheap pre-pass, no LLM call)
# ------------------------------------------------------------------
def _prune_old_tool_results(
self, messages: List[Dict[str, Any]], protect_tail_count: int,
) -> tuple[List[Dict[str, Any]], int]:
"""Replace old tool result contents with a short placeholder.
Walks backward from the end, protecting the most recent
``protect_tail_count`` messages. Older tool results get their
content replaced with a placeholder string.
Returns (pruned_messages, pruned_count).
"""
if not messages:
return messages, 0
result = [m.copy() for m in messages]
pruned = 0
prune_boundary = len(result) - protect_tail_count
for i in range(prune_boundary):
msg = result[i]
if msg.get("role") != "tool":
continue
content = msg.get("content", "")
if not content or content == _PRUNED_TOOL_PLACEHOLDER:
continue
# Only prune if the content is substantial (>200 chars)
if len(content) > 200:
result[i] = {**msg, "content": _PRUNED_TOOL_PLACEHOLDER}
pruned += 1
return result, pruned
# ------------------------------------------------------------------
# Summarization
# ------------------------------------------------------------------
def _compute_summary_budget(self, turns_to_summarize: List[Dict[str, Any]]) -> int:
"""Scale summary token budget with the amount of content being compressed."""
content_tokens = estimate_messages_tokens_rough(turns_to_summarize)
budget = int(content_tokens * _SUMMARY_RATIO)
return max(_MIN_SUMMARY_TOKENS, min(budget, _MAX_SUMMARY_TOKENS))
def _serialize_for_summary(self, turns: List[Dict[str, Any]]) -> str:
"""Serialize conversation turns into labeled text for the summarizer.
Includes tool call arguments and result content (up to 3000 chars
per message) so the summarizer can preserve specific details like
file paths, commands, and outputs.
"""
parts = []
for msg in turns:
role = msg.get("role", "unknown")
content = msg.get("content") or ""
# Tool results: keep more content than before (3000 chars)
if role == "tool":
tool_id = msg.get("tool_call_id", "")
if len(content) > 3000:
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
parts.append(f"[TOOL RESULT {tool_id}]: {content}")
continue
# Assistant messages: include tool call names AND arguments
if role == "assistant":
if len(content) > 3000:
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
tool_calls = msg.get("tool_calls", [])
if tool_calls:
tc_parts = []
for tc in tool_calls:
if isinstance(tc, dict):
fn = tc.get("function", {})
name = fn.get("name", "?")
args = fn.get("arguments", "")
# Truncate long arguments but keep enough for context
if len(args) > 500:
args = args[:400] + "..."
tc_parts.append(f" {name}({args})")
else:
fn = getattr(tc, "function", None)
name = getattr(fn, "name", "?") if fn else "?"
tc_parts.append(f" {name}(...)")
content += "\n[Tool calls:\n" + "\n".join(tc_parts) + "\n]"
parts.append(f"[ASSISTANT]: {content}")
continue
# User and other roles
if len(content) > 3000:
content = content[:2000] + "\n...[truncated]...\n" + content[-800:]
parts.append(f"[{role.upper()}]: {content}")
return "\n\n".join(parts)
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
"""Generate a structured summary of conversation turns.
Uses a structured template (Goal, Progress, Decisions, Files, Next Steps)
inspired by Pi-mono and OpenCode. When a previous summary exists,
generates an iterative update instead of summarizing from scratch.
Returns None if all attempts fail the caller should drop
the middle turns without a summary rather than inject a useless
placeholder.
"""
summary_budget = self._compute_summary_budget(turns_to_summarize)
content_to_summarize = self._serialize_for_summary(turns_to_summarize)
if self._previous_summary:
# Iterative update: preserve existing info, add new progress
prompt = f"""You are updating a context compaction summary. A previous compaction produced the summary below. New conversation turns have occurred since then and need to be incorporated.
PREVIOUS SUMMARY:
{self._previous_summary}
NEW TURNS TO INCORPORATE:
{content_to_summarize}
Update the summary using this exact structure. PRESERVE all existing information that is still relevant. ADD new progress. Move items from "In Progress" to "Done" when completed. Remove information only if it is clearly obsolete.
## Goal
[What the user is trying to accomplish preserve from previous summary, update if goal evolved]
## Constraints & Preferences
[User preferences, coding style, constraints, important decisions accumulate across compactions]
## Progress
### Done
[Completed work include specific file paths, commands run, results obtained]
### In Progress
[Work currently underway]
### Blocked
[Any blockers or issues encountered]
## Key Decisions
[Important technical decisions and why they were made]
## Relevant Files
[Files read, modified, or created with brief note on each. Accumulate across compactions.]
## Next Steps
[What needs to happen next to continue the work]
## Critical Context
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
Target ~{summary_budget} tokens. Be specific include file paths, command outputs, error messages, and concrete values rather than vague descriptions.
Write only the summary body. Do not include any preamble or prefix."""
else:
# First compaction: summarize from scratch
prompt = f"""Create a structured handoff summary for a later assistant that will continue this conversation after earlier turns are compacted.
TURNS TO SUMMARIZE:
{content_to_summarize}
Use this exact structure:
## Goal
[What the user is trying to accomplish]
## Constraints & Preferences
[User preferences, coding style, constraints, important decisions]
## Progress
### Done
[Completed work include specific file paths, commands run, results obtained]
### In Progress
[Work currently underway]
### Blocked
[Any blockers or issues encountered]
## Key Decisions
[Important technical decisions and why they were made]
## Relevant Files
[Files read, modified, or created with brief note on each]
## Next Steps
[What needs to happen next to continue the work]
## Critical Context
[Any specific values, error messages, configuration details, or data that would be lost without explicit preservation]
Target ~{summary_budget} tokens. Be specific include file paths, command outputs, error messages, and concrete values rather than vague descriptions. The goal is to prevent the next assistant from repeating work or losing important details.
Write only the summary body. Do not include any preamble or prefix."""
try:
call_kwargs = {
"task": "compression",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
"max_tokens": summary_budget * 2,
"timeout": 45.0,
}
if self.summary_model:
call_kwargs["model"] = self.summary_model
response = call_llm(**call_kwargs)
content = response.choices[0].message.content
# Handle cases where content is not a string (e.g., dict from llama.cpp)
if not isinstance(content, str):
content = str(content) if content else ""
summary = content.strip()
# Store for iterative updates on next compaction
self._previous_summary = summary
return self._with_summary_prefix(summary)
except RuntimeError:
logging.warning("Context compression: no provider available for "
"summary. Middle turns will be dropped without summary.")
return None
except Exception as e:
logging.warning("Failed to generate context summary: %s", e)
return None
@staticmethod
def _with_summary_prefix(summary: str) -> str:
"""Normalize summary text to the current compaction handoff format."""
text = (summary or "").strip()
for prefix in (LEGACY_SUMMARY_PREFIX, SUMMARY_PREFIX):
if text.startswith(prefix):
text = text[len(prefix):].lstrip()
break
return f"{SUMMARY_PREFIX}\n{text}" if text else SUMMARY_PREFIX
# ------------------------------------------------------------------
# Tool-call / tool-result pair integrity helpers
# ------------------------------------------------------------------
@staticmethod
def _get_tool_call_id(tc) -> str:
"""Extract the call ID from a tool_call entry (dict or SimpleNamespace)."""
if isinstance(tc, dict):
return tc.get("id", "")
return getattr(tc, "id", "") or ""
def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Fix orphaned tool_call / tool_result pairs after compression.
Two failure modes:
1. A tool *result* references a call_id whose assistant tool_call was
removed (summarized/truncated). The API rejects this with
"No tool call found for function call output with call_id ...".
2. An assistant message has tool_calls whose results were dropped.
The API rejects this because every tool_call must be followed by
a tool result with the matching call_id.
This method removes orphaned results and inserts stub results for
orphaned calls so the message list is always well-formed.
"""
surviving_call_ids: set = set()
for msg in messages:
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = self._get_tool_call_id(tc)
if cid:
surviving_call_ids.add(cid)
result_call_ids: set = set()
for msg in messages:
if msg.get("role") == "tool":
cid = msg.get("tool_call_id")
if cid:
result_call_ids.add(cid)
# 1. Remove tool results whose call_id has no matching assistant tool_call
orphaned_results = result_call_ids - surviving_call_ids
if orphaned_results:
messages = [
m for m in messages
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
]
if not self.quiet_mode:
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
# 2. Add stub results for assistant tool_calls whose results were dropped
missing_results = surviving_call_ids - result_call_ids
if missing_results:
patched: List[Dict[str, Any]] = []
for msg in messages:
patched.append(msg)
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = self._get_tool_call_id(tc)
if cid in missing_results:
patched.append({
"role": "tool",
"content": "[Result from earlier conversation — see context summary above]",
"tool_call_id": cid,
})
messages = patched
if not self.quiet_mode:
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
return messages
def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int:
"""Push a compress-start boundary forward past any orphan tool results.
If ``messages[idx]`` is a tool result, slide forward until we hit a
non-tool message so we don't start the summarised region mid-group.
"""
while idx < len(messages) and messages[idx].get("role") == "tool":
idx += 1
return idx
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
"""Pull a compress-end boundary backward to avoid splitting a
tool_call / result group.
If the boundary falls in the middle of a tool-result group (i.e.
there are consecutive tool messages before ``idx``), walk backward
past all of them to find the parent assistant message. If found,
move the boundary before the assistant so the entire
assistant + tool_results group is included in the summarised region
rather than being split (which causes silent data loss when
``_sanitize_tool_pairs`` removes the orphaned tail results).
"""
if idx <= 0 or idx >= len(messages):
return idx
# Walk backward past consecutive tool results
check = idx - 1
while check >= 0 and messages[check].get("role") == "tool":
check -= 1
# If we landed on the parent assistant with tool_calls, pull the
# boundary before it so the whole group gets summarised together.
if check >= 0 and messages[check].get("role") == "assistant" and messages[check].get("tool_calls"):
idx = check
return idx
# ------------------------------------------------------------------
# Tail protection by token budget
# ------------------------------------------------------------------
def _find_tail_cut_by_tokens(
self, messages: List[Dict[str, Any]], head_end: int,
token_budget: int = _DEFAULT_TAIL_TOKEN_BUDGET,
) -> int:
"""Walk backward from the end of messages, accumulating tokens until
the budget is reached. Returns the index where the tail starts.
Never cuts inside a tool_call/result group. Falls back to the old
``protect_last_n`` if the budget would protect fewer messages.
"""
n = len(messages)
min_tail = self.protect_last_n
accumulated = 0
cut_idx = n # start from beyond the end
for i in range(n - 1, head_end - 1, -1):
msg = messages[i]
content = msg.get("content") or ""
msg_tokens = len(content) // _CHARS_PER_TOKEN + 10 # +10 for role/metadata
# Include tool call arguments in estimate
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict):
args = tc.get("function", {}).get("arguments", "")
msg_tokens += len(args) // _CHARS_PER_TOKEN
if accumulated + msg_tokens > token_budget and (n - i) >= min_tail:
break
accumulated += msg_tokens
cut_idx = i
# Ensure we protect at least protect_last_n messages
fallback_cut = n - min_tail
if cut_idx > fallback_cut:
cut_idx = fallback_cut
# If the token budget would protect everything (small conversations),
# fall back to the fixed protect_last_n approach so compression can
# still remove middle turns.
if cut_idx <= head_end:
cut_idx = fallback_cut
# Align to avoid splitting tool groups
cut_idx = self._align_boundary_backward(messages, cut_idx)
return max(cut_idx, head_end + 1)
# ------------------------------------------------------------------
# Main compression entry point
# ------------------------------------------------------------------
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
"""Compress conversation messages by summarizing middle turns.
Algorithm:
1. Prune old tool results (cheap pre-pass, no LLM call)
2. Protect head messages (system prompt + first exchange)
3. Find tail boundary by token budget (~20K tokens of recent context)
4. Summarize middle turns with structured LLM prompt
5. On re-compression, iteratively update the previous summary
After compression, orphaned tool_call / tool_result pairs are cleaned
up so the API never receives mismatched IDs.
"""
n_messages = len(messages)
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
if not self.quiet_mode:
logger.warning(
"Cannot compress: only %d messages (need > %d)",
n_messages,
self.protect_first_n + self.protect_last_n + 1,
)
return messages
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
# Phase 1: Prune old tool results (cheap, no LLM call)
messages, pruned_count = self._prune_old_tool_results(
messages, protect_tail_count=self.protect_last_n * 3,
)
if pruned_count and not self.quiet_mode:
logger.info("Pre-compression: pruned %d old tool result(s)", pruned_count)
# Phase 2: Determine boundaries
compress_start = self.protect_first_n
compress_start = self._align_boundary_forward(messages, compress_start)
# Use token-budget tail protection instead of fixed message count
compress_end = self._find_tail_cut_by_tokens(messages, compress_start)
if compress_start >= compress_end:
return messages
turns_to_summarize = messages[compress_start:compress_end]
if not self.quiet_mode:
logger.info(
"Context compression triggered (%d tokens >= %d threshold)",
display_tokens,
self.threshold_tokens,
)
logger.info(
"Model context limit: %d tokens (%.0f%% = %d)",
self.context_length,
self.threshold_percent * 100,
self.threshold_tokens,
)
tail_msgs = n_messages - compress_end
logger.info(
"Summarizing turns %d-%d (%d turns), protecting %d head + %d tail messages",
compress_start + 1,
compress_end,
len(turns_to_summarize),
compress_start,
tail_msgs,
)
# Phase 3: Generate structured summary
summary = self._generate_summary(turns_to_summarize)
# Phase 4: Assemble compressed message list
compressed = []
for i in range(compress_start):
msg = messages[i].copy()
if i == 0 and msg.get("role") == "system" and self.compression_count == 0:
msg["content"] = (
(msg.get("content") or "")
+ "\n\n[Note: Some earlier conversation turns have been compacted into a handoff summary to preserve context space. The current session state may still reflect earlier work, so build on that summary and state rather than re-doing work.]"
)
compressed.append(msg)
_merge_summary_into_tail = False
if summary:
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
# Pick a role that avoids consecutive same-role with both neighbors.
# Priority: avoid colliding with head (already committed), then tail.
if last_head_role in ("assistant", "tool"):
summary_role = "user"
else:
summary_role = "assistant"
# If the chosen role collides with the tail AND flipping wouldn't
# collide with the head, flip it.
if summary_role == first_tail_role:
flipped = "assistant" if summary_role == "user" else "user"
if flipped != last_head_role:
summary_role = flipped
else:
# Both roles would create consecutive same-role messages
# (e.g. head=assistant, tail=user — neither role works).
# Merge the summary into the first tail message instead
# of inserting a standalone message that breaks alternation.
_merge_summary_into_tail = True
if not _merge_summary_into_tail:
compressed.append({"role": summary_role, "content": summary})
else:
if not self.quiet_mode:
logger.warning("No summary model available — middle turns dropped without summary")
for i in range(compress_end, n_messages):
msg = messages[i].copy()
if _merge_summary_into_tail and i == compress_end:
original = msg.get("content") or ""
msg["content"] = summary + "\n\n" + original
_merge_summary_into_tail = False
compressed.append(msg)
self.compression_count += 1
compressed = self._sanitize_tool_pairs(compressed)
if not self.quiet_mode:
new_estimate = estimate_messages_tokens_rough(compressed)
saved_estimate = display_tokens - new_estimate
logger.info(
"Compressed: %d -> %d messages (~%d tokens saved)",
n_messages,
len(compressed),
saved_estimate,
)
logger.info("Compression #%d complete", self.compression_count)
return compressed

View file

@ -0,0 +1,485 @@
from __future__ import annotations
import asyncio
import inspect
import json
import mimetypes
import os
import re
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Awaitable, Callable
from agent.model_metadata import estimate_tokens_rough
REFERENCE_PATTERN = re.compile(
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
)
TRAILING_PUNCTUATION = ",.;!?"
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube")
_SENSITIVE_HERMES_DIRS = (Path("skills") / ".hub",)
_SENSITIVE_HOME_FILES = (
Path(".ssh") / "authorized_keys",
Path(".ssh") / "id_rsa",
Path(".ssh") / "id_ed25519",
Path(".ssh") / "config",
Path(".bashrc"),
Path(".zshrc"),
Path(".profile"),
Path(".bash_profile"),
Path(".zprofile"),
Path(".netrc"),
Path(".pgpass"),
Path(".npmrc"),
Path(".pypirc"),
)
@dataclass(frozen=True)
class ContextReference:
raw: str
kind: str
target: str
start: int
end: int
line_start: int | None = None
line_end: int | None = None
@dataclass
class ContextReferenceResult:
message: str
original_message: str
references: list[ContextReference] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
injected_tokens: int = 0
expanded: bool = False
blocked: bool = False
def parse_context_references(message: str) -> list[ContextReference]:
refs: list[ContextReference] = []
if not message:
return refs
for match in REFERENCE_PATTERN.finditer(message):
simple = match.group("simple")
if simple:
refs.append(
ContextReference(
raw=match.group(0),
kind=simple,
target="",
start=match.start(),
end=match.end(),
)
)
continue
kind = match.group("kind")
value = _strip_trailing_punctuation(match.group("value") or "")
line_start = None
line_end = None
target = value
if kind == "file":
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
if range_match:
target = range_match.group("path")
line_start = int(range_match.group("start"))
line_end = int(range_match.group("end") or range_match.group("start"))
refs.append(
ContextReference(
raw=match.group(0),
kind=kind,
target=target,
start=match.start(),
end=match.end(),
line_start=line_start,
line_end=line_end,
)
)
return refs
def preprocess_context_references(
message: str,
*,
cwd: str | Path,
context_length: int,
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
allowed_root: str | Path | None = None,
) -> ContextReferenceResult:
coro = preprocess_context_references_async(
message,
cwd=cwd,
context_length=context_length,
url_fetcher=url_fetcher,
allowed_root=allowed_root,
)
# Safe for both CLI (no loop) and gateway (loop already running).
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, coro).result()
return asyncio.run(coro)
async def preprocess_context_references_async(
message: str,
*,
cwd: str | Path,
context_length: int,
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
allowed_root: str | Path | None = None,
) -> ContextReferenceResult:
refs = parse_context_references(message)
if not refs:
return ContextReferenceResult(message=message, original_message=message)
cwd_path = Path(cwd).expanduser().resolve()
# Default to the current working directory so @ references cannot escape
# the active workspace unless a caller explicitly widens the root.
allowed_root_path = (
Path(allowed_root).expanduser().resolve() if allowed_root is not None else cwd_path
)
warnings: list[str] = []
blocks: list[str] = []
injected_tokens = 0
for ref in refs:
warning, block = await _expand_reference(
ref,
cwd_path,
url_fetcher=url_fetcher,
allowed_root=allowed_root_path,
)
if warning:
warnings.append(warning)
if block:
blocks.append(block)
injected_tokens += estimate_tokens_rough(block)
hard_limit = max(1, int(context_length * 0.50))
soft_limit = max(1, int(context_length * 0.25))
if injected_tokens > hard_limit:
warnings.append(
f"@ context injection refused: {injected_tokens} tokens exceeds the 50% hard limit ({hard_limit})."
)
return ContextReferenceResult(
message=message,
original_message=message,
references=refs,
warnings=warnings,
injected_tokens=injected_tokens,
expanded=False,
blocked=True,
)
if injected_tokens > soft_limit:
warnings.append(
f"@ context injection warning: {injected_tokens} tokens exceeds the 25% soft limit ({soft_limit})."
)
stripped = _remove_reference_tokens(message, refs)
final = stripped
if warnings:
final = f"{final}\n\n--- Context Warnings ---\n" + "\n".join(f"- {warning}" for warning in warnings)
if blocks:
final = f"{final}\n\n--- Attached Context ---\n\n" + "\n\n".join(blocks)
return ContextReferenceResult(
message=final.strip(),
original_message=message,
references=refs,
warnings=warnings,
injected_tokens=injected_tokens,
expanded=bool(blocks or warnings),
blocked=False,
)
async def _expand_reference(
ref: ContextReference,
cwd: Path,
*,
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
allowed_root: Path | None = None,
) -> tuple[str | None, str | None]:
try:
if ref.kind == "file":
return _expand_file_reference(ref, cwd, allowed_root=allowed_root)
if ref.kind == "folder":
return _expand_folder_reference(ref, cwd, allowed_root=allowed_root)
if ref.kind == "diff":
return _expand_git_reference(ref, cwd, ["diff"], "git diff")
if ref.kind == "staged":
return _expand_git_reference(ref, cwd, ["diff", "--staged"], "git diff --staged")
if ref.kind == "git":
count = max(1, min(int(ref.target or "1"), 10))
return _expand_git_reference(ref, cwd, ["log", f"-{count}", "-p"], f"git log -{count} -p")
if ref.kind == "url":
content = await _fetch_url_content(ref.target, url_fetcher=url_fetcher)
if not content:
return f"{ref.raw}: no content extracted", None
return None, f"🌐 {ref.raw} ({estimate_tokens_rough(content)} tokens)\n{content}"
except Exception as exc:
return f"{ref.raw}: {exc}", None
return f"{ref.raw}: unsupported reference type", None
def _expand_file_reference(
ref: ContextReference,
cwd: Path,
*,
allowed_root: Path | None = None,
) -> tuple[str | None, str | None]:
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
_ensure_reference_path_allowed(path)
if not path.exists():
return f"{ref.raw}: file not found", None
if not path.is_file():
return f"{ref.raw}: path is not a file", None
if _is_binary_file(path):
return f"{ref.raw}: binary files are not supported", None
text = path.read_text(encoding="utf-8")
if ref.line_start is not None:
lines = text.splitlines()
start_idx = max(ref.line_start - 1, 0)
end_idx = min(ref.line_end or ref.line_start, len(lines))
text = "\n".join(lines[start_idx:end_idx])
lang = _code_fence_language(path)
label = ref.raw
return None, f"📄 {label} ({estimate_tokens_rough(text)} tokens)\n```{lang}\n{text}\n```"
def _expand_folder_reference(
ref: ContextReference,
cwd: Path,
*,
allowed_root: Path | None = None,
) -> tuple[str | None, str | None]:
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
_ensure_reference_path_allowed(path)
if not path.exists():
return f"{ref.raw}: folder not found", None
if not path.is_dir():
return f"{ref.raw}: path is not a folder", None
listing = _build_folder_listing(path, cwd)
return None, f"📁 {ref.raw} ({estimate_tokens_rough(listing)} tokens)\n{listing}"
def _expand_git_reference(
ref: ContextReference,
cwd: Path,
args: list[str],
label: str,
) -> tuple[str | None, str | None]:
result = subprocess.run(
["git", *args],
cwd=cwd,
capture_output=True,
text=True,
)
if result.returncode != 0:
stderr = (result.stderr or "").strip() or "git command failed"
return f"{ref.raw}: {stderr}", None
content = result.stdout.strip()
if not content:
content = "(no output)"
return None, f"🧾 {label} ({estimate_tokens_rough(content)} tokens)\n```diff\n{content}\n```"
async def _fetch_url_content(
url: str,
*,
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
) -> str:
fetcher = url_fetcher or _default_url_fetcher
content = fetcher(url)
if inspect.isawaitable(content):
content = await content
return str(content or "").strip()
async def _default_url_fetcher(url: str) -> str:
from tools.web_tools import web_extract_tool
raw = await web_extract_tool([url], format="markdown", use_llm_processing=True)
payload = json.loads(raw)
docs = payload.get("data", {}).get("documents", [])
if not docs:
return ""
doc = docs[0]
return str(doc.get("content") or doc.get("raw_content") or "").strip()
def _resolve_path(cwd: Path, target: str, *, allowed_root: Path | None = None) -> Path:
path = Path(os.path.expanduser(target))
if not path.is_absolute():
path = cwd / path
resolved = path.resolve()
if allowed_root is not None:
try:
resolved.relative_to(allowed_root)
except ValueError as exc:
raise ValueError("path is outside the allowed workspace") from exc
return resolved
def _ensure_reference_path_allowed(path: Path) -> None:
home = Path(os.path.expanduser("~")).resolve()
hermes_home = Path(
os.getenv("HERMES_HOME", str(home / ".hermes"))
).expanduser().resolve()
blocked_exact = {home / rel for rel in _SENSITIVE_HOME_FILES}
blocked_exact.add(hermes_home / ".env")
blocked_dirs = [home / rel for rel in _SENSITIVE_HOME_DIRS]
blocked_dirs.extend(hermes_home / rel for rel in _SENSITIVE_HERMES_DIRS)
if path in blocked_exact:
raise ValueError("path is a sensitive credential file and cannot be attached")
for blocked_dir in blocked_dirs:
try:
path.relative_to(blocked_dir)
except ValueError:
continue
raise ValueError("path is a sensitive credential or internal Hermes path and cannot be attached")
def _strip_trailing_punctuation(value: str) -> str:
stripped = value.rstrip(TRAILING_PUNCTUATION)
while stripped.endswith((")", "]", "}")):
closer = stripped[-1]
opener = {")": "(", "]": "[", "}": "{"}[closer]
if stripped.count(closer) > stripped.count(opener):
stripped = stripped[:-1]
continue
break
return stripped
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
pieces: list[str] = []
cursor = 0
for ref in refs:
pieces.append(message[cursor:ref.start])
cursor = ref.end
pieces.append(message[cursor:])
text = "".join(pieces)
text = re.sub(r"\s{2,}", " ", text)
text = re.sub(r"\s+([,.;:!?])", r"\1", text)
return text.strip()
def _is_binary_file(path: Path) -> bool:
mime, _ = mimetypes.guess_type(path.name)
if mime and not mime.startswith("text/") and not any(
path.name.endswith(ext) for ext in (".py", ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".js", ".ts")
):
return True
chunk = path.read_bytes()[:4096]
return b"\x00" in chunk
def _build_folder_listing(path: Path, cwd: Path, limit: int = 200) -> str:
lines = [f"{path.relative_to(cwd)}/"]
entries = _iter_visible_entries(path, cwd, limit=limit)
for entry in entries:
rel = entry.relative_to(cwd)
indent = " " * max(len(rel.parts) - len(path.relative_to(cwd).parts) - 1, 0)
if entry.is_dir():
lines.append(f"{indent}- {entry.name}/")
else:
meta = _file_metadata(entry)
lines.append(f"{indent}- {entry.name} ({meta})")
if len(entries) >= limit:
lines.append("- ...")
return "\n".join(lines)
def _iter_visible_entries(path: Path, cwd: Path, limit: int) -> list[Path]:
rg_entries = _rg_files(path, cwd, limit=limit)
if rg_entries is not None:
output: list[Path] = []
seen_dirs: set[Path] = set()
for rel in rg_entries:
full = cwd / rel
for parent in full.parents:
if parent == cwd or parent in seen_dirs or path not in {parent, *parent.parents}:
continue
seen_dirs.add(parent)
output.append(parent)
output.append(full)
return sorted({p for p in output if p.exists()}, key=lambda p: (not p.is_dir(), str(p)))
output = []
for root, dirs, files in os.walk(path):
dirs[:] = sorted(d for d in dirs if not d.startswith(".") and d != "__pycache__")
files = sorted(f for f in files if not f.startswith("."))
root_path = Path(root)
for d in dirs:
output.append(root_path / d)
if len(output) >= limit:
return output
for f in files:
output.append(root_path / f)
if len(output) >= limit:
return output
return output
def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None:
try:
result = subprocess.run(
["rg", "--files", str(path.relative_to(cwd))],
cwd=cwd,
capture_output=True,
text=True,
)
except FileNotFoundError:
return None
if result.returncode != 0:
return None
files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()]
return files[:limit]
def _file_metadata(path: Path) -> str:
if _is_binary_file(path):
return f"{path.stat().st_size} bytes"
try:
line_count = path.read_text(encoding="utf-8").count("\n") + 1
except Exception:
return f"{path.stat().st_size} bytes"
return f"{line_count} lines"
def _code_fence_language(path: Path) -> str:
mapping = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".tsx": "tsx",
".jsx": "jsx",
".json": "json",
".md": "markdown",
".sh": "bash",
".yml": "yaml",
".yaml": "yaml",
".toml": "toml",
}
return mapping.get(path.suffix.lower(), "")

View file

@ -0,0 +1,447 @@
"""OpenAI-compatible shim that forwards Hermes requests to `copilot --acp`.
This adapter lets Hermes treat the GitHub Copilot ACP server as a chat-style
backend. Each request starts a short-lived ACP session, sends the formatted
conversation as a single prompt, collects text chunks, and converts the result
back into the minimal shape Hermes expects from an OpenAI client.
"""
from __future__ import annotations
import json
import os
import queue
import shlex
import subprocess
import threading
import time
from collections import deque
from pathlib import Path
from types import SimpleNamespace
from typing import Any
ACP_MARKER_BASE_URL = "acp://copilot"
_DEFAULT_TIMEOUT_SECONDS = 900.0
def _resolve_command() -> str:
return (
os.getenv("HERMES_COPILOT_ACP_COMMAND", "").strip()
or os.getenv("COPILOT_CLI_PATH", "").strip()
or "copilot"
)
def _resolve_args() -> list[str]:
raw = os.getenv("HERMES_COPILOT_ACP_ARGS", "").strip()
if not raw:
return ["--acp", "--stdio"]
return shlex.split(raw)
def _jsonrpc_error(message_id: Any, code: int, message: str) -> dict[str, Any]:
return {
"jsonrpc": "2.0",
"id": message_id,
"error": {
"code": code,
"message": message,
},
}
def _format_messages_as_prompt(messages: list[dict[str, Any]], model: str | None = None) -> str:
sections: list[str] = [
"You are being used as the active ACP agent backend for Hermes.",
"Use your own ACP capabilities and respond directly in natural language.",
"Do not emit OpenAI tool-call JSON.",
]
if model:
sections.append(f"Hermes requested model hint: {model}")
transcript: list[str] = []
for message in messages:
if not isinstance(message, dict):
continue
role = str(message.get("role") or "unknown").strip().lower()
if role == "tool":
role = "tool"
elif role not in {"system", "user", "assistant"}:
role = "context"
content = message.get("content")
rendered = _render_message_content(content)
if not rendered:
continue
label = {
"system": "System",
"user": "User",
"assistant": "Assistant",
"tool": "Tool",
"context": "Context",
}.get(role, role.title())
transcript.append(f"{label}:\n{rendered}")
if transcript:
sections.append("Conversation transcript:\n\n" + "\n\n".join(transcript))
sections.append("Continue the conversation from the latest user request.")
return "\n\n".join(section.strip() for section in sections if section and section.strip())
def _render_message_content(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content.strip()
if isinstance(content, dict):
if "text" in content:
return str(content.get("text") or "").strip()
if "content" in content and isinstance(content.get("content"), str):
return str(content.get("content") or "").strip()
return json.dumps(content, ensure_ascii=True)
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str) and text.strip():
parts.append(text.strip())
return "\n".join(parts).strip()
return str(content).strip()
def _ensure_path_within_cwd(path_text: str, cwd: str) -> Path:
candidate = Path(path_text)
if not candidate.is_absolute():
raise PermissionError("ACP file-system paths must be absolute.")
resolved = candidate.resolve()
root = Path(cwd).resolve()
try:
resolved.relative_to(root)
except ValueError as exc:
raise PermissionError(f"Path '{resolved}' is outside the session cwd '{root}'.") from exc
return resolved
class _ACPChatCompletions:
def __init__(self, client: "CopilotACPClient"):
self._client = client
def create(self, **kwargs: Any) -> Any:
return self._client._create_chat_completion(**kwargs)
class _ACPChatNamespace:
def __init__(self, client: "CopilotACPClient"):
self.completions = _ACPChatCompletions(client)
class CopilotACPClient:
"""Minimal OpenAI-client-compatible facade for Copilot ACP."""
def __init__(
self,
*,
api_key: str | None = None,
base_url: str | None = None,
default_headers: dict[str, str] | None = None,
acp_command: str | None = None,
acp_args: list[str] | None = None,
acp_cwd: str | None = None,
command: str | None = None,
args: list[str] | None = None,
**_: Any,
):
self.api_key = api_key or "copilot-acp"
self.base_url = base_url or ACP_MARKER_BASE_URL
self._default_headers = dict(default_headers or {})
self._acp_command = acp_command or command or _resolve_command()
self._acp_args = list(acp_args or args or _resolve_args())
self._acp_cwd = str(Path(acp_cwd or os.getcwd()).resolve())
self.chat = _ACPChatNamespace(self)
self.is_closed = False
self._active_process: subprocess.Popen[str] | None = None
self._active_process_lock = threading.Lock()
def close(self) -> None:
proc: subprocess.Popen[str] | None
with self._active_process_lock:
proc = self._active_process
self._active_process = None
self.is_closed = True
if proc is None:
return
try:
proc.terminate()
proc.wait(timeout=2)
except Exception:
try:
proc.kill()
except Exception:
pass
def _create_chat_completion(
self,
*,
model: str | None = None,
messages: list[dict[str, Any]] | None = None,
timeout: float | None = None,
**_: Any,
) -> Any:
prompt_text = _format_messages_as_prompt(messages or [], model=model)
response_text, reasoning_text = self._run_prompt(
prompt_text,
timeout_seconds=float(timeout or _DEFAULT_TIMEOUT_SECONDS),
)
usage = SimpleNamespace(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
prompt_tokens_details=SimpleNamespace(cached_tokens=0),
)
assistant_message = SimpleNamespace(
content=response_text,
tool_calls=[],
reasoning=reasoning_text or None,
reasoning_content=reasoning_text or None,
reasoning_details=None,
)
choice = SimpleNamespace(message=assistant_message, finish_reason="stop")
return SimpleNamespace(
choices=[choice],
usage=usage,
model=model or "copilot-acp",
)
def _run_prompt(self, prompt_text: str, *, timeout_seconds: float) -> tuple[str, str]:
try:
proc = subprocess.Popen(
[self._acp_command] + self._acp_args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
cwd=self._acp_cwd,
)
except FileNotFoundError as exc:
raise RuntimeError(
f"Could not start Copilot ACP command '{self._acp_command}'. "
"Install GitHub Copilot CLI or set HERMES_COPILOT_ACP_COMMAND/COPILOT_CLI_PATH."
) from exc
if proc.stdin is None or proc.stdout is None:
proc.kill()
raise RuntimeError("Copilot ACP process did not expose stdin/stdout pipes.")
self.is_closed = False
with self._active_process_lock:
self._active_process = proc
inbox: queue.Queue[dict[str, Any]] = queue.Queue()
stderr_tail: deque[str] = deque(maxlen=40)
def _stdout_reader() -> None:
for line in proc.stdout:
try:
inbox.put(json.loads(line))
except Exception:
inbox.put({"raw": line.rstrip("\n")})
def _stderr_reader() -> None:
if proc.stderr is None:
return
for line in proc.stderr:
stderr_tail.append(line.rstrip("\n"))
out_thread = threading.Thread(target=_stdout_reader, daemon=True)
err_thread = threading.Thread(target=_stderr_reader, daemon=True)
out_thread.start()
err_thread.start()
next_id = 0
def _request(method: str, params: dict[str, Any], *, text_parts: list[str] | None = None, reasoning_parts: list[str] | None = None) -> Any:
nonlocal next_id
next_id += 1
request_id = next_id
payload = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": params,
}
proc.stdin.write(json.dumps(payload) + "\n")
proc.stdin.flush()
deadline = time.time() + timeout_seconds
while time.time() < deadline:
if proc.poll() is not None:
break
try:
msg = inbox.get(timeout=0.1)
except queue.Empty:
continue
if self._handle_server_message(
msg,
process=proc,
cwd=self._acp_cwd,
text_parts=text_parts,
reasoning_parts=reasoning_parts,
):
continue
if msg.get("id") != request_id:
continue
if "error" in msg:
err = msg.get("error") or {}
raise RuntimeError(
f"Copilot ACP {method} failed: {err.get('message') or err}"
)
return msg.get("result")
stderr_text = "\n".join(stderr_tail).strip()
if proc.poll() is not None and stderr_text:
raise RuntimeError(f"Copilot ACP process exited early: {stderr_text}")
raise TimeoutError(f"Timed out waiting for Copilot ACP response to {method}.")
try:
_request(
"initialize",
{
"protocolVersion": 1,
"clientCapabilities": {
"fs": {
"readTextFile": True,
"writeTextFile": True,
}
},
"clientInfo": {
"name": "hermes-agent",
"title": "Hermes Agent",
"version": "0.0.0",
},
},
)
session = _request(
"session/new",
{
"cwd": self._acp_cwd,
"mcpServers": [],
},
) or {}
session_id = str(session.get("sessionId") or "").strip()
if not session_id:
raise RuntimeError("Copilot ACP did not return a sessionId.")
text_parts: list[str] = []
reasoning_parts: list[str] = []
_request(
"session/prompt",
{
"sessionId": session_id,
"prompt": [
{
"type": "text",
"text": prompt_text,
}
],
},
text_parts=text_parts,
reasoning_parts=reasoning_parts,
)
return "".join(text_parts), "".join(reasoning_parts)
finally:
self.close()
def _handle_server_message(
self,
msg: dict[str, Any],
*,
process: subprocess.Popen[str],
cwd: str,
text_parts: list[str] | None,
reasoning_parts: list[str] | None,
) -> bool:
method = msg.get("method")
if not isinstance(method, str):
return False
if method == "session/update":
params = msg.get("params") or {}
update = params.get("update") or {}
kind = str(update.get("sessionUpdate") or "").strip()
content = update.get("content") or {}
chunk_text = ""
if isinstance(content, dict):
chunk_text = str(content.get("text") or "")
if kind == "agent_message_chunk" and chunk_text and text_parts is not None:
text_parts.append(chunk_text)
elif kind == "agent_thought_chunk" and chunk_text and reasoning_parts is not None:
reasoning_parts.append(chunk_text)
return True
if process.stdin is None:
return True
message_id = msg.get("id")
params = msg.get("params") or {}
if method == "session/request_permission":
response = {
"jsonrpc": "2.0",
"id": message_id,
"result": {
"outcome": {
"outcome": "allow_once",
}
},
}
elif method == "fs/read_text_file":
try:
path = _ensure_path_within_cwd(str(params.get("path") or ""), cwd)
content = path.read_text() if path.exists() else ""
line = params.get("line")
limit = params.get("limit")
if isinstance(line, int) and line > 1:
lines = content.splitlines(keepends=True)
start = line - 1
end = start + limit if isinstance(limit, int) and limit > 0 else None
content = "".join(lines[start:end])
response = {
"jsonrpc": "2.0",
"id": message_id,
"result": {
"content": content,
},
}
except Exception as exc:
response = _jsonrpc_error(message_id, -32602, str(exc))
elif method == "fs/write_text_file":
try:
path = _ensure_path_within_cwd(str(params.get("path") or ""), cwd)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(str(params.get("content") or ""))
response = {
"jsonrpc": "2.0",
"id": message_id,
"result": None,
}
except Exception as exc:
response = _jsonrpc_error(message_id, -32602, str(exc))
else:
response = _jsonrpc_error(
message_id,
-32601,
f"ACP client method '{method}' is not supported by Hermes yet.",
)
process.stdin.write(json.dumps(response) + "\n")
process.stdin.flush()
return True

View file

@ -0,0 +1,722 @@
"""CLI presentation -- spinner, kawaii faces, tool preview formatting.
Pure display functions and classes with no AIAgent dependency.
Used by AIAgent._execute_tool_calls for CLI feedback.
"""
import json
import logging
import os
import sys
import threading
import time
# ANSI escape codes for coloring tool failure indicators
_RED = "\033[31m"
_RESET = "\033[0m"
logger = logging.getLogger(__name__)
# =========================================================================
# Skin-aware helpers (lazy import to avoid circular deps)
# =========================================================================
def _get_skin():
"""Get the active skin config, or None if not available."""
try:
from hermes_cli.skin_engine import get_active_skin
return get_active_skin()
except Exception:
return None
def get_skin_faces(key: str, default: list) -> list:
"""Get spinner face list from active skin, falling back to default."""
skin = _get_skin()
if skin:
faces = skin.get_spinner_list(key)
if faces:
return faces
return default
def get_skin_verbs() -> list:
"""Get thinking verbs from active skin."""
skin = _get_skin()
if skin:
verbs = skin.get_spinner_list("thinking_verbs")
if verbs:
return verbs
return KawaiiSpinner.THINKING_VERBS
def get_skin_tool_prefix() -> str:
"""Get tool output prefix character from active skin."""
skin = _get_skin()
if skin:
return skin.tool_prefix
return ""
def get_tool_emoji(tool_name: str, default: str = "") -> str:
"""Get the display emoji for a tool.
Resolution order:
1. Active skin's ``tool_emojis`` overrides (if a skin is loaded)
2. Tool registry's per-tool ``emoji`` field
3. *default* fallback
"""
# 1. Skin override
skin = _get_skin()
if skin and skin.tool_emojis:
override = skin.tool_emojis.get(tool_name)
if override:
return override
# 2. Registry default
try:
from tools.registry import registry
emoji = registry.get_emoji(tool_name, default="")
if emoji:
return emoji
except Exception:
pass
# 3. Hardcoded fallback
return default
# =========================================================================
# Tool preview (one-line summary of a tool call's primary argument)
# =========================================================================
def _oneline(text: str) -> str:
"""Collapse whitespace (including newlines) to single spaces."""
return " ".join(text.split())
def build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str | None:
"""Build a short preview of a tool call's primary argument for display."""
if not args:
return None
primary_args = {
"terminal": "command", "web_search": "query", "web_extract": "urls",
"read_file": "path", "write_file": "path", "patch": "path",
"search_files": "pattern", "browser_navigate": "url",
"browser_click": "ref", "browser_type": "text",
"image_generate": "prompt", "text_to_speech": "text",
"vision_analyze": "question", "mixture_of_agents": "user_prompt",
"skill_view": "name", "skills_list": "category",
"cronjob": "action",
"execute_code": "code", "delegate_task": "goal",
"clarify": "question", "skill_manage": "name",
}
if tool_name == "process":
action = args.get("action", "")
sid = args.get("session_id", "")
data = args.get("data", "")
timeout_val = args.get("timeout")
parts = [action]
if sid:
parts.append(sid[:16])
if data:
parts.append(f'"{_oneline(data[:20])}"')
if timeout_val and action == "wait":
parts.append(f"{timeout_val}s")
return " ".join(parts) if parts else None
if tool_name == "todo":
todos_arg = args.get("todos")
merge = args.get("merge", False)
if todos_arg is None:
return "reading task list"
elif merge:
return f"updating {len(todos_arg)} task(s)"
else:
return f"planning {len(todos_arg)} task(s)"
if tool_name == "session_search":
query = _oneline(args.get("query", ""))
return f"recall: \"{query[:25]}{'...' if len(query) > 25 else ''}\""
if tool_name == "memory":
action = args.get("action", "")
target = args.get("target", "")
if action == "add":
content = _oneline(args.get("content", ""))
return f"+{target}: \"{content[:25]}{'...' if len(content) > 25 else ''}\""
elif action == "replace":
return f"~{target}: \"{_oneline(args.get('old_text', '')[:20])}\""
elif action == "remove":
return f"-{target}: \"{_oneline(args.get('old_text', '')[:20])}\""
return action
if tool_name == "send_message":
target = args.get("target", "?")
msg = _oneline(args.get("message", ""))
if len(msg) > 20:
msg = msg[:17] + "..."
return f"to {target}: \"{msg}\""
if tool_name.startswith("rl_"):
rl_previews = {
"rl_list_environments": "listing envs",
"rl_select_environment": args.get("name", ""),
"rl_get_current_config": "reading config",
"rl_edit_config": f"{args.get('field', '')}={args.get('value', '')}",
"rl_start_training": "starting",
"rl_check_status": args.get("run_id", "")[:16],
"rl_stop_training": f"stopping {args.get('run_id', '')[:16]}",
"rl_get_results": args.get("run_id", "")[:16],
"rl_list_runs": "listing runs",
"rl_test_inference": f"{args.get('num_steps', 3)} steps",
}
return rl_previews.get(tool_name)
key = primary_args.get(tool_name)
if not key:
for fallback_key in ("query", "text", "command", "path", "name", "prompt", "code", "goal"):
if fallback_key in args:
key = fallback_key
break
if not key or key not in args:
return None
value = args[key]
if isinstance(value, list):
value = value[0] if value else ""
preview = _oneline(str(value))
if not preview:
return None
if len(preview) > max_len:
preview = preview[:max_len - 3] + "..."
return preview
# =========================================================================
# KawaiiSpinner
# =========================================================================
class KawaiiSpinner:
"""Animated spinner with kawaii faces for CLI feedback during tool execution."""
SPINNERS = {
'dots': ['', '', '', '', '', '', '', '', '', ''],
'bounce': ['', '', '', '', '', '', '', ''],
'grow': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
'arrows': ['', '', '', '', '', '', '', ''],
'star': ['', '', '', '', '', '', '', ''],
'moon': ['🌑', '🌒', '🌓', '🌔', '🌕', '🌖', '🌗', '🌘'],
'pulse': ['', '', '', '', '', ''],
'brain': ['🧠', '💭', '💡', '', '💫', '🌟', '💡', '💭'],
'sparkle': ['', '˚', '*', '', '', '', '*', '˚'],
}
KAWAII_WAITING = [
"(。◕‿◕。)", "(◕‿◕✿)", "٩(◕‿◕。)۶", "(✿◠‿◠)", "( ˘▽˘)っ",
"♪(´ε` )", "(◕ᴗ◕✿)", "ヾ(^∇^)", "(≧◡≦)", "(★ω★)",
]
KAWAII_THINKING = [
"(。•́︿•̀。)", "(◔_◔)", "(¬‿¬)", "( •_•)>⌐■-■", "(⌐■_■)",
"(´・_・`)", "◉_◉", "(°ロ°)", "( ˘⌣˘)♡", "ヽ(>∀<☆)☆",
"٩(๑❛ᴗ❛๑)۶", "(⊙_⊙)", "(¬_¬)", "( ͡° ͜ʖ ͡°)", "ಠ_ಠ",
]
THINKING_VERBS = [
"pondering", "contemplating", "musing", "cogitating", "ruminating",
"deliberating", "mulling", "reflecting", "processing", "reasoning",
"analyzing", "computing", "synthesizing", "formulating", "brainstorming",
]
def __init__(self, message: str = "", spinner_type: str = 'dots'):
self.message = message
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots'])
self.running = False
self.thread = None
self.frame_idx = 0
self.start_time = None
self.last_line_len = 0
self._last_flush_time = 0.0 # Rate-limit flushes for patch_stdout compat
# Capture stdout NOW, before any redirect_stdout(devnull) from
# child agents can replace sys.stdout with a black hole.
self._out = sys.stdout
def _write(self, text: str, end: str = '\n', flush: bool = False):
"""Write to the stdout captured at spinner creation time."""
try:
self._out.write(text + end)
if flush:
self._out.flush()
except (ValueError, OSError):
pass
def _animate(self):
# When stdout is not a real terminal (e.g. Docker, systemd, pipe),
# skip the animation entirely — it creates massive log bloat.
# Just log the start once and let stop() log the completion.
if not hasattr(self._out, 'isatty') or not self._out.isatty():
self._write(f" [tool] {self.message}", flush=True)
while self.running:
time.sleep(0.5)
return
# Cache skin wings at start (avoid per-frame imports)
skin = _get_skin()
wings = skin.get_spinner_wings() if skin else []
while self.running:
if os.getenv("HERMES_SPINNER_PAUSE"):
time.sleep(0.1)
continue
frame = self.spinner_frames[self.frame_idx % len(self.spinner_frames)]
elapsed = time.time() - self.start_time
if wings:
left, right = wings[self.frame_idx % len(wings)]
line = f" {left} {frame} {self.message} {right} ({elapsed:.1f}s)"
else:
line = f" {frame} {self.message} ({elapsed:.1f}s)"
pad = max(self.last_line_len - len(line), 0)
# Rate-limit flush() calls to avoid spinner spam under
# prompt_toolkit's patch_stdout. Each flush() pushes a queue
# item that may trigger a separate run_in_terminal() call; if
# items are processed one-at-a-time the \r overwrite is lost
# and every frame appears on its own line. By flushing at
# most every 0.4s we guarantee multiple \r-frames are batched
# into a single write, so the terminal collapses them correctly.
now = time.time()
should_flush = (now - self._last_flush_time) >= 0.4
self._write(f"\r{line}{' ' * pad}", end='', flush=should_flush)
if should_flush:
self._last_flush_time = now
self.last_line_len = len(line)
self.frame_idx += 1
time.sleep(0.12)
def start(self):
if self.running:
return
self.running = True
self.start_time = time.time()
self.thread = threading.Thread(target=self._animate, daemon=True)
self.thread.start()
def update_text(self, new_message: str):
self.message = new_message
def print_above(self, text: str):
"""Print a line above the spinner without disrupting animation.
Clears the current spinner line, prints the text, and lets the
next animation tick redraw the spinner on the line below.
Thread-safe: uses the captured stdout reference (self._out).
Works inside redirect_stdout(devnull) because _write bypasses
sys.stdout and writes to the stdout captured at spinner creation.
"""
if not self.running:
self._write(f" {text}", flush=True)
return
# Clear spinner line with spaces (not \033[K) to avoid garbled escape
# codes when prompt_toolkit's patch_stdout is active — same approach
# as stop(). Then print text; spinner redraws on next tick.
blanks = ' ' * max(self.last_line_len + 5, 40)
self._write(f"\r{blanks}\r {text}", flush=True)
def stop(self, final_message: str = None):
self.running = False
if self.thread:
self.thread.join(timeout=0.5)
is_tty = hasattr(self._out, 'isatty') and self._out.isatty()
if is_tty:
# Clear the spinner line with spaces instead of \033[K to avoid
# garbled escape codes when prompt_toolkit's patch_stdout is active.
blanks = ' ' * max(self.last_line_len + 5, 40)
self._write(f"\r{blanks}\r", end='', flush=True)
if final_message:
elapsed = f" ({time.time() - self.start_time:.1f}s)" if self.start_time else ""
if is_tty:
self._write(f" {final_message}", flush=True)
else:
self._write(f" [done] {final_message}{elapsed}", flush=True)
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
return False
# =========================================================================
# Kawaii face arrays (used by AIAgent._execute_tool_calls for spinner text)
# =========================================================================
KAWAII_SEARCH = [
"♪(´ε` )", "(。◕‿◕。)", "ヾ(^∇^)", "(◕ᴗ◕✿)", "( ˘▽˘)っ",
"٩(◕‿◕。)۶", "(✿◠‿◠)", "♪~(´ε` )", "(ノ´ヮ`)*:・゚✧", "(◎o◎)",
]
KAWAII_READ = [
"φ(゜▽゜*)♪", "( ˘▽˘)っ", "(⌐■_■)", "٩(。•́‿•̀。)۶", "(◕‿◕✿)",
"ヾ(@⌒ー⌒@)", "(✧ω✧)", "♪(๑ᴖ◡ᴖ๑)♪", "(≧◡≦)", "( ´ ▽ ` )",
]
KAWAII_TERMINAL = [
"ヽ(>∀<☆)", "(ノ°∀°)", "٩(^ᴗ^)۶", "ヾ(⌐■_■)ノ♪", "(•̀ᴗ•́)و",
"┗(0)┓", "(`・ω・´)", "( ̄▽ ̄)", "(ง •̀_•́)ง", "ヽ(´▽`)/",
]
KAWAII_BROWSER = [
"(ノ°∀°)", "(☞゚ヮ゚)☞", "( ͡° ͜ʖ ͡°)", "┌( ಠ_ಠ)┘", "(⊙_⊙)",
"ヾ(•ω•`)o", "( ̄ω ̄)", "( ˇωˇ )", "(ᵔᴥᵔ)", "(◎o◎)",
]
KAWAII_CREATE = [
"✧*。٩(ˊᗜˋ*)و✧", "(ノ◕ヮ◕)ノ*:・゚✧", "ヽ(>∀<☆)", "٩(♡ε♡)۶", "(◕‿◕)♡",
"✿◕ ‿ ◕✿", "(*≧▽≦)", "ヾ(-)", "(☆▽☆)", "°˖✧◝(⁰▿⁰)◜✧˖°",
]
KAWAII_SKILL = [
"ヾ(@⌒ー⌒@)", "(๑˃ᴗ˂)ﻭ", "٩(◕‿◕。)۶", "(✿╹◡╹)", "ヽ(・∀・)",
"(ノ´ヮ`)*:・゚✧", "♪(๑ᴖ◡ᴖ๑)♪", "(◠‿◠)", "٩(ˊᗜˋ*)و", "(^▽^)",
"ヾ(^∇^)", "(★ω★)/", "٩(。•́‿•̀。)۶", "(◕ᴗ◕✿)", "(◎o◎)",
"(✧ω✧)", "ヽ(>∀<☆)", "( ˘▽˘)っ", "(≧◡≦) ♡", "ヾ( ̄▽ ̄)",
]
KAWAII_THINK = [
"(っ°Д°;)っ", "(;′⌒`)", "(・_・ヾ", "( ´_ゝ`)", "( ̄ヘ ̄)",
"(。-`ω´-)", "( ˘︹˘ )", "(¬_¬)", "ヽ(ー_ー )", "(一_一)",
]
KAWAII_GENERIC = [
"♪(´ε` )", "(◕‿◕✿)", "ヾ(^∇^)", "٩(◕‿◕。)۶", "(✿◠‿◠)",
"(ノ´ヮ`)*:・゚✧", "ヽ(>∀<☆)", "(☆▽☆)", "( ˘▽˘)っ", "(≧◡≦)",
]
# =========================================================================
# Cute tool message (completion line that replaces the spinner)
# =========================================================================
def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]:
"""Inspect a tool result string for signs of failure.
Returns ``(is_failure, suffix)`` where *suffix* is an informational tag
like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic
failures. On success, returns ``(False, "")``.
"""
if result is None:
return False, ""
if tool_name == "terminal":
try:
data = json.loads(result)
exit_code = data.get("exit_code")
if exit_code is not None and exit_code != 0:
return True, f" [exit {exit_code}]"
except (json.JSONDecodeError, TypeError, AttributeError):
logger.debug("Could not parse terminal result as JSON for exit code check")
return False, ""
# Memory-specific: distinguish "full" from real errors
if tool_name == "memory":
try:
data = json.loads(result)
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
return True, " [full]"
except (json.JSONDecodeError, TypeError, AttributeError):
logger.debug("Could not parse memory result as JSON for capacity check")
# Generic heuristic for non-terminal tools
lower = result[:500].lower()
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
return True, " [error]"
return False, ""
def get_cute_tool_message(
tool_name: str, args: dict, duration: float, result: str | None = None,
) -> str:
"""Generate a formatted tool completion line for CLI quiet mode.
Format: ``| {emoji} {verb:9} {detail} {duration}``
When *result* is provided the line is checked for failure indicators.
Failed tool calls get a red prefix and an informational suffix.
"""
dur = f"{duration:.1f}s"
is_failure, failure_suffix = _detect_tool_failure(tool_name, result)
skin_prefix = get_skin_tool_prefix()
def _trunc(s, n=40):
s = str(s)
return (s[:n-3] + "...") if len(s) > n else s
def _path(p, n=35):
p = str(p)
return ("..." + p[-(n-3):]) if len(p) > n else p
def _wrap(line: str) -> str:
"""Apply skin tool prefix and failure suffix."""
if skin_prefix != "":
line = line.replace("", skin_prefix, 1)
if not is_failure:
return line
return f"{line}{failure_suffix}"
if tool_name == "web_search":
return _wrap(f"┊ 🔍 search {_trunc(args.get('query', ''), 42)} {dur}")
if tool_name == "web_extract":
urls = args.get("urls", [])
if urls:
url = urls[0] if isinstance(urls, list) else str(urls)
domain = url.replace("https://", "").replace("http://", "").split("/")[0]
extra = f" +{len(urls)-1}" if len(urls) > 1 else ""
return _wrap(f"┊ 📄 fetch {_trunc(domain, 35)}{extra} {dur}")
return _wrap(f"┊ 📄 fetch pages {dur}")
if tool_name == "web_crawl":
url = args.get("url", "")
domain = url.replace("https://", "").replace("http://", "").split("/")[0]
return _wrap(f"┊ 🕸️ crawl {_trunc(domain, 35)} {dur}")
if tool_name == "terminal":
return _wrap(f"┊ 💻 $ {_trunc(args.get('command', ''), 42)} {dur}")
if tool_name == "process":
action = args.get("action", "?")
sid = args.get("session_id", "")[:12]
labels = {"list": "ls processes", "poll": f"poll {sid}", "log": f"log {sid}",
"wait": f"wait {sid}", "kill": f"kill {sid}", "write": f"write {sid}", "submit": f"submit {sid}"}
return _wrap(f"┊ ⚙️ proc {labels.get(action, f'{action} {sid}')} {dur}")
if tool_name == "read_file":
return _wrap(f"┊ 📖 read {_path(args.get('path', ''))} {dur}")
if tool_name == "write_file":
return _wrap(f"┊ ✍️ write {_path(args.get('path', ''))} {dur}")
if tool_name == "patch":
return _wrap(f"┊ 🔧 patch {_path(args.get('path', ''))} {dur}")
if tool_name == "search_files":
pattern = _trunc(args.get("pattern", ""), 35)
target = args.get("target", "content")
verb = "find" if target == "files" else "grep"
return _wrap(f"┊ 🔎 {verb:9} {pattern} {dur}")
if tool_name == "browser_navigate":
url = args.get("url", "")
domain = url.replace("https://", "").replace("http://", "").split("/")[0]
return _wrap(f"┊ 🌐 navigate {_trunc(domain, 35)} {dur}")
if tool_name == "browser_snapshot":
mode = "full" if args.get("full") else "compact"
return _wrap(f"┊ 📸 snapshot {mode} {dur}")
if tool_name == "browser_click":
return _wrap(f"┊ 👆 click {args.get('ref', '?')} {dur}")
if tool_name == "browser_type":
return _wrap(f"┊ ⌨️ type \"{_trunc(args.get('text', ''), 30)}\" {dur}")
if tool_name == "browser_scroll":
d = args.get("direction", "down")
arrow = {"down": "", "up": "", "right": "", "left": ""}.get(d, "")
return _wrap(f"{arrow} scroll {d} {dur}")
if tool_name == "browser_back":
return _wrap(f"┊ ◀️ back {dur}")
if tool_name == "browser_press":
return _wrap(f"┊ ⌨️ press {args.get('key', '?')} {dur}")
if tool_name == "browser_close":
return _wrap(f"┊ 🚪 close browser {dur}")
if tool_name == "browser_get_images":
return _wrap(f"┊ 🖼️ images extracting {dur}")
if tool_name == "browser_vision":
return _wrap(f"┊ 👁️ vision analyzing page {dur}")
if tool_name == "todo":
todos_arg = args.get("todos")
merge = args.get("merge", False)
if todos_arg is None:
return _wrap(f"┊ 📋 plan reading tasks {dur}")
elif merge:
return _wrap(f"┊ 📋 plan update {len(todos_arg)} task(s) {dur}")
else:
return _wrap(f"┊ 📋 plan {len(todos_arg)} task(s) {dur}")
if tool_name == "session_search":
return _wrap(f"┊ 🔍 recall \"{_trunc(args.get('query', ''), 35)}\" {dur}")
if tool_name == "memory":
action = args.get("action", "?")
target = args.get("target", "")
if action == "add":
return _wrap(f"┊ 🧠 memory +{target}: \"{_trunc(args.get('content', ''), 30)}\" {dur}")
elif action == "replace":
return _wrap(f"┊ 🧠 memory ~{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}")
elif action == "remove":
return _wrap(f"┊ 🧠 memory -{target}: \"{_trunc(args.get('old_text', ''), 20)}\" {dur}")
return _wrap(f"┊ 🧠 memory {action} {dur}")
if tool_name == "skills_list":
return _wrap(f"┊ 📚 skills list {args.get('category', 'all')} {dur}")
if tool_name == "skill_view":
return _wrap(f"┊ 📚 skill {_trunc(args.get('name', ''), 30)} {dur}")
if tool_name == "image_generate":
return _wrap(f"┊ 🎨 create {_trunc(args.get('prompt', ''), 35)} {dur}")
if tool_name == "text_to_speech":
return _wrap(f"┊ 🔊 speak {_trunc(args.get('text', ''), 30)} {dur}")
if tool_name == "vision_analyze":
return _wrap(f"┊ 👁️ vision {_trunc(args.get('question', ''), 30)} {dur}")
if tool_name == "mixture_of_agents":
return _wrap(f"┊ 🧠 reason {_trunc(args.get('user_prompt', ''), 30)} {dur}")
if tool_name == "send_message":
return _wrap(f"┊ 📨 send {args.get('target', '?')}: \"{_trunc(args.get('message', ''), 25)}\" {dur}")
if tool_name == "cronjob":
action = args.get("action", "?")
if action == "create":
skills = args.get("skills") or ([] if not args.get("skill") else [args.get("skill")])
label = args.get("name") or (skills[0] if skills else None) or args.get("prompt", "task")
return _wrap(f"┊ ⏰ cron create {_trunc(label, 24)} {dur}")
if action == "list":
return _wrap(f"┊ ⏰ cron listing {dur}")
return _wrap(f"┊ ⏰ cron {action} {args.get('job_id', '')} {dur}")
if tool_name.startswith("rl_"):
rl = {
"rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}",
"rl_get_current_config": "get config", "rl_edit_config": f"set {args.get('field', '?')}",
"rl_start_training": "start training", "rl_check_status": f"status {args.get('run_id', '?')[:12]}",
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", "rl_get_results": f"results {args.get('run_id', '?')[:12]}",
"rl_list_runs": "list runs", "rl_test_inference": "test inference",
}
return _wrap(f"┊ 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}")
if tool_name == "execute_code":
code = args.get("code", "")
first_line = code.strip().split("\n")[0] if code.strip() else ""
return _wrap(f"┊ 🐍 exec {_trunc(first_line, 35)} {dur}")
if tool_name == "delegate_task":
tasks = args.get("tasks")
if tasks and isinstance(tasks, list):
return _wrap(f"┊ 🔀 delegate {len(tasks)} parallel tasks {dur}")
return _wrap(f"┊ 🔀 delegate {_trunc(args.get('goal', ''), 35)} {dur}")
preview = build_tool_preview(tool_name, args) or ""
return _wrap(f"┊ ⚡ {tool_name[:9]:9} {_trunc(preview, 35)} {dur}")
# =========================================================================
# Honcho session line (one-liner with clickable OSC 8 hyperlink)
# =========================================================================
_DIM = "\033[2m"
_SKY_BLUE = "\033[38;5;117m"
_ANSI_RESET = "\033[0m"
def honcho_session_url(workspace: str, session_name: str) -> str:
"""Build a Honcho app URL for a session."""
from urllib.parse import quote
return (
f"https://app.honcho.dev/explore"
f"?workspace={quote(workspace, safe='')}"
f"&view=sessions"
f"&session={quote(session_name, safe='')}"
)
def _osc8_link(url: str, text: str) -> str:
"""OSC 8 terminal hyperlink (clickable in iTerm2, Ghostty, WezTerm, etc.)."""
return f"\033]8;;{url}\033\\{text}\033]8;;\033\\"
def honcho_session_line(workspace: str, session_name: str) -> str:
"""One-line session indicator: `Honcho session: <clickable name>`."""
url = honcho_session_url(workspace, session_name)
linked_name = _osc8_link(url, f"{_SKY_BLUE}{session_name}{_ANSI_RESET}")
return f"{_DIM}Honcho session:{_ANSI_RESET} {linked_name}"
def write_tty(text: str) -> None:
"""Write directly to /dev/tty, bypassing stdout capture."""
try:
fd = os.open("/dev/tty", os.O_WRONLY)
os.write(fd, text.encode("utf-8"))
os.close(fd)
except OSError:
sys.stdout.write(text)
sys.stdout.flush()
# =========================================================================
# Context pressure display (CLI user-facing warnings)
# =========================================================================
# ANSI color codes for context pressure tiers
_CYAN = "\033[36m"
_YELLOW = "\033[33m"
_BOLD = "\033[1m"
_DIM_ANSI = "\033[2m"
# Bar characters
_BAR_FILLED = ""
_BAR_EMPTY = ""
_BAR_WIDTH = 20
def format_context_pressure(
compaction_progress: float,
threshold_tokens: int,
threshold_percent: float,
compression_enabled: bool = True,
) -> str:
"""Build a formatted context pressure line for CLI display.
The bar and percentage show progress toward the compaction threshold,
NOT the raw context window. 100% = compaction fires.
Uses ANSI colors:
- cyan at ~60% to compaction = informational
- bold yellow at ~85% to compaction = warning
Args:
compaction_progress: How close to compaction (0.01.0, 1.0 = fires).
threshold_tokens: Compaction threshold in tokens.
threshold_percent: Compaction threshold as a fraction of context window.
compression_enabled: Whether auto-compression is active.
"""
pct_int = int(compaction_progress * 100)
filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH)
bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)
threshold_k = f"{threshold_tokens // 1000}k" if threshold_tokens >= 1000 else str(threshold_tokens)
threshold_pct_int = int(threshold_percent * 100)
# Tier styling
if compaction_progress >= 0.85:
color = f"{_BOLD}{_YELLOW}"
icon = ""
if compression_enabled:
hint = "compaction imminent"
else:
hint = "no auto-compaction"
else:
color = _CYAN
icon = ""
hint = "approaching compaction"
return (
f" {color}{icon} context {bar} {pct_int}% to compaction{_ANSI_RESET}"
f" {_DIM_ANSI}{threshold_k} threshold ({threshold_pct_int}%) · {hint}{_ANSI_RESET}"
)
def format_context_pressure_gateway(
compaction_progress: float,
threshold_percent: float,
compression_enabled: bool = True,
) -> str:
"""Build a plain-text context pressure notification for messaging platforms.
No ANSI just Unicode and plain text suitable for Telegram/Discord/etc.
The percentage shows progress toward the compaction threshold.
"""
pct_int = int(compaction_progress * 100)
filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH)
bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)
threshold_pct_int = int(threshold_percent * 100)
if compaction_progress >= 0.85:
icon = "⚠️"
if compression_enabled:
hint = f"Context compaction is imminent (threshold: {threshold_pct_int}% of window)."
else:
hint = "Auto-compaction is disabled — context may be truncated."
else:
icon = ""
hint = f"Compaction threshold is at {threshold_pct_int}% of context window."
return f"{icon} Context: {bar} {pct_int}% to compaction\n{hint}"

View file

@ -0,0 +1,792 @@
"""
Session Insights Engine for Hermes Agent.
Analyzes historical session data from the SQLite state database to produce
comprehensive usage insights token consumption, cost estimates, tool usage
patterns, activity trends, model/platform breakdowns, and session metrics.
Inspired by Claude Code's /insights command, adapted for Hermes Agent's
multi-platform architecture with additional cost estimation and platform
breakdown capabilities.
Usage:
from agent.insights import InsightsEngine
engine = InsightsEngine(db)
report = engine.generate(days=30)
print(engine.format_terminal(report))
"""
import json
import time
from collections import Counter, defaultdict
from datetime import datetime
from typing import Any, Dict, List
from agent.usage_pricing import (
CanonicalUsage,
DEFAULT_PRICING,
estimate_usage_cost,
format_duration_compact,
get_pricing,
has_known_pricing,
)
_DEFAULT_PRICING = DEFAULT_PRICING
def _has_known_pricing(model_name: str, provider: str = None, base_url: str = None) -> bool:
"""Check if a model has known pricing (vs unknown/custom endpoint)."""
return has_known_pricing(model_name, provider=provider, base_url=base_url)
def _get_pricing(model_name: str) -> Dict[str, float]:
"""Look up pricing for a model. Uses fuzzy matching on model name.
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models
we can't assume costs for self-hosted endpoints, local inference, etc.
"""
return get_pricing(model_name)
def _estimate_cost(
session_or_model: Dict[str, Any] | str,
input_tokens: int = 0,
output_tokens: int = 0,
*,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
provider: str = None,
base_url: str = None,
) -> tuple[float, str]:
"""Estimate the USD cost for a session row or a model/token tuple."""
if isinstance(session_or_model, dict):
session = session_or_model
model = session.get("model") or ""
usage = CanonicalUsage(
input_tokens=session.get("input_tokens") or 0,
output_tokens=session.get("output_tokens") or 0,
cache_read_tokens=session.get("cache_read_tokens") or 0,
cache_write_tokens=session.get("cache_write_tokens") or 0,
)
provider = session.get("billing_provider")
base_url = session.get("billing_base_url")
else:
model = session_or_model or ""
usage = CanonicalUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
)
result = estimate_usage_cost(
model,
usage,
provider=provider,
base_url=base_url,
)
return float(result.amount_usd or 0.0), result.status
def _format_duration(seconds: float) -> str:
"""Format seconds into a human-readable duration string."""
return format_duration_compact(seconds)
def _bar_chart(values: List[int], max_width: int = 20) -> List[str]:
"""Create simple horizontal bar chart strings from values."""
peak = max(values) if values else 1
if peak == 0:
return ["" for _ in values]
return ["" * max(1, int(v / peak * max_width)) if v > 0 else "" for v in values]
class InsightsEngine:
"""
Analyzes session history and produces usage insights.
Works directly with a SessionDB instance (or raw sqlite3 connection)
to query session and message data.
"""
def __init__(self, db):
"""
Initialize with a SessionDB instance.
Args:
db: A SessionDB instance (from hermes_state.py)
"""
self.db = db
self._conn = db._conn
def generate(self, days: int = 30, source: str = None) -> Dict[str, Any]:
"""
Generate a complete insights report.
Args:
days: Number of days to look back (default: 30)
source: Optional filter by source platform
Returns:
Dict with all computed insights
"""
cutoff = time.time() - (days * 86400)
# Gather raw data
sessions = self._get_sessions(cutoff, source)
tool_usage = self._get_tool_usage(cutoff, source)
message_stats = self._get_message_stats(cutoff, source)
if not sessions:
return {
"days": days,
"source_filter": source,
"empty": True,
"overview": {},
"models": [],
"platforms": [],
"tools": [],
"activity": {},
"top_sessions": [],
}
# Compute insights
overview = self._compute_overview(sessions, message_stats)
models = self._compute_model_breakdown(sessions)
platforms = self._compute_platform_breakdown(sessions)
tools = self._compute_tool_breakdown(tool_usage)
activity = self._compute_activity_patterns(sessions)
top_sessions = self._compute_top_sessions(sessions)
return {
"days": days,
"source_filter": source,
"empty": False,
"generated_at": time.time(),
"overview": overview,
"models": models,
"platforms": platforms,
"tools": tools,
"activity": activity,
"top_sessions": top_sessions,
}
# =========================================================================
# Data gathering (SQL queries)
# =========================================================================
# Columns we actually need (skip system_prompt, model_config blobs)
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
"message_count, tool_call_count, input_tokens, output_tokens, "
"cache_read_tokens, cache_write_tokens, billing_provider, "
"billing_base_url, billing_mode, estimated_cost_usd, "
"actual_cost_usd, cost_status, cost_source")
# Pre-computed query strings — f-string evaluated once at class definition,
# not at runtime, so no user-controlled value can alter the query structure.
_GET_SESSIONS_WITH_SOURCE = (
f"SELECT {_SESSION_COLS} FROM sessions"
" WHERE started_at >= ? AND source = ?"
" ORDER BY started_at DESC"
)
_GET_SESSIONS_ALL = (
f"SELECT {_SESSION_COLS} FROM sessions"
" WHERE started_at >= ?"
" ORDER BY started_at DESC"
)
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
"""Fetch sessions within the time window."""
if source:
cursor = self._conn.execute(self._GET_SESSIONS_WITH_SOURCE, (cutoff, source))
else:
cursor = self._conn.execute(self._GET_SESSIONS_ALL, (cutoff,))
return [dict(row) for row in cursor.fetchall()]
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
"""Get tool call counts from messages.
Uses two sources:
1. tool_name column on 'tool' role messages (set by gateway)
2. tool_calls JSON on 'assistant' role messages (covers CLI where
tool_name is not populated on tool responses)
"""
tool_counts = Counter()
# Source 1: explicit tool_name on tool response messages
if source:
cursor = self._conn.execute(
"""SELECT m.tool_name, COUNT(*) as count
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?
AND m.role = 'tool' AND m.tool_name IS NOT NULL
GROUP BY m.tool_name
ORDER BY count DESC""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
"""SELECT m.tool_name, COUNT(*) as count
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?
AND m.role = 'tool' AND m.tool_name IS NOT NULL
GROUP BY m.tool_name
ORDER BY count DESC""",
(cutoff,),
)
for row in cursor.fetchall():
tool_counts[row["tool_name"]] += row["count"]
# Source 2: extract from tool_calls JSON on assistant messages
# (covers CLI sessions where tool_name is NULL on tool responses)
if source:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff, source),
)
else:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff,),
)
tool_calls_counts = Counter()
for row in cursor2.fetchall():
try:
calls = row["tool_calls"]
if isinstance(calls, str):
calls = json.loads(calls)
if isinstance(calls, list):
for call in calls:
func = call.get("function", {}) if isinstance(call, dict) else {}
name = func.get("name")
if name:
tool_calls_counts[name] += 1
except (json.JSONDecodeError, TypeError, AttributeError):
continue
# Merge: prefer tool_name source, supplement with tool_calls source
# for tools not already counted
if not tool_counts and tool_calls_counts:
# No tool_name data at all — use tool_calls exclusively
tool_counts = tool_calls_counts
elif tool_counts and tool_calls_counts:
# Both sources have data — use whichever has the higher count per tool
# (they may overlap, so take the max to avoid double-counting)
all_tools = set(tool_counts) | set(tool_calls_counts)
merged = Counter()
for tool in all_tools:
merged[tool] = max(tool_counts.get(tool, 0), tool_calls_counts.get(tool, 0))
tool_counts = merged
# Convert to the expected format
return [
{"tool_name": name, "count": count}
for name, count in tool_counts.most_common()
]
def _get_message_stats(self, cutoff: float, source: str = None) -> Dict:
"""Get aggregate message statistics."""
if source:
cursor = self._conn.execute(
"""SELECT
COUNT(*) as total_messages,
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
"""SELECT
COUNT(*) as total_messages,
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?""",
(cutoff,),
)
row = cursor.fetchone()
return dict(row) if row else {
"total_messages": 0, "user_messages": 0,
"assistant_messages": 0, "tool_messages": 0,
}
# =========================================================================
# Computation
# =========================================================================
def _compute_overview(self, sessions: List[Dict], message_stats: Dict) -> Dict:
"""Compute high-level overview statistics."""
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
total_cache_read = sum(s.get("cache_read_tokens") or 0 for s in sessions)
total_cache_write = sum(s.get("cache_write_tokens") or 0 for s in sessions)
total_tokens = total_input + total_output + total_cache_read + total_cache_write
total_tool_calls = sum(s.get("tool_call_count") or 0 for s in sessions)
total_messages = sum(s.get("message_count") or 0 for s in sessions)
# Cost estimation (weighted by model)
total_cost = 0.0
actual_cost = 0.0
models_with_pricing = set()
models_without_pricing = set()
unknown_cost_sessions = 0
included_cost_sessions = 0
for s in sessions:
model = s.get("model") or ""
estimated, status = _estimate_cost(s)
total_cost += estimated
actual_cost += s.get("actual_cost_usd") or 0.0
display = model.split("/")[-1] if "/" in model else (model or "unknown")
if status == "included":
included_cost_sessions += 1
elif status == "unknown":
unknown_cost_sessions += 1
if _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url")):
models_with_pricing.add(display)
else:
models_without_pricing.add(display)
# Session duration stats (guard against negative durations from clock drift)
durations = []
for s in sessions:
start = s.get("started_at")
end = s.get("ended_at")
if start and end and end > start:
durations.append(end - start)
total_hours = sum(durations) / 3600 if durations else 0
avg_duration = sum(durations) / len(durations) if durations else 0
# Earliest and latest session
started_timestamps = [s["started_at"] for s in sessions if s.get("started_at")]
date_range_start = min(started_timestamps) if started_timestamps else None
date_range_end = max(started_timestamps) if started_timestamps else None
return {
"total_sessions": len(sessions),
"total_messages": total_messages,
"total_tool_calls": total_tool_calls,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_cache_read_tokens": total_cache_read,
"total_cache_write_tokens": total_cache_write,
"total_tokens": total_tokens,
"estimated_cost": total_cost,
"actual_cost": actual_cost,
"total_hours": total_hours,
"avg_session_duration": avg_duration,
"avg_messages_per_session": total_messages / len(sessions) if sessions else 0,
"avg_tokens_per_session": total_tokens / len(sessions) if sessions else 0,
"user_messages": message_stats.get("user_messages") or 0,
"assistant_messages": message_stats.get("assistant_messages") or 0,
"tool_messages": message_stats.get("tool_messages") or 0,
"date_range_start": date_range_start,
"date_range_end": date_range_end,
"models_with_pricing": sorted(models_with_pricing),
"models_without_pricing": sorted(models_without_pricing),
"unknown_cost_sessions": unknown_cost_sessions,
"included_cost_sessions": included_cost_sessions,
}
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
"""Break down usage by model."""
model_data = defaultdict(lambda: {
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
"cache_read_tokens": 0, "cache_write_tokens": 0,
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
})
for s in sessions:
model = s.get("model") or "unknown"
# Normalize: strip provider prefix for display
display_model = model.split("/")[-1] if "/" in model else model
d = model_data[display_model]
d["sessions"] += 1
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
cache_read = s.get("cache_read_tokens") or 0
cache_write = s.get("cache_write_tokens") or 0
d["input_tokens"] += inp
d["output_tokens"] += out
d["cache_read_tokens"] += cache_read
d["cache_write_tokens"] += cache_write
d["total_tokens"] += inp + out + cache_read + cache_write
d["tool_calls"] += s.get("tool_call_count") or 0
estimate, status = _estimate_cost(s)
d["cost"] += estimate
d["has_pricing"] = _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url"))
d["cost_status"] = status
result = [
{"model": model, **data}
for model, data in model_data.items()
]
# Sort by tokens first, fall back to session count when tokens are 0
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
return result
def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]:
"""Break down usage by platform/source."""
platform_data = defaultdict(lambda: {
"sessions": 0, "messages": 0, "input_tokens": 0,
"output_tokens": 0, "cache_read_tokens": 0,
"cache_write_tokens": 0, "total_tokens": 0, "tool_calls": 0,
})
for s in sessions:
source = s.get("source") or "unknown"
d = platform_data[source]
d["sessions"] += 1
d["messages"] += s.get("message_count") or 0
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
cache_read = s.get("cache_read_tokens") or 0
cache_write = s.get("cache_write_tokens") or 0
d["input_tokens"] += inp
d["output_tokens"] += out
d["cache_read_tokens"] += cache_read
d["cache_write_tokens"] += cache_write
d["total_tokens"] += inp + out + cache_read + cache_write
d["tool_calls"] += s.get("tool_call_count") or 0
result = [
{"platform": platform, **data}
for platform, data in platform_data.items()
]
result.sort(key=lambda x: x["sessions"], reverse=True)
return result
def _compute_tool_breakdown(self, tool_usage: List[Dict]) -> List[Dict]:
"""Process tool usage data into a ranked list with percentages."""
total_calls = sum(t["count"] for t in tool_usage) if tool_usage else 0
result = []
for t in tool_usage:
pct = (t["count"] / total_calls * 100) if total_calls else 0
result.append({
"tool": t["tool_name"],
"count": t["count"],
"percentage": pct,
})
return result
def _compute_activity_patterns(self, sessions: List[Dict]) -> Dict:
"""Analyze activity patterns by day of week and hour."""
day_counts = Counter() # 0=Monday ... 6=Sunday
hour_counts = Counter()
daily_counts = Counter() # date string -> count
for s in sessions:
ts = s.get("started_at")
if not ts:
continue
dt = datetime.fromtimestamp(ts)
day_counts[dt.weekday()] += 1
hour_counts[dt.hour] += 1
daily_counts[dt.strftime("%Y-%m-%d")] += 1
day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
day_breakdown = [
{"day": day_names[i], "count": day_counts.get(i, 0)}
for i in range(7)
]
hour_breakdown = [
{"hour": i, "count": hour_counts.get(i, 0)}
for i in range(24)
]
# Busiest day and hour
busiest_day = max(day_breakdown, key=lambda x: x["count"]) if day_breakdown else None
busiest_hour = max(hour_breakdown, key=lambda x: x["count"]) if hour_breakdown else None
# Active days (days with at least one session)
active_days = len(daily_counts)
# Streak calculation
if daily_counts:
all_dates = sorted(daily_counts.keys())
current_streak = 1
max_streak = 1
for i in range(1, len(all_dates)):
d1 = datetime.strptime(all_dates[i - 1], "%Y-%m-%d")
d2 = datetime.strptime(all_dates[i], "%Y-%m-%d")
if (d2 - d1).days == 1:
current_streak += 1
max_streak = max(max_streak, current_streak)
else:
current_streak = 1
else:
max_streak = 0
return {
"by_day": day_breakdown,
"by_hour": hour_breakdown,
"busiest_day": busiest_day,
"busiest_hour": busiest_hour,
"active_days": active_days,
"max_streak": max_streak,
}
def _compute_top_sessions(self, sessions: List[Dict]) -> List[Dict]:
"""Find notable sessions (longest, most messages, most tokens)."""
top = []
# Longest by duration
sessions_with_duration = [
s for s in sessions
if s.get("started_at") and s.get("ended_at")
]
if sessions_with_duration:
longest = max(
sessions_with_duration,
key=lambda s: (s["ended_at"] - s["started_at"]),
)
dur = longest["ended_at"] - longest["started_at"]
top.append({
"label": "Longest session",
"session_id": longest["id"][:16],
"value": _format_duration(dur),
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
})
# Most messages
most_msgs = max(sessions, key=lambda s: s.get("message_count") or 0)
if (most_msgs.get("message_count") or 0) > 0:
top.append({
"label": "Most messages",
"session_id": most_msgs["id"][:16],
"value": f"{most_msgs['message_count']} msgs",
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d") if most_msgs.get("started_at") else "?",
})
# Most tokens
most_tokens = max(
sessions,
key=lambda s: (s.get("input_tokens") or 0) + (s.get("output_tokens") or 0),
)
token_total = (most_tokens.get("input_tokens") or 0) + (most_tokens.get("output_tokens") or 0)
if token_total > 0:
top.append({
"label": "Most tokens",
"session_id": most_tokens["id"][:16],
"value": f"{token_total:,} tokens",
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d") if most_tokens.get("started_at") else "?",
})
# Most tool calls
most_tools = max(sessions, key=lambda s: s.get("tool_call_count") or 0)
if (most_tools.get("tool_call_count") or 0) > 0:
top.append({
"label": "Most tool calls",
"session_id": most_tools["id"][:16],
"value": f"{most_tools['tool_call_count']} calls",
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d") if most_tools.get("started_at") else "?",
})
return top
# =========================================================================
# Formatting
# =========================================================================
def format_terminal(self, report: Dict) -> str:
"""Format the insights report for terminal display (CLI)."""
if report.get("empty"):
days = report.get("days", 30)
src = f" (source: {report['source_filter']})" if report.get("source_filter") else ""
return f" No sessions found in the last {days} days{src}."
lines = []
o = report["overview"]
days = report["days"]
src_filter = report.get("source_filter")
# Header
lines.append("")
lines.append(" ╔══════════════════════════════════════════════════════════╗")
lines.append(" ║ 📊 Hermes Insights ║")
period_label = f"Last {days} days"
if src_filter:
period_label += f" ({src_filter})"
padding = 58 - len(period_label) - 2
left_pad = padding // 2
right_pad = padding - left_pad
lines.append(f"{' ' * left_pad} {period_label} {' ' * right_pad}")
lines.append(" ╚══════════════════════════════════════════════════════════╝")
lines.append("")
# Date range
if o.get("date_range_start") and o.get("date_range_end"):
start_str = datetime.fromtimestamp(o["date_range_start"]).strftime("%b %d, %Y")
end_str = datetime.fromtimestamp(o["date_range_end"]).strftime("%b %d, %Y")
lines.append(f" Period: {start_str}{end_str}")
lines.append("")
# Overview
lines.append(" 📋 Overview")
lines.append(" " + "" * 56)
lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}")
lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}")
lines.append(f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}")
cost_str = f"${o['estimated_cost']:.2f}"
if o.get("models_without_pricing"):
cost_str += " *"
lines.append(f" Total tokens: {o['total_tokens']:<12,} Est. cost: {cost_str}")
if o["total_hours"] > 0:
lines.append(f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}")
lines.append(f" Avg msgs/session: {o['avg_messages_per_session']:.1f}")
lines.append("")
# Model breakdown
if report["models"]:
lines.append(" 🤖 Models Used")
lines.append(" " + "" * 56)
lines.append(f" {'Model':<30} {'Sessions':>8} {'Tokens':>12} {'Cost':>8}")
for m in report["models"]:
model_name = m["model"][:28]
if m.get("has_pricing"):
cost_cell = f"${m['cost']:>6.2f}"
else:
cost_cell = " N/A"
lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}")
if o.get("models_without_pricing"):
lines.append(f" * Cost N/A for custom/self-hosted models")
lines.append("")
# Platform breakdown
if len(report["platforms"]) > 1 or (report["platforms"] and report["platforms"][0]["platform"] != "cli"):
lines.append(" 📱 Platforms")
lines.append(" " + "" * 56)
lines.append(f" {'Platform':<14} {'Sessions':>8} {'Messages':>10} {'Tokens':>14}")
for p in report["platforms"]:
lines.append(f" {p['platform']:<14} {p['sessions']:>8} {p['messages']:>10,} {p['total_tokens']:>14,}")
lines.append("")
# Tool usage
if report["tools"]:
lines.append(" 🔧 Top Tools")
lines.append(" " + "" * 56)
lines.append(f" {'Tool':<28} {'Calls':>8} {'%':>8}")
for t in report["tools"][:15]: # Top 15
lines.append(f" {t['tool']:<28} {t['count']:>8,} {t['percentage']:>7.1f}%")
if len(report["tools"]) > 15:
lines.append(f" ... and {len(report['tools']) - 15} more tools")
lines.append("")
# Activity patterns
act = report.get("activity", {})
if act.get("by_day"):
lines.append(" 📅 Activity Patterns")
lines.append(" " + "" * 56)
# Day of week chart
day_values = [d["count"] for d in act["by_day"]]
bars = _bar_chart(day_values, max_width=15)
for i, d in enumerate(act["by_day"]):
bar = bars[i]
lines.append(f" {d['day']} {bar:<15} {d['count']}")
lines.append("")
# Peak hours (show top 5 busiest hours)
busy_hours = sorted(act["by_hour"], key=lambda x: x["count"], reverse=True)
busy_hours = [h for h in busy_hours if h["count"] > 0][:5]
if busy_hours:
hour_strs = []
for h in busy_hours:
hr = h["hour"]
ampm = "AM" if hr < 12 else "PM"
display_hr = hr % 12 or 12
hour_strs.append(f"{display_hr}{ampm} ({h['count']})")
lines.append(f" Peak hours: {', '.join(hour_strs)}")
if act.get("active_days"):
lines.append(f" Active days: {act['active_days']}")
if act.get("max_streak") and act["max_streak"] > 1:
lines.append(f" Best streak: {act['max_streak']} consecutive days")
lines.append("")
# Notable sessions
if report.get("top_sessions"):
lines.append(" 🏆 Notable Sessions")
lines.append(" " + "" * 56)
for ts in report["top_sessions"]:
lines.append(f" {ts['label']:<20} {ts['value']:<18} ({ts['date']}, {ts['session_id']})")
lines.append("")
return "\n".join(lines)
def format_gateway(self, report: Dict) -> str:
"""Format the insights report for gateway/messaging (shorter)."""
if report.get("empty"):
days = report.get("days", 30)
return f"No sessions found in the last {days} days."
lines = []
o = report["overview"]
days = report["days"]
lines.append(f"📊 **Hermes Insights** — Last {days} days\n")
# Overview
lines.append(f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}")
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})")
cost_note = ""
if o.get("models_without_pricing"):
cost_note = " _(excludes custom/self-hosted models)_"
lines.append(f"**Est. cost:** ${o['estimated_cost']:.2f}{cost_note}")
if o["total_hours"] > 0:
lines.append(f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}")
lines.append("")
# Models (top 5)
if report["models"]:
lines.append("**🤖 Models:**")
for m in report["models"][:5]:
cost_str = f"${m['cost']:.2f}" if m.get("has_pricing") else "N/A"
lines.append(f" {m['model'][:25]}{m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}")
lines.append("")
# Platforms (if multi-platform)
if len(report["platforms"]) > 1:
lines.append("**📱 Platforms:**")
for p in report["platforms"]:
lines.append(f" {p['platform']}{p['sessions']} sessions, {p['messages']:,} msgs")
lines.append("")
# Tools (top 8)
if report["tools"]:
lines.append("**🔧 Top Tools:**")
for t in report["tools"][:8]:
lines.append(f" {t['tool']}{t['count']:,} calls ({t['percentage']:.1f}%)")
lines.append("")
# Activity summary
act = report.get("activity", {})
if act.get("busiest_day") and act.get("busiest_hour"):
hr = act["busiest_hour"]["hour"]
ampm = "AM" if hr < 12 else "PM"
display_hr = hr % 12 or 12
lines.append(f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)")
if act.get("active_days"):
lines.append(f"**Active days:** {act['active_days']}", )
if act.get("max_streak", 0) > 1:
lines.append(f"**Best streak:** {act['max_streak']} consecutive days")
return "\n".join(lines)

View file

@ -0,0 +1,897 @@
"""Model metadata, context lengths, and token estimation utilities.
Pure utility functions with no AIAgent dependency. Used by ContextCompressor
and run_agent.py for pre-flight context checks.
"""
import logging
import os
import re
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import requests
import yaml
from hermes_constants import OPENROUTER_MODELS_URL
logger = logging.getLogger(__name__)
# Provider names that can appear as a "provider:" prefix before a model ID.
# Only these are stripped — Ollama-style "model:tag" colons (e.g. "qwen3.5:27b")
# are preserved so the full model name reaches cache lookups and server queries.
_PROVIDER_PREFIXES: frozenset[str] = frozenset({
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
"zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", "deepseek",
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba",
"custom", "local",
# Common aliases
"glm", "z-ai", "z.ai", "zhipu", "github", "github-copilot",
"github-models", "kimi", "moonshot", "claude", "deep-seek",
"opencode", "zen", "go", "vercel", "kilo", "dashscope", "aliyun", "qwen",
})
_OLLAMA_TAG_PATTERN = re.compile(
r"^(\d+\.?\d*b|latest|stable|q\d|fp?\d|instruct|chat|coder|vision|text)",
re.IGNORECASE,
)
def _strip_provider_prefix(model: str) -> str:
"""Strip a recognised provider prefix from a model string.
``"local:my-model"`` ``"my-model"``
``"qwen3.5:27b"`` ``"qwen3.5:27b"`` (unchanged not a provider prefix)
``"qwen:0.5b"`` ``"qwen:0.5b"`` (unchanged Ollama model:tag)
``"deepseek:latest"`` ``"deepseek:latest"``(unchanged Ollama model:tag)
"""
if ":" not in model or model.startswith("http"):
return model
prefix, suffix = model.split(":", 1)
prefix_lower = prefix.strip().lower()
if prefix_lower in _PROVIDER_PREFIXES:
# Don't strip if suffix looks like an Ollama tag (e.g. "7b", "latest", "q4_0")
if _OLLAMA_TAG_PATTERN.match(suffix.strip()):
return model
return suffix
return model
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
_model_metadata_cache_time: float = 0
_MODEL_CACHE_TTL = 3600
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
_ENDPOINT_MODEL_CACHE_TTL = 300
# Descending tiers for context length probing when the model is unknown.
# We start at 128K (a safe default for most modern models) and step down
# on context-length errors until one works.
CONTEXT_PROBE_TIERS = [
128_000,
64_000,
32_000,
16_000,
8_000,
]
# Default context length when no detection method succeeds.
DEFAULT_FALLBACK_CONTEXT = CONTEXT_PROBE_TIERS[0]
# Thin fallback defaults — only broad model family patterns.
# These fire only when provider is unknown AND models.dev/OpenRouter/Anthropic
# all miss. Replaced the previous 80+ entry dict.
# For provider-specific context lengths, models.dev is the primary source.
DEFAULT_CONTEXT_LENGTHS = {
# Anthropic Claude 4.6 (1M context) — bare IDs only to avoid
# fuzzy-match collisions (e.g. "anthropic/claude-sonnet-4" is a
# substring of "anthropic/claude-sonnet-4.6").
# OpenRouter-prefixed models resolve via OpenRouter live API or models.dev.
"claude-opus-4-6": 1000000,
"claude-sonnet-4-6": 1000000,
"claude-opus-4.6": 1000000,
"claude-sonnet-4.6": 1000000,
# Catch-all for older Claude models (must sort after specific entries)
"claude": 200000,
# OpenAI
"gpt-4.1": 1047576,
"gpt-5": 128000,
"gpt-4": 128000,
# Google
"gemini": 1048576,
# DeepSeek
"deepseek": 128000,
# Meta
"llama": 131072,
# Qwen
"qwen": 131072,
# MiniMax
"minimax": 204800,
# GLM
"glm": 202752,
# Kimi
"kimi": 262144,
}
_CONTEXT_LENGTH_KEYS = (
"context_length",
"context_window",
"max_context_length",
"max_position_embeddings",
"max_model_len",
"max_input_tokens",
"max_sequence_length",
"max_seq_len",
"n_ctx_train",
"n_ctx",
)
_MAX_COMPLETION_KEYS = (
"max_completion_tokens",
"max_output_tokens",
"max_tokens",
)
# Local server hostnames / address patterns
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
def _normalize_base_url(base_url: str) -> str:
return (base_url or "").strip().rstrip("/")
def _is_openrouter_base_url(base_url: str) -> bool:
return "openrouter.ai" in _normalize_base_url(base_url).lower()
def _is_custom_endpoint(base_url: str) -> bool:
normalized = _normalize_base_url(base_url)
return bool(normalized) and not _is_openrouter_base_url(normalized)
_URL_TO_PROVIDER: Dict[str, str] = {
"api.openai.com": "openai",
"chatgpt.com": "openai",
"api.anthropic.com": "anthropic",
"api.z.ai": "zai",
"api.moonshot.ai": "kimi-coding",
"api.kimi.com": "kimi-coding",
"api.minimax": "minimax",
"dashscope.aliyuncs.com": "alibaba",
"dashscope-intl.aliyuncs.com": "alibaba",
"openrouter.ai": "openrouter",
"inference-api.nousresearch.com": "nous",
"api.deepseek.com": "deepseek",
"api.githubcopilot.com": "copilot",
"models.github.ai": "copilot",
}
def _infer_provider_from_url(base_url: str) -> Optional[str]:
"""Infer the models.dev provider name from a base URL.
This allows context length resolution via models.dev for custom endpoints
like DashScope (Alibaba), Z.AI, Kimi, etc. without requiring the user to
explicitly set the provider name in config.
"""
normalized = _normalize_base_url(base_url)
if not normalized:
return None
parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}")
host = parsed.netloc.lower() or parsed.path.lower()
for url_part, provider in _URL_TO_PROVIDER.items():
if url_part in host:
return provider
return None
def _is_known_provider_base_url(base_url: str) -> bool:
return _infer_provider_from_url(base_url) is not None
def is_local_endpoint(base_url: str) -> bool:
"""Return True if base_url points to a local machine (localhost / RFC-1918 / WSL)."""
normalized = _normalize_base_url(base_url)
if not normalized:
return False
url = normalized if "://" in normalized else f"http://{normalized}"
try:
parsed = urlparse(url)
host = parsed.hostname or ""
except Exception:
return False
if host in _LOCAL_HOSTS:
return True
# RFC-1918 private ranges and link-local
import ipaddress
try:
addr = ipaddress.ip_address(host)
return addr.is_private or addr.is_loopback or addr.is_link_local
except ValueError:
pass
# Bare IP that looks like a private range (e.g. 172.26.x.x for WSL)
parts = host.split(".")
if len(parts) == 4:
try:
first, second = int(parts[0]), int(parts[1])
if first == 10:
return True
if first == 172 and 16 <= second <= 31:
return True
if first == 192 and second == 168:
return True
except ValueError:
pass
return False
def detect_local_server_type(base_url: str) -> Optional[str]:
"""Detect which local server is running at base_url by probing known endpoints.
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
"""
import httpx
normalized = _normalize_base_url(base_url)
server_url = normalized
if server_url.endswith("/v1"):
server_url = server_url[:-3]
try:
with httpx.Client(timeout=2.0) as client:
# LM Studio exposes /api/v1/models — check first (most specific)
try:
r = client.get(f"{server_url}/api/v1/models")
if r.status_code == 200:
return "lm-studio"
except Exception:
pass
# Ollama exposes /api/tags and responds with {"models": [...]}
# LM Studio returns {"error": "Unexpected endpoint"} with status 200
# on this path, so we must verify the response contains "models".
try:
r = client.get(f"{server_url}/api/tags")
if r.status_code == 200:
try:
data = r.json()
if "models" in data:
return "ollama"
except Exception:
pass
except Exception:
pass
# llama.cpp exposes /v1/props (older builds used /props without the /v1 prefix)
try:
r = client.get(f"{server_url}/v1/props")
if r.status_code != 200:
r = client.get(f"{server_url}/props") # fallback for older builds
if r.status_code == 200 and "default_generation_settings" in r.text:
return "llamacpp"
except Exception:
pass
# vLLM: /version
try:
r = client.get(f"{server_url}/version")
if r.status_code == 200:
data = r.json()
if "version" in data:
return "vllm"
except Exception:
pass
except Exception:
pass
return None
def _iter_nested_dicts(value: Any):
if isinstance(value, dict):
yield value
for nested in value.values():
yield from _iter_nested_dicts(nested)
elif isinstance(value, list):
for item in value:
yield from _iter_nested_dicts(item)
def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]:
try:
if isinstance(value, bool):
return None
if isinstance(value, str):
value = value.strip().replace(",", "")
result = int(value)
except (TypeError, ValueError):
return None
if minimum <= result <= maximum:
return result
return None
def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]:
keyset = {key.lower() for key in keys}
for mapping in _iter_nested_dicts(payload):
for key, value in mapping.items():
if str(key).lower() not in keyset:
continue
coerced = _coerce_reasonable_int(value)
if coerced is not None:
return coerced
return None
def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]:
return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS)
def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
return _extract_first_int(payload, _MAX_COMPLETION_KEYS)
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
alias_map = {
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
"request": ("request", "request_cost"),
"cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"),
"cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"),
}
for mapping in _iter_nested_dicts(payload):
normalized = {str(key).lower(): value for key, value in mapping.items()}
if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()):
continue
pricing: Dict[str, Any] = {}
for target, aliases in alias_map.items():
for alias in aliases:
if alias in normalized and normalized[alias] not in (None, ""):
pricing[target] = normalized[alias]
break
if pricing:
return pricing
return {}
def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None:
cache[model_id] = entry
if "/" in model_id:
bare_model = model_id.split("/", 1)[1]
cache.setdefault(bare_model, entry)
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
global _model_metadata_cache, _model_metadata_cache_time
if not force_refresh and _model_metadata_cache and (time.time() - _model_metadata_cache_time) < _MODEL_CACHE_TTL:
return _model_metadata_cache
try:
response = requests.get(OPENROUTER_MODELS_URL, timeout=10)
response.raise_for_status()
data = response.json()
cache = {}
for model in data.get("data", []):
model_id = model.get("id", "")
entry = {
"context_length": model.get("context_length", 128000),
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
"name": model.get("name", model_id),
"pricing": model.get("pricing", {}),
}
_add_model_aliases(cache, model_id, entry)
canonical = model.get("canonical_slug", "")
if canonical and canonical != model_id:
_add_model_aliases(cache, canonical, entry)
_model_metadata_cache = cache
_model_metadata_cache_time = time.time()
logger.debug("Fetched metadata for %s models from OpenRouter", len(cache))
return cache
except Exception as e:
logging.warning(f"Failed to fetch model metadata from OpenRouter: {e}")
return _model_metadata_cache or {}
def fetch_endpoint_model_metadata(
base_url: str,
api_key: str = "",
force_refresh: bool = False,
) -> Dict[str, Dict[str, Any]]:
"""Fetch model metadata from an OpenAI-compatible ``/models`` endpoint.
This is used for explicit custom endpoints where hardcoded global model-name
defaults are unreliable. Results are cached in memory per base URL.
"""
normalized = _normalize_base_url(base_url)
if not normalized or _is_openrouter_base_url(normalized):
return {}
if not force_refresh:
cached = _endpoint_model_metadata_cache.get(normalized)
cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0)
if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL:
return cached
candidates = [normalized]
if normalized.endswith("/v1"):
alternate = normalized[:-3].rstrip("/")
else:
alternate = normalized + "/v1"
if alternate and alternate not in candidates:
candidates.append(alternate)
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
last_error: Optional[Exception] = None
for candidate in candidates:
url = candidate.rstrip("/") + "/models"
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
payload = response.json()
cache: Dict[str, Dict[str, Any]] = {}
for model in payload.get("data", []):
if not isinstance(model, dict):
continue
model_id = model.get("id")
if not model_id:
continue
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
context_length = _extract_context_length(model)
if context_length is not None:
entry["context_length"] = context_length
max_completion_tokens = _extract_max_completion_tokens(model)
if max_completion_tokens is not None:
entry["max_completion_tokens"] = max_completion_tokens
pricing = _extract_pricing(model)
if pricing:
entry["pricing"] = pricing
_add_model_aliases(cache, model_id, entry)
# If this is a llama.cpp server, query /props for actual allocated context
is_llamacpp = any(
m.get("owned_by") == "llamacpp"
for m in payload.get("data", []) if isinstance(m, dict)
)
if is_llamacpp:
try:
# Try /v1/props first (current llama.cpp); fall back to /props for older builds
base = candidate.rstrip("/").replace("/v1", "")
props_resp = requests.get(base + "/v1/props", headers=headers, timeout=5)
if not props_resp.ok:
props_resp = requests.get(base + "/props", headers=headers, timeout=5)
if props_resp.ok:
props = props_resp.json()
gen_settings = props.get("default_generation_settings", {})
n_ctx = gen_settings.get("n_ctx")
model_alias = props.get("model_alias", "")
if n_ctx and model_alias and model_alias in cache:
cache[model_alias]["context_length"] = n_ctx
except Exception:
pass
_endpoint_model_metadata_cache[normalized] = cache
_endpoint_model_metadata_cache_time[normalized] = time.time()
return cache
except Exception as exc:
last_error = exc
if last_error:
logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error)
_endpoint_model_metadata_cache[normalized] = {}
_endpoint_model_metadata_cache_time[normalized] = time.time()
return {}
def _get_context_cache_path() -> Path:
"""Return path to the persistent context length cache file."""
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
return hermes_home / "context_length_cache.yaml"
def _load_context_cache() -> Dict[str, int]:
"""Load the model+provider -> context_length cache from disk."""
path = _get_context_cache_path()
if not path.exists():
return {}
try:
with open(path) as f:
data = yaml.safe_load(f) or {}
return data.get("context_lengths", {})
except Exception as e:
logger.debug("Failed to load context length cache: %s", e)
return {}
def save_context_length(model: str, base_url: str, length: int) -> None:
"""Persist a discovered context length for a model+provider combo.
Cache key is ``model@base_url`` so the same model name served from
different providers can have different limits.
"""
key = f"{model}@{base_url}"
cache = _load_context_cache()
if cache.get(key) == length:
return # already stored
cache[key] = length
path = _get_context_cache_path()
try:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
logger.info("Cached context length %s -> %s tokens", key, f"{length:,}")
except Exception as e:
logger.debug("Failed to save context length cache: %s", e)
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
"""Look up a previously discovered context length for model+provider."""
key = f"{model}@{base_url}"
cache = _load_context_cache()
return cache.get(key)
def get_next_probe_tier(current_length: int) -> Optional[int]:
"""Return the next lower probe tier, or None if already at minimum."""
for tier in CONTEXT_PROBE_TIERS:
if tier < current_length:
return tier
return None
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
"""Try to extract the actual context limit from an API error message.
Many providers include the limit in their error text, e.g.:
- "maximum context length is 32768 tokens"
- "context_length_exceeded: 131072"
- "Maximum context size 32768 exceeded"
- "model's max context length is 65536"
"""
error_lower = error_msg.lower()
# Pattern: look for numbers near context-related keywords
patterns = [
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
]
for pattern in patterns:
match = re.search(pattern, error_lower)
if match:
limit = int(match.group(1))
# Sanity check: must be a reasonable context length
if 1024 <= limit <= 10_000_000:
return limit
return None
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
Supports two forms:
- Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1"
- Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1"
(the part after the last "/" equals lookup_model)
This covers LM Studio's native API which stores models as "publisher/slug"
while users typically configure only the slug after the "local:" prefix.
"""
if candidate_id == lookup_model:
return True
# Slug match: basename of candidate equals the lookup name
if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model:
return True
return False
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
"""Query a local server for the model's context length."""
import httpx
# Strip recognised provider prefix (e.g., "local:model-name" → "model-name").
# Ollama "model:tag" colons (e.g. "qwen3.5:27b") are intentionally preserved.
model = _strip_provider_prefix(model)
# Strip /v1 suffix to get the server root
server_url = base_url.rstrip("/")
if server_url.endswith("/v1"):
server_url = server_url[:-3]
try:
server_type = detect_local_server_type(base_url)
except Exception:
server_type = None
try:
with httpx.Client(timeout=3.0) as client:
# Ollama: /api/show returns model details with context info
if server_type == "ollama":
resp = client.post(f"{server_url}/api/show", json={"name": model})
if resp.status_code == 200:
data = resp.json()
# Check model_info for context length
model_info = data.get("model_info", {})
for key, value in model_info.items():
if "context_length" in key and isinstance(value, (int, float)):
return int(value)
# Check parameters string for num_ctx
params = data.get("parameters", "")
if "num_ctx" in params:
for line in params.split("\n"):
if "num_ctx" in line:
parts = line.strip().split()
if len(parts) >= 2:
try:
return int(parts[-1])
except ValueError:
pass
# LM Studio native API: /api/v1/models returns max_context_length.
# This is more reliable than the OpenAI-compat /v1/models which
# doesn't include context window information for LM Studio servers.
# Use _model_id_matches for fuzzy matching: LM Studio stores models as
# "publisher/slug" but users configure only "slug" after "local:" prefix.
if server_type == "lm-studio":
resp = client.get(f"{server_url}/api/v1/models")
if resp.status_code == 200:
data = resp.json()
for m in data.get("models", []):
if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model):
# Prefer loaded instance context (actual runtime value)
for inst in m.get("loaded_instances", []):
cfg = inst.get("config", {})
ctx = cfg.get("context_length")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
# Fall back to max_context_length (theoretical model max)
ctx = m.get("max_context_length") or m.get("context_length")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
# LM Studio / vLLM / llama.cpp: try /v1/models/{model}
resp = client.get(f"{server_url}/v1/models/{model}")
if resp.status_code == 200:
data = resp.json()
# vLLM returns max_model_len
ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
# Try /v1/models and find the model in the list.
# Use _model_id_matches to handle "publisher/slug" vs bare "slug".
resp = client.get(f"{server_url}/v1/models")
if resp.status_code == 200:
data = resp.json()
models_list = data.get("data", [])
for m in models_list:
if _model_id_matches(m.get("id", ""), model):
ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
except Exception:
pass
return None
def _normalize_model_version(model: str) -> str:
"""Normalize version separators for matching.
Nous uses dashes: claude-opus-4-6, claude-sonnet-4-5
OpenRouter uses dots: claude-opus-4.6, claude-sonnet-4.5
Normalize both to dashes for comparison.
"""
return model.replace(".", "-")
def _query_anthropic_context_length(model: str, base_url: str, api_key: str) -> Optional[int]:
"""Query Anthropic's /v1/models endpoint for context length.
Only works with regular ANTHROPIC_API_KEY (sk-ant-api*).
OAuth tokens (sk-ant-oat*) from Claude Code return 401.
"""
if not api_key or api_key.startswith("sk-ant-oat"):
return None # OAuth tokens can't access /v1/models
try:
base = base_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
url = f"{base}/v1/models?limit=1000"
headers = {
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
}
resp = requests.get(url, headers=headers, timeout=10)
if resp.status_code != 200:
return None
data = resp.json()
for m in data.get("data", []):
if m.get("id") == model:
ctx = m.get("max_input_tokens")
if isinstance(ctx, int) and ctx > 0:
return ctx
except Exception as e:
logger.debug("Anthropic /v1/models query failed: %s", e)
return None
def _resolve_nous_context_length(model: str) -> Optional[int]:
"""Resolve Nous Portal model context length via OpenRouter metadata.
Nous model IDs are bare (e.g. 'claude-opus-4-6') while OpenRouter uses
prefixed IDs (e.g. 'anthropic/claude-opus-4.6'). Try suffix matching
with version normalization (dotdash).
"""
metadata = fetch_model_metadata() # OpenRouter cache
# Exact match first
if model in metadata:
return metadata[model].get("context_length")
normalized = _normalize_model_version(model).lower()
for or_id, entry in metadata.items():
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
if bare.lower() == model.lower() or _normalize_model_version(bare).lower() == normalized:
return entry.get("context_length")
# Partial prefix match for cases like gemini-3-flash → gemini-3-flash-preview
# Require match to be at a word boundary (followed by -, :, or end of string)
model_lower = model.lower()
for or_id, entry in metadata.items():
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
for candidate, query in [(bare.lower(), model_lower), (_normalize_model_version(bare).lower(), normalized)]:
if candidate.startswith(query) and (
len(candidate) == len(query) or candidate[len(query)] in "-:."
):
return entry.get("context_length")
return None
def get_model_context_length(
model: str,
base_url: str = "",
api_key: str = "",
config_context_length: int | None = None,
provider: str = "",
) -> int:
"""Get the context length for a model.
Resolution order:
0. Explicit config override (model.context_length or custom_providers per-model)
1. Persistent cache (previously discovered via probing)
2. Active endpoint metadata (/models for explicit custom endpoints)
3. Local server query (for local endpoints)
4. Anthropic /v1/models API (API-key users only, not OAuth)
5. OpenRouter live API metadata
6. Nous suffix-match via OpenRouter cache
7. models.dev registry lookup (provider-aware)
8. Thin hardcoded defaults (broad family patterns)
9. Default fallback (128K)
"""
# 0. Explicit config override — user knows best
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
return config_context_length
# Normalise provider-prefixed model names (e.g. "local:model-name" →
# "model-name") so cache lookups and server queries use the bare ID that
# local servers actually know about. Ollama "model:tag" colons are preserved.
model = _strip_provider_prefix(model)
# 1. Check persistent cache (model+provider)
if base_url:
cached = get_cached_context_length(model, base_url)
if cached is not None:
return cached
# 2. Active endpoint metadata for truly custom/unknown endpoints.
# Known providers (Copilot, OpenAI, Anthropic, etc.) skip this — their
# /models endpoint may report a provider-imposed limit (e.g. Copilot
# returns 128k) instead of the model's full context (400k). models.dev
# has the correct per-provider values and is checked at step 5+.
if _is_custom_endpoint(base_url) and not _is_known_provider_base_url(base_url):
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
matched = endpoint_metadata.get(model)
if not matched:
# Single-model servers: if only one model is loaded, use it
if len(endpoint_metadata) == 1:
matched = next(iter(endpoint_metadata.values()))
else:
# Fuzzy match: substring in either direction
for key, entry in endpoint_metadata.items():
if model in key or key in model:
matched = entry
break
if matched:
context_length = matched.get("context_length")
if isinstance(context_length, int):
return context_length
if not _is_known_provider_base_url(base_url):
# 3. Try querying local server directly
if is_local_endpoint(base_url):
local_ctx = _query_local_context_length(model, base_url)
if local_ctx and local_ctx > 0:
save_context_length(model, base_url, local_ctx)
return local_ctx
logger.info(
"Could not detect context length for model %r at %s"
"defaulting to %s tokens (probe-down). Set model.context_length "
"in config.yaml to override.",
model, base_url, f"{DEFAULT_FALLBACK_CONTEXT:,}",
)
return DEFAULT_FALLBACK_CONTEXT
# 4. Anthropic /v1/models API (only for regular API keys, not OAuth)
if provider == "anthropic" or (
base_url and "api.anthropic.com" in base_url
):
ctx = _query_anthropic_context_length(model, base_url or "https://api.anthropic.com", api_key)
if ctx:
return ctx
# 5. Provider-aware lookups (before generic OpenRouter cache)
# These are provider-specific and take priority over the generic OR cache,
# since the same model can have different context limits per provider
# (e.g. claude-opus-4.6 is 1M on Anthropic but 128K on GitHub Copilot).
# If provider is generic (openrouter/custom/empty), try to infer from URL.
effective_provider = provider
if not effective_provider or effective_provider in ("openrouter", "custom"):
if base_url:
inferred = _infer_provider_from_url(base_url)
if inferred:
effective_provider = inferred
if effective_provider == "nous":
ctx = _resolve_nous_context_length(model)
if ctx:
return ctx
if effective_provider:
from agent.models_dev import lookup_models_dev_context
ctx = lookup_models_dev_context(effective_provider, model)
if ctx:
return ctx
# 6. OpenRouter live API metadata (provider-unaware fallback)
metadata = fetch_model_metadata()
if model in metadata:
return metadata[model].get("context_length", 128000)
# 8. Hardcoded defaults (fuzzy match — longest key first for specificity)
# Only check `default_model in model` (is the key a substring of the input).
# The reverse (`model in default_model`) causes shorter names like
# "claude-sonnet-4" to incorrectly match "claude-sonnet-4-6" and return 1M.
model_lower = model.lower()
for default_model, length in sorted(
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
):
if default_model in model_lower:
return length
# 9. Query local server as last resort
if base_url and is_local_endpoint(base_url):
local_ctx = _query_local_context_length(model, base_url)
if local_ctx and local_ctx > 0:
save_context_length(model, base_url, local_ctx)
return local_ctx
# 10. Default fallback — 128K
return DEFAULT_FALLBACK_CONTEXT
def estimate_tokens_rough(text: str) -> int:
"""Rough token estimate (~4 chars/token) for pre-flight checks."""
if not text:
return 0
return len(text) // 4
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
"""Rough token estimate for a message list (pre-flight only)."""
total_chars = sum(len(str(msg)) for msg in messages)
return total_chars // 4

View file

@ -0,0 +1,171 @@
"""Models.dev registry integration for provider-aware context length detection.
Fetches model metadata from https://models.dev/api.json a community-maintained
database of 3800+ models across 100+ providers, including per-provider context
windows, pricing, and capabilities.
Data is cached in memory (1hr TTL) and on disk (~/.hermes/models_dev_cache.json)
to avoid cold-start network latency.
"""
import json
import logging
import os
import time
from pathlib import Path
from typing import Any, Dict, Optional
import requests
logger = logging.getLogger(__name__)
MODELS_DEV_URL = "https://models.dev/api.json"
_MODELS_DEV_CACHE_TTL = 3600 # 1 hour in-memory
# In-memory cache
_models_dev_cache: Dict[str, Any] = {}
_models_dev_cache_time: float = 0
# Provider ID mapping: Hermes provider names → models.dev provider IDs
PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
"openrouter": "openrouter",
"anthropic": "anthropic",
"zai": "zai",
"kimi-coding": "kimi-for-coding",
"minimax": "minimax",
"minimax-cn": "minimax-cn",
"deepseek": "deepseek",
"alibaba": "alibaba",
"copilot": "github-copilot",
"ai-gateway": "vercel",
"opencode-zen": "opencode",
"opencode-go": "opencode-go",
"kilocode": "kilo",
}
def _get_cache_path() -> Path:
"""Return path to disk cache file."""
env_val = os.environ.get("HERMES_HOME", "")
hermes_home = Path(env_val) if env_val else Path.home() / ".hermes"
return hermes_home / "models_dev_cache.json"
def _load_disk_cache() -> Dict[str, Any]:
"""Load models.dev data from disk cache."""
try:
cache_path = _get_cache_path()
if cache_path.exists():
with open(cache_path, encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.debug("Failed to load models.dev disk cache: %s", e)
return {}
def _save_disk_cache(data: Dict[str, Any]) -> None:
"""Save models.dev data to disk cache."""
try:
cache_path = _get_cache_path()
cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_path, "w", encoding="utf-8") as f:
json.dump(data, f, separators=(",", ":"))
except Exception as e:
logger.debug("Failed to save models.dev disk cache: %s", e)
def fetch_models_dev(force_refresh: bool = False) -> Dict[str, Any]:
"""Fetch models.dev registry. In-memory cache (1hr) + disk fallback.
Returns the full registry dict keyed by provider ID, or empty dict on failure.
"""
global _models_dev_cache, _models_dev_cache_time
# Check in-memory cache
if (
not force_refresh
and _models_dev_cache
and (time.time() - _models_dev_cache_time) < _MODELS_DEV_CACHE_TTL
):
return _models_dev_cache
# Try network fetch
try:
response = requests.get(MODELS_DEV_URL, timeout=15)
response.raise_for_status()
data = response.json()
if isinstance(data, dict) and len(data) > 0:
_models_dev_cache = data
_models_dev_cache_time = time.time()
_save_disk_cache(data)
logger.debug(
"Fetched models.dev registry: %d providers, %d total models",
len(data),
sum(len(p.get("models", {})) for p in data.values() if isinstance(p, dict)),
)
return data
except Exception as e:
logger.debug("Failed to fetch models.dev: %s", e)
# Fall back to disk cache — use a short TTL (5 min) so we retry
# the network fetch soon instead of serving stale data for a full hour.
if not _models_dev_cache:
_models_dev_cache = _load_disk_cache()
if _models_dev_cache:
_models_dev_cache_time = time.time() - _MODELS_DEV_CACHE_TTL + 300
logger.debug("Loaded models.dev from disk cache (%d providers)", len(_models_dev_cache))
return _models_dev_cache
def lookup_models_dev_context(provider: str, model: str) -> Optional[int]:
"""Look up context_length for a provider+model combo in models.dev.
Returns the context window in tokens, or None if not found.
Handles case-insensitive matching and filters out context=0 entries.
"""
mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider)
if not mdev_provider_id:
return None
data = fetch_models_dev()
provider_data = data.get(mdev_provider_id)
if not isinstance(provider_data, dict):
return None
models = provider_data.get("models", {})
if not isinstance(models, dict):
return None
# Exact match
entry = models.get(model)
if entry:
ctx = _extract_context(entry)
if ctx:
return ctx
# Case-insensitive match
model_lower = model.lower()
for mid, mdata in models.items():
if mid.lower() == model_lower:
ctx = _extract_context(mdata)
if ctx:
return ctx
return None
def _extract_context(entry: Dict[str, Any]) -> Optional[int]:
"""Extract context_length from a models.dev model entry.
Returns None for invalid/zero values (some audio/image models have context=0).
"""
if not isinstance(entry, dict):
return None
limit = entry.get("limit")
if not isinstance(limit, dict):
return None
ctx = limit.get("context")
if isinstance(ctx, (int, float)) and ctx > 0:
return int(ctx)
return None

View file

@ -0,0 +1,604 @@
"""System prompt assembly -- identity, platform hints, skills index, context files.
All functions are stateless. AIAgent._build_system_prompt() calls these to
assemble pieces, then combines them with memory and ephemeral prompts.
"""
import logging
import os
import re
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Context file scanning — detect prompt injection in AGENTS.md, .cursorrules,
# SOUL.md before they get injected into the system prompt.
# ---------------------------------------------------------------------------
_CONTEXT_THREAT_PATTERNS = [
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
(r'system\s+prompt\s+override', "sys_prompt_override"),
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
(r'<!--[^>]*(?:ignore|override|system|secret|hidden)[^>]*-->', "html_comment_injection"),
(r'<\s*div\s+style\s*=\s*["\'].*display\s*:\s*none', "hidden_div"),
(r'translate\s+.*\s+into\s+.*\s+and\s+(execute|run|eval)', "translate_execute"),
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
]
_CONTEXT_INVISIBLE_CHARS = {
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
}
def _scan_context_content(content: str, filename: str) -> str:
"""Scan context file content for injection. Returns sanitized content."""
findings = []
# Check invisible unicode
for char in _CONTEXT_INVISIBLE_CHARS:
if char in content:
findings.append(f"invisible unicode U+{ord(char):04X}")
# Check threat patterns
for pattern, pid in _CONTEXT_THREAT_PATTERNS:
if re.search(pattern, content, re.IGNORECASE):
findings.append(pid)
if findings:
logger.warning("Context file %s blocked: %s", filename, ", ".join(findings))
return f"[BLOCKED: {filename} contained potential prompt injection ({', '.join(findings)}). Content not loaded.]"
return content
def _find_git_root(start: Path) -> Optional[Path]:
"""Walk *start* and its parents looking for a ``.git`` directory.
Returns the directory containing ``.git``, or ``None`` if we hit the
filesystem root without finding one.
"""
current = start.resolve()
for parent in [current, *current.parents]:
if (parent / ".git").exists():
return parent
return None
_HERMES_MD_NAMES = (".hermes.md", "HERMES.md")
def _find_hermes_md(cwd: Path) -> Optional[Path]:
"""Discover the nearest ``.hermes.md`` or ``HERMES.md``.
Search order: *cwd* first, then each parent directory up to (and
including) the git repository root. Returns the first match, or
``None`` if nothing is found.
"""
stop_at = _find_git_root(cwd)
current = cwd.resolve()
for directory in [current, *current.parents]:
for name in _HERMES_MD_NAMES:
candidate = directory / name
if candidate.is_file():
return candidate
# Stop walking at the git root (or filesystem root).
if stop_at and directory == stop_at:
break
return None
def _strip_yaml_frontmatter(content: str) -> str:
"""Remove optional YAML frontmatter (``---`` delimited) from *content*.
The frontmatter may contain structured config (model overrides, tool
settings) that will be handled separately in a future PR. For now we
strip it so only the human-readable markdown body is injected into the
system prompt.
"""
if content.startswith("---"):
end = content.find("\n---", 3)
if end != -1:
# Skip past the closing --- and any trailing newline
body = content[end + 4:].lstrip("\n")
return body if body else content
return content
# =========================================================================
# Constants
# =========================================================================
DEFAULT_AGENT_IDENTITY = (
"You are Hermes Agent, an intelligent AI assistant created by Nous Research. "
"You are helpful, knowledgeable, and direct. You assist users with a wide "
"range of tasks including answering questions, writing and editing code, "
"analyzing information, creative work, and executing actions via your tools. "
"You communicate clearly, admit uncertainty when appropriate, and prioritize "
"being genuinely useful over being verbose unless otherwise directed below. "
"Be targeted and efficient in your exploration and investigations."
)
MEMORY_GUIDANCE = (
"You have persistent memory across sessions. Save durable facts using the memory "
"tool: user preferences, environment details, tool quirks, and stable conventions. "
"Memory is injected into every turn, so keep it compact and focused on facts that "
"will still matter later.\n"
"Prioritize what reduces future user steering — the most valuable memory is one "
"that prevents the user from having to correct or remind you again. "
"User preferences and recurring corrections matter more than procedural task details.\n"
"Do NOT save task progress, session outcomes, completed-work logs, or temporary TODO "
"state to memory; use session_search to recall those from past transcripts. "
"If you've discovered a new way to do something, solved a problem that could be "
"necessary later, save it as a skill with the skill tool."
)
SESSION_SEARCH_GUIDANCE = (
"When the user references something from a past conversation or you suspect "
"relevant cross-session context exists, use session_search to recall it before "
"asking them to repeat themselves."
)
SKILLS_GUIDANCE = (
"After completing a complex task (5+ tool calls), fixing a tricky error, "
"or discovering a non-trivial workflow, save the approach as a "
"skill with skill_manage so you can reuse it next time.\n"
"When using a skill and finding it outdated, incomplete, or wrong, "
"patch it immediately with skill_manage(action='patch') — don't wait to be asked. "
"Skills that aren't maintained become liabilities."
)
PLATFORM_HINTS = {
"whatsapp": (
"You are on a text messaging communication platform, WhatsApp. "
"Please do not use markdown as it does not render. "
"You can send media files natively: to deliver a file to the user, "
"include MEDIA:/absolute/path/to/file in your response. The file "
"will be sent as a native WhatsApp attachment — images (.jpg, .png, "
".webp) appear as photos, videos (.mp4, .mov) play inline, and other "
"files arrive as downloadable documents. You can also include image "
"URLs in markdown format ![alt](url) and they will be sent as photos."
),
"telegram": (
"You are on a text messaging communication platform, Telegram. "
"Please do not use markdown as it does not render. "
"You can send media files natively: to deliver a file to the user, "
"include MEDIA:/absolute/path/to/file in your response. Images "
"(.png, .jpg, .webp) appear as photos, audio (.ogg) sends as voice "
"bubbles, and videos (.mp4) play inline. You can also include image "
"URLs in markdown format ![alt](url) and they will be sent as native photos."
),
"discord": (
"You are in a Discord server or group chat communicating with your user. "
"You can send media files natively: include MEDIA:/absolute/path/to/file "
"in your response. Images (.png, .jpg, .webp) are sent as photo "
"attachments, audio as file attachments. You can also include image URLs "
"in markdown format ![alt](url) and they will be sent as attachments."
),
"slack": (
"You are in a Slack workspace communicating with your user. "
"You can send media files natively: include MEDIA:/absolute/path/to/file "
"in your response. Images (.png, .jpg, .webp) are uploaded as photo "
"attachments, audio as file attachments. You can also include image URLs "
"in markdown format ![alt](url) and they will be uploaded as attachments."
),
"signal": (
"You are on a text messaging communication platform, Signal. "
"Please do not use markdown as it does not render. "
"You can send media files natively: to deliver a file to the user, "
"include MEDIA:/absolute/path/to/file in your response. Images "
"(.png, .jpg, .webp) appear as photos, audio as attachments, and other "
"files arrive as downloadable documents. You can also include image "
"URLs in markdown format ![alt](url) and they will be sent as photos."
),
"email": (
"You are communicating via email. Write clear, well-structured responses "
"suitable for email. Use plain text formatting (no markdown). "
"Keep responses concise but complete. You can send file attachments — "
"include MEDIA:/absolute/path/to/file in your response. The subject line "
"is preserved for threading. Do not include greetings or sign-offs unless "
"contextually appropriate."
),
"cron": (
"You are running as a scheduled cron job. There is no user present — you "
"cannot ask questions, request clarification, or wait for follow-up. Execute "
"the task fully and autonomously, making reasonable decisions where needed. "
"Your final response is automatically delivered to the job's configured "
"destination — put the primary content directly in your response."
),
"cli": (
"You are a CLI AI Agent. Try not to use markdown but simple text "
"renderable inside a terminal."
),
"sms": (
"You are communicating via SMS. Keep responses concise and use plain text "
"only — no markdown, no formatting. SMS messages are limited to ~1600 "
"characters, so be brief and direct."
),
}
CONTEXT_FILE_MAX_CHARS = 20_000
CONTEXT_TRUNCATE_HEAD_RATIO = 0.7
CONTEXT_TRUNCATE_TAIL_RATIO = 0.2
# =========================================================================
# Skills index
# =========================================================================
def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
"""Read a SKILL.md once and return platform compatibility, frontmatter, and description.
Returns (is_compatible, frontmatter, description). On any error, returns
(True, {}, "") to err on the side of showing the skill.
"""
try:
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
raw = skill_file.read_text(encoding="utf-8")[:2000]
frontmatter, _ = _parse_frontmatter(raw)
if not skill_matches_platform(frontmatter):
return False, {}, ""
desc = ""
raw_desc = frontmatter.get("description", "")
if raw_desc:
desc = str(raw_desc).strip().strip("'\"")
if len(desc) > 60:
desc = desc[:57] + "..."
return True, frontmatter, desc
except Exception as e:
logger.debug("Failed to parse skill file %s: %s", skill_file, e)
return True, {}, ""
def _read_skill_conditions(skill_file: Path) -> dict:
"""Extract conditional activation fields from SKILL.md frontmatter."""
try:
from tools.skills_tool import _parse_frontmatter
raw = skill_file.read_text(encoding="utf-8")[:2000]
frontmatter, _ = _parse_frontmatter(raw)
hermes = frontmatter.get("metadata", {}).get("hermes", {})
return {
"fallback_for_toolsets": hermes.get("fallback_for_toolsets", []),
"requires_toolsets": hermes.get("requires_toolsets", []),
"fallback_for_tools": hermes.get("fallback_for_tools", []),
"requires_tools": hermes.get("requires_tools", []),
}
except Exception as e:
logger.debug("Failed to read skill conditions from %s: %s", skill_file, e)
return {}
def _skill_should_show(
conditions: dict,
available_tools: "set[str] | None",
available_toolsets: "set[str] | None",
) -> bool:
"""Return False if the skill's conditional activation rules exclude it."""
if available_tools is None and available_toolsets is None:
return True # No filtering info — show everything (backward compat)
at = available_tools or set()
ats = available_toolsets or set()
# fallback_for: hide when the primary tool/toolset IS available
for ts in conditions.get("fallback_for_toolsets", []):
if ts in ats:
return False
for t in conditions.get("fallback_for_tools", []):
if t in at:
return False
# requires: hide when a required tool/toolset is NOT available
for ts in conditions.get("requires_toolsets", []):
if ts not in ats:
return False
for t in conditions.get("requires_tools", []):
if t not in at:
return False
return True
def build_skills_system_prompt(
available_tools: "set[str] | None" = None,
available_toolsets: "set[str] | None" = None,
) -> str:
"""Build a compact skill index for the system prompt.
Scans ~/.hermes/skills/ for SKILL.md files grouped by category.
Includes per-skill descriptions from frontmatter so the model can
match skills by meaning, not just name.
Filters out skills incompatible with the current OS platform.
"""
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
skills_dir = hermes_home / "skills"
if not skills_dir.exists():
return ""
# Collect skills with descriptions, grouped by category.
# Each entry: (skill_name, description)
# Supports sub-categories: skills/mlops/training/axolotl/SKILL.md
# -> category "mlops/training", skill "axolotl"
# Load disabled skill names once for the entire scan
try:
from tools.skills_tool import _get_disabled_skill_names
disabled = _get_disabled_skill_names()
except Exception:
disabled = set()
skills_by_category: dict[str, list[tuple[str, str]]] = {}
for skill_file in skills_dir.rglob("SKILL.md"):
is_compatible, frontmatter, desc = _parse_skill_file(skill_file)
if not is_compatible:
continue
rel_path = skill_file.relative_to(skills_dir)
parts = rel_path.parts
if len(parts) >= 2:
skill_name = parts[-2]
category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0]
else:
category = "general"
skill_name = skill_file.parent.name
# Respect user's disabled skills config
fm_name = frontmatter.get("name", skill_name)
if fm_name in disabled or skill_name in disabled:
continue
# Skip skills whose conditional activation rules exclude them
conditions = _read_skill_conditions(skill_file)
if not _skill_should_show(conditions, available_tools, available_toolsets):
continue
skills_by_category.setdefault(category, []).append((skill_name, desc))
if not skills_by_category:
return ""
# Read category-level descriptions from DESCRIPTION.md
# Checks both the exact category path and parent directories
category_descriptions = {}
for category in skills_by_category:
cat_path = Path(category)
desc_file = skills_dir / cat_path / "DESCRIPTION.md"
if desc_file.exists():
try:
content = desc_file.read_text(encoding="utf-8")
match = re.search(r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---", content, re.MULTILINE | re.DOTALL)
if match:
category_descriptions[category] = match.group(1).strip()
except Exception as e:
logger.debug("Could not read skill description %s: %s", desc_file, e)
index_lines = []
for category in sorted(skills_by_category.keys()):
cat_desc = category_descriptions.get(category, "")
if cat_desc:
index_lines.append(f" {category}: {cat_desc}")
else:
index_lines.append(f" {category}:")
# Deduplicate and sort skills within each category
seen = set()
for name, desc in sorted(skills_by_category[category], key=lambda x: x[0]):
if name in seen:
continue
seen.add(name)
if desc:
index_lines.append(f" - {name}: {desc}")
else:
index_lines.append(f" - {name}")
return (
"## Skills (mandatory)\n"
"Before replying, scan the skills below. If one clearly matches your task, "
"load it with skill_view(name) and follow its instructions. "
"If a skill has issues, fix it with skill_manage(action='patch').\n"
"After difficult/iterative tasks, offer to save as a skill. "
"If a skill you loaded was missing steps, had wrong commands, or needed "
"pitfalls you discovered, update it before finishing.\n"
"\n"
"<available_skills>\n"
+ "\n".join(index_lines) + "\n"
"</available_skills>\n"
"\n"
"If none match, proceed normally without loading a skill."
)
# =========================================================================
# Context files (SOUL.md, AGENTS.md, .cursorrules)
# =========================================================================
def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE_MAX_CHARS) -> str:
"""Head/tail truncation with a marker in the middle."""
if len(content) <= max_chars:
return content
head_chars = int(max_chars * CONTEXT_TRUNCATE_HEAD_RATIO)
tail_chars = int(max_chars * CONTEXT_TRUNCATE_TAIL_RATIO)
head = content[:head_chars]
tail = content[-tail_chars:]
marker = f"\n\n[...truncated {filename}: kept {head_chars}+{tail_chars} of {len(content)} chars. Use file tools to read the full file.]\n\n"
return head + marker + tail
def load_soul_md() -> Optional[str]:
"""Load SOUL.md from HERMES_HOME and return its content, or None.
Used as the agent identity (slot #1 in the system prompt). When this
returns content, ``build_context_files_prompt`` should be called with
``skip_soul=True`` so SOUL.md isn't injected twice.
"""
try:
from hermes_cli.config import ensure_hermes_home
ensure_hermes_home()
except Exception as e:
logger.debug("Could not ensure HERMES_HOME before loading SOUL.md: %s", e)
soul_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "SOUL.md"
if not soul_path.exists():
return None
try:
content = soul_path.read_text(encoding="utf-8").strip()
if not content:
return None
content = _scan_context_content(content, "SOUL.md")
content = _truncate_content(content, "SOUL.md")
return content
except Exception as e:
logger.debug("Could not read SOUL.md from %s: %s", soul_path, e)
return None
def _load_hermes_md(cwd_path: Path) -> str:
""".hermes.md / HERMES.md — walk to git root."""
hermes_md_path = _find_hermes_md(cwd_path)
if not hermes_md_path:
return ""
try:
content = hermes_md_path.read_text(encoding="utf-8").strip()
if not content:
return ""
content = _strip_yaml_frontmatter(content)
rel = hermes_md_path.name
try:
rel = str(hermes_md_path.relative_to(cwd_path))
except ValueError:
pass
content = _scan_context_content(content, rel)
result = f"## {rel}\n\n{content}"
return _truncate_content(result, ".hermes.md")
except Exception as e:
logger.debug("Could not read %s: %s", hermes_md_path, e)
return ""
def _load_agents_md(cwd_path: Path) -> str:
"""AGENTS.md — hierarchical, recursive directory walk."""
top_level_agents = None
for name in ["AGENTS.md", "agents.md"]:
candidate = cwd_path / name
if candidate.exists():
top_level_agents = candidate
break
if not top_level_agents:
return ""
agents_files = []
for root, dirs, files in os.walk(cwd_path):
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')]
for f in files:
if f.lower() == "agents.md":
agents_files.append(Path(root) / f)
agents_files.sort(key=lambda p: len(p.parts))
total_content = ""
for agents_path in agents_files:
try:
content = agents_path.read_text(encoding="utf-8").strip()
if content:
rel_path = agents_path.relative_to(cwd_path)
content = _scan_context_content(content, str(rel_path))
total_content += f"## {rel_path}\n\n{content}\n\n"
except Exception as e:
logger.debug("Could not read %s: %s", agents_path, e)
if not total_content:
return ""
return _truncate_content(total_content, "AGENTS.md")
def _load_claude_md(cwd_path: Path) -> str:
"""CLAUDE.md / claude.md — cwd only."""
for name in ["CLAUDE.md", "claude.md"]:
candidate = cwd_path / name
if candidate.exists():
try:
content = candidate.read_text(encoding="utf-8").strip()
if content:
content = _scan_context_content(content, name)
result = f"## {name}\n\n{content}"
return _truncate_content(result, "CLAUDE.md")
except Exception as e:
logger.debug("Could not read %s: %s", candidate, e)
return ""
def _load_cursorrules(cwd_path: Path) -> str:
""".cursorrules + .cursor/rules/*.mdc — cwd only."""
cursorrules_content = ""
cursorrules_file = cwd_path / ".cursorrules"
if cursorrules_file.exists():
try:
content = cursorrules_file.read_text(encoding="utf-8").strip()
if content:
content = _scan_context_content(content, ".cursorrules")
cursorrules_content += f"## .cursorrules\n\n{content}\n\n"
except Exception as e:
logger.debug("Could not read .cursorrules: %s", e)
cursor_rules_dir = cwd_path / ".cursor" / "rules"
if cursor_rules_dir.exists() and cursor_rules_dir.is_dir():
mdc_files = sorted(cursor_rules_dir.glob("*.mdc"))
for mdc_file in mdc_files:
try:
content = mdc_file.read_text(encoding="utf-8").strip()
if content:
content = _scan_context_content(content, f".cursor/rules/{mdc_file.name}")
cursorrules_content += f"## .cursor/rules/{mdc_file.name}\n\n{content}\n\n"
except Exception as e:
logger.debug("Could not read %s: %s", mdc_file, e)
if not cursorrules_content:
return ""
return _truncate_content(cursorrules_content, ".cursorrules")
def build_context_files_prompt(cwd: Optional[str] = None, skip_soul: bool = False) -> str:
"""Discover and load context files for the system prompt.
Priority (first found wins only ONE project context type is loaded):
1. .hermes.md / HERMES.md (walk to git root)
2. AGENTS.md / agents.md (recursive directory walk)
3. CLAUDE.md / claude.md (cwd only)
4. .cursorrules / .cursor/rules/*.mdc (cwd only)
SOUL.md from HERMES_HOME is independent and always included when present.
Each context source is capped at 20,000 chars.
When *skip_soul* is True, SOUL.md is not included here (it was already
loaded via ``load_soul_md()`` for the identity slot).
"""
if cwd is None:
cwd = os.getcwd()
cwd_path = Path(cwd).resolve()
sections = []
# Priority-based project context: first match wins
project_context = (
_load_hermes_md(cwd_path)
or _load_agents_md(cwd_path)
or _load_claude_md(cwd_path)
or _load_cursorrules(cwd_path)
)
if project_context:
sections.append(project_context)
# SOUL.md from HERMES_HOME only — skip when already loaded as identity
if not skip_soul:
soul_content = load_soul_md()
if soul_content:
sections.append(soul_content)
if not sections:
return ""
return "# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n" + "\n".join(sections)

View file

@ -0,0 +1,72 @@
"""Anthropic prompt caching (system_and_3 strategy).
Reduces input token costs by ~75% on multi-turn conversations by caching
the conversation prefix. Uses 4 cache_control breakpoints (Anthropic max):
1. System prompt (stable across all turns)
2-4. Last 3 non-system messages (rolling window)
Pure functions -- no class state, no AIAgent dependency.
"""
import copy
from typing import Any, Dict, List
def _apply_cache_marker(msg: dict, cache_marker: dict, native_anthropic: bool = False) -> None:
"""Add cache_control to a single message, handling all format variations."""
role = msg.get("role", "")
content = msg.get("content")
if role == "tool":
if native_anthropic:
msg["cache_control"] = cache_marker
return
if content is None or content == "":
msg["cache_control"] = cache_marker
return
if isinstance(content, str):
msg["content"] = [
{"type": "text", "text": content, "cache_control": cache_marker}
]
return
if isinstance(content, list) and content:
last = content[-1]
if isinstance(last, dict):
last["cache_control"] = cache_marker
def apply_anthropic_cache_control(
api_messages: List[Dict[str, Any]],
cache_ttl: str = "5m",
native_anthropic: bool = False,
) -> List[Dict[str, Any]]:
"""Apply system_and_3 caching strategy to messages for Anthropic models.
Places up to 4 cache_control breakpoints: system prompt + last 3 non-system messages.
Returns:
Deep copy of messages with cache_control breakpoints injected.
"""
messages = copy.deepcopy(api_messages)
if not messages:
return messages
marker = {"type": "ephemeral"}
if cache_ttl == "1h":
marker["ttl"] = "1h"
breakpoints_used = 0
if messages[0].get("role") == "system":
_apply_cache_marker(messages[0], marker, native_anthropic=native_anthropic)
breakpoints_used += 1
remaining = 4 - breakpoints_used
non_sys = [i for i in range(len(messages)) if messages[i].get("role") != "system"]
for idx in non_sys[-remaining:]:
_apply_cache_marker(messages[idx], marker, native_anthropic=native_anthropic)
return messages

165
hermes_code/agent/redact.py Normal file
View file

@ -0,0 +1,165 @@
"""Regex-based secret redaction for logs and tool output.
Applies pattern matching to mask API keys, tokens, and credentials
before they reach log files, verbose output, or gateway logs.
Short tokens (< 18 chars) are fully masked. Longer tokens preserve
the first 6 and last 4 characters for debuggability.
"""
import logging
import os
import re
logger = logging.getLogger(__name__)
# Known API key prefixes -- match the prefix + contiguous token chars
_PREFIX_PATTERNS = [
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
r"AIza[A-Za-z0-9_-]{30,}", # Google API keys
r"pplx-[A-Za-z0-9]{10,}", # Perplexity
r"fal_[A-Za-z0-9_-]{10,}", # Fal.ai
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
r"npm_[A-Za-z0-9]{10,}", # npm access token
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
]
# ENV assignment patterns: KEY=value where KEY contains a secret-like name
_SECRET_ENV_NAMES = r"(?:API_?KEY|TOKEN|SECRET|PASSWORD|PASSWD|CREDENTIAL|AUTH)"
_ENV_ASSIGN_RE = re.compile(
rf"([A-Z_]*{_SECRET_ENV_NAMES}[A-Z_]*)\s*=\s*(['\"]?)(\S+)\2",
re.IGNORECASE,
)
# JSON field patterns: "apiKey": "value", "token": "value", etc.
_JSON_KEY_NAMES = r"(?:api_?[Kk]ey|token|secret|password|access_token|refresh_token|auth_token|bearer|secret_value|raw_secret|secret_input|key_material)"
_JSON_FIELD_RE = re.compile(
rf'("{_JSON_KEY_NAMES}")\s*:\s*"([^"]+)"',
re.IGNORECASE,
)
# Authorization headers
_AUTH_HEADER_RE = re.compile(
r"(Authorization:\s*Bearer\s+)(\S+)",
re.IGNORECASE,
)
# Telegram bot tokens: bot<digits>:<token> or <digits>:<token>,
# where token part is restricted to [-A-Za-z0-9_] and length >= 30
_TELEGRAM_RE = re.compile(
r"(bot)?(\d{8,}):([-A-Za-z0-9_]{30,})",
)
# Private key blocks: -----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY-----
_PRIVATE_KEY_RE = re.compile(
r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----"
)
# Database connection strings: protocol://user:PASSWORD@host
# Catches postgres, mysql, mongodb, redis, amqp URLs and redacts the password
_DB_CONNSTR_RE = re.compile(
r"((?:postgres(?:ql)?|mysql|mongodb(?:\+srv)?|redis|amqp)://[^:]+:)([^@]+)(@)",
re.IGNORECASE,
)
# E.164 phone numbers: +<country><number>, 7-15 digits
# Negative lookahead prevents matching hex strings or identifiers
_SIGNAL_PHONE_RE = re.compile(r"(\+[1-9]\d{6,14})(?![A-Za-z0-9])")
# Compile known prefix patterns into one alternation
_PREFIX_RE = re.compile(
r"(?<![A-Za-z0-9_-])(" + "|".join(_PREFIX_PATTERNS) + r")(?![A-Za-z0-9_-])"
)
def _mask_token(token: str) -> str:
"""Mask a token, preserving prefix for long tokens."""
if len(token) < 18:
return "***"
return f"{token[:6]}...{token[-4:]}"
def redact_sensitive_text(text: str) -> str:
"""Apply all redaction patterns to a block of text.
Safe to call on any string -- non-matching text passes through unchanged.
Disabled when security.redact_secrets is false in config.yaml.
"""
if text is None:
return None
if not isinstance(text, str):
text = str(text)
if not text:
return text
if os.getenv("HERMES_REDACT_SECRETS", "").lower() in ("0", "false", "no", "off"):
return text
# Known prefixes (sk-, ghp_, etc.)
text = _PREFIX_RE.sub(lambda m: _mask_token(m.group(1)), text)
# ENV assignments: OPENAI_API_KEY=sk-abc...
def _redact_env(m):
name, quote, value = m.group(1), m.group(2), m.group(3)
return f"{name}={quote}{_mask_token(value)}{quote}"
text = _ENV_ASSIGN_RE.sub(_redact_env, text)
# JSON fields: "apiKey": "value"
def _redact_json(m):
key, value = m.group(1), m.group(2)
return f'{key}: "{_mask_token(value)}"'
text = _JSON_FIELD_RE.sub(_redact_json, text)
# Authorization headers
text = _AUTH_HEADER_RE.sub(
lambda m: m.group(1) + _mask_token(m.group(2)),
text,
)
# Telegram bot tokens
def _redact_telegram(m):
prefix = m.group(1) or ""
digits = m.group(2)
return f"{prefix}{digits}:***"
text = _TELEGRAM_RE.sub(_redact_telegram, text)
# Private key blocks
text = _PRIVATE_KEY_RE.sub("[REDACTED PRIVATE KEY]", text)
# Database connection string passwords
text = _DB_CONNSTR_RE.sub(lambda m: f"{m.group(1)}***{m.group(3)}", text)
# E.164 phone numbers (Signal, WhatsApp)
def _redact_phone(m):
phone = m.group(1)
if len(phone) <= 8:
return phone[:2] + "****" + phone[-2:]
return phone[:4] + "****" + phone[-4:]
text = _SIGNAL_PHONE_RE.sub(_redact_phone, text)
return text
class RedactingFormatter(logging.Formatter):
"""Log formatter that redacts secrets from all log messages."""
def __init__(self, fmt=None, datefmt=None, style='%', **kwargs):
super().__init__(fmt, datefmt, style, **kwargs)
def format(self, record: logging.LogRecord) -> str:
original = super().format(record)
return redact_sensitive_text(original)

View file

@ -0,0 +1,282 @@
"""Shared slash command helpers for skills and built-in prompt-style modes.
Shared between CLI (cli.py) and gateway (gateway/run.py) so both surfaces
can invoke skills via /skill-name commands and prompt-only built-ins like
/plan.
"""
import json
import logging
import re
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
_skill_commands: Dict[str, Dict[str, Any]] = {}
_PLAN_SLUG_RE = re.compile(r"[^a-z0-9]+")
def build_plan_path(
user_instruction: str = "",
*,
now: datetime | None = None,
) -> Path:
"""Return the default workspace-relative markdown path for a /plan invocation.
Relative paths are intentional: file tools are task/backend-aware and resolve
them against the active working directory for local, docker, ssh, modal,
daytona, and similar terminal backends. That keeps the plan with the active
workspace instead of the Hermes host's global home directory.
"""
slug_source = (user_instruction or "").strip().splitlines()[0] if user_instruction else ""
slug = _PLAN_SLUG_RE.sub("-", slug_source.lower()).strip("-")
if slug:
slug = "-".join(part for part in slug.split("-")[:8] if part)[:48].strip("-")
slug = slug or "conversation-plan"
timestamp = (now or datetime.now()).strftime("%Y-%m-%d_%H%M%S")
return Path(".hermes") / "plans" / f"{timestamp}-{slug}.md"
def _load_skill_payload(skill_identifier: str, task_id: str | None = None) -> tuple[dict[str, Any], Path | None, str] | None:
"""Load a skill by name/path and return (loaded_payload, skill_dir, display_name)."""
raw_identifier = (skill_identifier or "").strip()
if not raw_identifier:
return None
try:
from tools.skills_tool import SKILLS_DIR, skill_view
identifier_path = Path(raw_identifier).expanduser()
if identifier_path.is_absolute():
try:
normalized = str(identifier_path.resolve().relative_to(SKILLS_DIR.resolve()))
except Exception:
normalized = raw_identifier
else:
normalized = raw_identifier.lstrip("/")
loaded_skill = json.loads(skill_view(normalized, task_id=task_id))
except Exception:
return None
if not loaded_skill.get("success"):
return None
skill_name = str(loaded_skill.get("name") or normalized)
skill_path = str(loaded_skill.get("path") or "")
skill_dir = None
if skill_path:
try:
skill_dir = SKILLS_DIR / Path(skill_path).parent
except Exception:
skill_dir = None
return loaded_skill, skill_dir, skill_name
def _build_skill_message(
loaded_skill: dict[str, Any],
skill_dir: Path | None,
activation_note: str,
user_instruction: str = "",
runtime_note: str = "",
) -> str:
"""Format a loaded skill into a user/system message payload."""
from tools.skills_tool import SKILLS_DIR
content = str(loaded_skill.get("content") or "")
parts = [activation_note, "", content.strip()]
if loaded_skill.get("setup_skipped"):
parts.extend(
[
"",
"[Skill setup note: Required environment setup was skipped. Continue loading the skill and explain any reduced functionality if it matters.]",
]
)
elif loaded_skill.get("gateway_setup_hint"):
parts.extend(
[
"",
f"[Skill setup note: {loaded_skill['gateway_setup_hint']}]",
]
)
elif loaded_skill.get("setup_needed") and loaded_skill.get("setup_note"):
parts.extend(
[
"",
f"[Skill setup note: {loaded_skill['setup_note']}]",
]
)
supporting = []
linked_files = loaded_skill.get("linked_files") or {}
for entries in linked_files.values():
if isinstance(entries, list):
supporting.extend(entries)
if not supporting and skill_dir:
for subdir in ("references", "templates", "scripts", "assets"):
subdir_path = skill_dir / subdir
if subdir_path.exists():
for f in sorted(subdir_path.rglob("*")):
if f.is_file():
rel = str(f.relative_to(skill_dir))
supporting.append(rel)
if supporting and skill_dir:
skill_view_target = str(skill_dir.relative_to(SKILLS_DIR))
parts.append("")
parts.append("[This skill has supporting files you can load with the skill_view tool:]")
for sf in supporting:
parts.append(f"- {sf}")
parts.append(
f'\nTo view any of these, use: skill_view(name="{skill_view_target}", file_path="<path>")'
)
if user_instruction:
parts.append("")
parts.append(f"The user has provided the following instruction alongside the skill invocation: {user_instruction}")
if runtime_note:
parts.append("")
parts.append(f"[Runtime note: {runtime_note}]")
return "\n".join(parts)
def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
"""Scan ~/.hermes/skills/ and return a mapping of /command -> skill info.
Returns:
Dict mapping "/skill-name" to {name, description, skill_md_path, skill_dir}.
"""
global _skill_commands
_skill_commands = {}
try:
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform, _get_disabled_skill_names
if not SKILLS_DIR.exists():
return _skill_commands
disabled = _get_disabled_skill_names()
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
if any(part in ('.git', '.github', '.hub') for part in skill_md.parts):
continue
try:
content = skill_md.read_text(encoding='utf-8')
frontmatter, body = _parse_frontmatter(content)
# Skip skills incompatible with the current OS platform
if not skill_matches_platform(frontmatter):
continue
name = frontmatter.get('name', skill_md.parent.name)
# Respect user's disabled skills config
if name in disabled:
continue
description = frontmatter.get('description', '')
if not description:
for line in body.strip().split('\n'):
line = line.strip()
if line and not line.startswith('#'):
description = line[:80]
break
cmd_name = name.lower().replace(' ', '-').replace('_', '-')
_skill_commands[f"/{cmd_name}"] = {
"name": name,
"description": description or f"Invoke the {name} skill",
"skill_md_path": str(skill_md),
"skill_dir": str(skill_md.parent),
}
except Exception:
continue
except Exception:
pass
return _skill_commands
def get_skill_commands() -> Dict[str, Dict[str, Any]]:
"""Return the current skill commands mapping (scan first if empty)."""
if not _skill_commands:
scan_skill_commands()
return _skill_commands
def build_skill_invocation_message(
cmd_key: str,
user_instruction: str = "",
task_id: str | None = None,
runtime_note: str = "",
) -> Optional[str]:
"""Build the user message content for a skill slash command invocation.
Args:
cmd_key: The command key including leading slash (e.g., "/gif-search").
user_instruction: Optional text the user typed after the command.
Returns:
The formatted message string, or None if the skill wasn't found.
"""
commands = get_skill_commands()
skill_info = commands.get(cmd_key)
if not skill_info:
return None
loaded = _load_skill_payload(skill_info["skill_dir"], task_id=task_id)
if not loaded:
return f"[Failed to load skill: {skill_info['name']}]"
loaded_skill, skill_dir, skill_name = loaded
activation_note = (
f'[SYSTEM: The user has invoked the "{skill_name}" skill, indicating they want '
"you to follow its instructions. The full skill content is loaded below.]"
)
return _build_skill_message(
loaded_skill,
skill_dir,
activation_note,
user_instruction=user_instruction,
runtime_note=runtime_note,
)
def build_preloaded_skills_prompt(
skill_identifiers: list[str],
task_id: str | None = None,
) -> tuple[str, list[str], list[str]]:
"""Load one or more skills for session-wide CLI preloading.
Returns (prompt_text, loaded_skill_names, missing_identifiers).
"""
prompt_parts: list[str] = []
loaded_names: list[str] = []
missing: list[str] = []
seen: set[str] = set()
for raw_identifier in skill_identifiers:
identifier = (raw_identifier or "").strip()
if not identifier or identifier in seen:
continue
seen.add(identifier)
loaded = _load_skill_payload(identifier, task_id=task_id)
if not loaded:
missing.append(identifier)
continue
loaded_skill, skill_dir, skill_name = loaded
activation_note = (
f'[SYSTEM: The user launched this CLI session with the "{skill_name}" skill '
"preloaded. Treat its instructions as active guidance for the duration of this "
"session unless the user overrides them.]"
)
prompt_parts.append(
_build_skill_message(
loaded_skill,
skill_dir,
activation_note,
)
)
loaded_names.append(skill_name)
return "\n\n".join(prompt_parts), loaded_names, missing

View file

@ -0,0 +1,196 @@
"""Helpers for optional cheap-vs-strong model routing."""
from __future__ import annotations
import os
import re
from typing import Any, Dict, Optional
_COMPLEX_KEYWORDS = {
"debug",
"debugging",
"implement",
"implementation",
"refactor",
"patch",
"traceback",
"stacktrace",
"exception",
"error",
"analyze",
"analysis",
"investigate",
"architecture",
"design",
"compare",
"benchmark",
"optimize",
"optimise",
"review",
"terminal",
"shell",
"tool",
"tools",
"pytest",
"test",
"tests",
"plan",
"planning",
"delegate",
"subagent",
"cron",
"docker",
"kubernetes",
}
_URL_RE = re.compile(r"https?://|www\.", re.IGNORECASE)
def _coerce_bool(value: Any, default: bool = False) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "on"}
return bool(value)
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
def choose_cheap_model_route(user_message: str, routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Return the configured cheap-model route when a message looks simple.
Conservative by design: if the message has signs of code/tool/debugging/
long-form work, keep the primary model.
"""
cfg = routing_config or {}
if not _coerce_bool(cfg.get("enabled"), False):
return None
cheap_model = cfg.get("cheap_model") or {}
if not isinstance(cheap_model, dict):
return None
provider = str(cheap_model.get("provider") or "").strip().lower()
model = str(cheap_model.get("model") or "").strip()
if not provider or not model:
return None
text = (user_message or "").strip()
if not text:
return None
max_chars = _coerce_int(cfg.get("max_simple_chars"), 160)
max_words = _coerce_int(cfg.get("max_simple_words"), 28)
if len(text) > max_chars:
return None
if len(text.split()) > max_words:
return None
if text.count("\n") > 1:
return None
if "```" in text or "`" in text:
return None
if _URL_RE.search(text):
return None
lowered = text.lower()
words = {token.strip(".,:;!?()[]{}\"'`") for token in lowered.split()}
if words & _COMPLEX_KEYWORDS:
return None
route = dict(cheap_model)
route["provider"] = provider
route["model"] = model
route["routing_reason"] = "simple_turn"
return route
def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any]], primary: Dict[str, Any]) -> Dict[str, Any]:
"""Resolve the effective model/runtime for one turn.
Returns a dict with model/runtime/signature/label fields.
"""
route = choose_cheap_model_route(user_message, routing_config)
if not route:
return {
"model": primary.get("model"),
"runtime": {
"api_key": primary.get("api_key"),
"base_url": primary.get("base_url"),
"provider": primary.get("provider"),
"api_mode": primary.get("api_mode"),
"command": primary.get("command"),
"args": list(primary.get("args") or []),
},
"label": None,
"signature": (
primary.get("model"),
primary.get("provider"),
primary.get("base_url"),
primary.get("api_mode"),
primary.get("command"),
tuple(primary.get("args") or ()),
),
}
from hermes_cli.runtime_provider import resolve_runtime_provider
explicit_api_key = None
api_key_env = str(route.get("api_key_env") or "").strip()
if api_key_env:
explicit_api_key = os.getenv(api_key_env) or None
try:
runtime = resolve_runtime_provider(
requested=route.get("provider"),
explicit_api_key=explicit_api_key,
explicit_base_url=route.get("base_url"),
)
except Exception:
return {
"model": primary.get("model"),
"runtime": {
"api_key": primary.get("api_key"),
"base_url": primary.get("base_url"),
"provider": primary.get("provider"),
"api_mode": primary.get("api_mode"),
"command": primary.get("command"),
"args": list(primary.get("args") or []),
},
"label": None,
"signature": (
primary.get("model"),
primary.get("provider"),
primary.get("base_url"),
primary.get("api_mode"),
primary.get("command"),
tuple(primary.get("args") or ()),
),
}
return {
"model": route.get("model"),
"runtime": {
"api_key": runtime.get("api_key"),
"base_url": runtime.get("base_url"),
"provider": runtime.get("provider"),
"api_mode": runtime.get("api_mode"),
"command": runtime.get("command"),
"args": list(runtime.get("args") or []),
},
"label": f"smart route → {route.get('model')} ({runtime.get('provider')})",
"signature": (
route.get("model"),
runtime.get("provider"),
runtime.get("base_url"),
runtime.get("api_mode"),
runtime.get("command"),
tuple(runtime.get("args") or ()),
),
}

View file

@ -0,0 +1,125 @@
"""Auto-generate short session titles from the first user/assistant exchange.
Runs asynchronously after the first response is delivered so it never
adds latency to the user-facing reply.
"""
import logging
import threading
from typing import Optional
from agent.auxiliary_client import call_llm
logger = logging.getLogger(__name__)
_TITLE_PROMPT = (
"Generate a short, descriptive title (3-7 words) for a conversation that starts with the "
"following exchange. The title should capture the main topic or intent. "
"Return ONLY the title text, nothing else. No quotes, no punctuation at the end, no prefixes."
)
def generate_title(user_message: str, assistant_response: str, timeout: float = 15.0) -> Optional[str]:
"""Generate a session title from the first exchange.
Uses the auxiliary LLM client (cheapest/fastest available model).
Returns the title string or None on failure.
"""
# Truncate long messages to keep the request small
user_snippet = user_message[:500] if user_message else ""
assistant_snippet = assistant_response[:500] if assistant_response else ""
messages = [
{"role": "system", "content": _TITLE_PROMPT},
{"role": "user", "content": f"User: {user_snippet}\n\nAssistant: {assistant_snippet}"},
]
try:
response = call_llm(
task="compression", # reuse compression task config (cheap/fast model)
messages=messages,
max_tokens=30,
temperature=0.3,
timeout=timeout,
)
title = (response.choices[0].message.content or "").strip()
# Clean up: remove quotes, trailing punctuation, prefixes like "Title: "
title = title.strip('"\'')
if title.lower().startswith("title:"):
title = title[6:].strip()
# Enforce reasonable length
if len(title) > 80:
title = title[:77] + "..."
return title if title else None
except Exception as e:
logger.debug("Title generation failed: %s", e)
return None
def auto_title_session(
session_db,
session_id: str,
user_message: str,
assistant_response: str,
) -> None:
"""Generate and set a session title if one doesn't already exist.
Called in a background thread after the first exchange completes.
Silently skips if:
- session_db is None
- session already has a title (user-set or previously auto-generated)
- title generation fails
"""
if not session_db or not session_id:
return
# Check if title already exists (user may have set one via /title before first response)
try:
existing = session_db.get_session_title(session_id)
if existing:
return
except Exception:
return
title = generate_title(user_message, assistant_response)
if not title:
return
try:
session_db.set_session_title(session_id, title)
logger.debug("Auto-generated session title: %s", title)
except Exception as e:
logger.debug("Failed to set auto-generated title: %s", e)
def maybe_auto_title(
session_db,
session_id: str,
user_message: str,
assistant_response: str,
conversation_history: list,
) -> None:
"""Fire-and-forget title generation after the first exchange.
Only generates a title when:
- This appears to be the first userassistant exchange
- No title is already set
"""
if not session_db or not session_id or not user_message or not assistant_response:
return
# Count user messages in history to detect first exchange.
# conversation_history includes the exchange that just happened,
# so for a first exchange we expect exactly 1 user message
# (or 2 counting system). Be generous: generate on first 2 exchanges.
user_msg_count = sum(1 for m in (conversation_history or []) if m.get("role") == "user")
if user_msg_count > 2:
return
thread = threading.Thread(
target=auto_title_session,
args=(session_db, session_id, user_message, assistant_response),
daemon=True,
name="auto-title",
)
thread.start()

View file

@ -0,0 +1,56 @@
"""Trajectory saving utilities and static helpers.
_convert_to_trajectory_format stays as an AIAgent method (batch_runner.py
calls agent._convert_to_trajectory_format). Only the static helpers and
the file-write logic live here.
"""
import json
import logging
from datetime import datetime
from typing import Any, Dict, List
logger = logging.getLogger(__name__)
def convert_scratchpad_to_think(content: str) -> str:
"""Convert <REASONING_SCRATCHPAD> tags to <think> tags."""
if not content or "<REASONING_SCRATCHPAD>" not in content:
return content
return content.replace("<REASONING_SCRATCHPAD>", "<think>").replace("</REASONING_SCRATCHPAD>", "</think>")
def has_incomplete_scratchpad(content: str) -> bool:
"""Check if content has an opening <REASONING_SCRATCHPAD> without a closing tag."""
if not content:
return False
return "<REASONING_SCRATCHPAD>" in content and "</REASONING_SCRATCHPAD>" not in content
def save_trajectory(trajectory: List[Dict[str, Any]], model: str,
completed: bool, filename: str = None):
"""Append a trajectory entry to a JSONL file.
Args:
trajectory: The ShareGPT-format conversation list.
model: Model name for metadata.
completed: Whether the conversation completed successfully.
filename: Override output filename. Defaults to trajectory_samples.jsonl
or failed_trajectories.jsonl based on ``completed``.
"""
if filename is None:
filename = "trajectory_samples.jsonl" if completed else "failed_trajectories.jsonl"
entry = {
"conversations": trajectory,
"timestamp": datetime.now().isoformat(),
"model": model,
"completed": completed,
}
try:
with open(filename, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
logger.info("Trajectory saved to %s", filename)
except Exception as e:
logger.warning("Failed to save trajectory: %s", e)

View file

@ -0,0 +1,655 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal
from typing import Any, Dict, Literal, Optional
from agent.model_metadata import fetch_endpoint_model_metadata, fetch_model_metadata
DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
_ZERO = Decimal("0")
_ONE_MILLION = Decimal("1000000")
CostStatus = Literal["actual", "estimated", "included", "unknown"]
CostSource = Literal[
"provider_cost_api",
"provider_generation_api",
"provider_models_api",
"official_docs_snapshot",
"user_override",
"custom_contract",
"none",
]
@dataclass(frozen=True)
class CanonicalUsage:
input_tokens: int = 0
output_tokens: int = 0
cache_read_tokens: int = 0
cache_write_tokens: int = 0
reasoning_tokens: int = 0
request_count: int = 1
raw_usage: Optional[dict[str, Any]] = None
@property
def prompt_tokens(self) -> int:
return self.input_tokens + self.cache_read_tokens + self.cache_write_tokens
@property
def total_tokens(self) -> int:
return self.prompt_tokens + self.output_tokens
@dataclass(frozen=True)
class BillingRoute:
provider: str
model: str
base_url: str = ""
billing_mode: str = "unknown"
@dataclass(frozen=True)
class PricingEntry:
input_cost_per_million: Optional[Decimal] = None
output_cost_per_million: Optional[Decimal] = None
cache_read_cost_per_million: Optional[Decimal] = None
cache_write_cost_per_million: Optional[Decimal] = None
request_cost: Optional[Decimal] = None
source: CostSource = "none"
source_url: Optional[str] = None
pricing_version: Optional[str] = None
fetched_at: Optional[datetime] = None
@dataclass(frozen=True)
class CostResult:
amount_usd: Optional[Decimal]
status: CostStatus
source: CostSource
label: str
fetched_at: Optional[datetime] = None
pricing_version: Optional[str] = None
notes: tuple[str, ...] = ()
_UTC_NOW = lambda: datetime.now(timezone.utc)
# Official docs snapshot entries. Models whose published pricing and cache
# semantics are stable enough to encode exactly.
_OFFICIAL_DOCS_PRICING: Dict[tuple[str, str], PricingEntry] = {
(
"anthropic",
"claude-opus-4-20250514",
): PricingEntry(
input_cost_per_million=Decimal("15.00"),
output_cost_per_million=Decimal("75.00"),
cache_read_cost_per_million=Decimal("1.50"),
cache_write_cost_per_million=Decimal("18.75"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-prompt-caching-2026-03-16",
),
(
"anthropic",
"claude-sonnet-4-20250514",
): PricingEntry(
input_cost_per_million=Decimal("3.00"),
output_cost_per_million=Decimal("15.00"),
cache_read_cost_per_million=Decimal("0.30"),
cache_write_cost_per_million=Decimal("3.75"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-prompt-caching-2026-03-16",
),
# OpenAI
(
"openai",
"gpt-4o",
): PricingEntry(
input_cost_per_million=Decimal("2.50"),
output_cost_per_million=Decimal("10.00"),
cache_read_cost_per_million=Decimal("1.25"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"gpt-4o-mini",
): PricingEntry(
input_cost_per_million=Decimal("0.15"),
output_cost_per_million=Decimal("0.60"),
cache_read_cost_per_million=Decimal("0.075"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"gpt-4.1",
): PricingEntry(
input_cost_per_million=Decimal("2.00"),
output_cost_per_million=Decimal("8.00"),
cache_read_cost_per_million=Decimal("0.50"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"gpt-4.1-mini",
): PricingEntry(
input_cost_per_million=Decimal("0.40"),
output_cost_per_million=Decimal("1.60"),
cache_read_cost_per_million=Decimal("0.10"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"gpt-4.1-nano",
): PricingEntry(
input_cost_per_million=Decimal("0.10"),
output_cost_per_million=Decimal("0.40"),
cache_read_cost_per_million=Decimal("0.025"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"o3",
): PricingEntry(
input_cost_per_million=Decimal("10.00"),
output_cost_per_million=Decimal("40.00"),
cache_read_cost_per_million=Decimal("2.50"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"o3-mini",
): PricingEntry(
input_cost_per_million=Decimal("1.10"),
output_cost_per_million=Decimal("4.40"),
cache_read_cost_per_million=Decimal("0.55"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
# Anthropic older models (pre-4.6 generation)
(
"anthropic",
"claude-3-5-sonnet-20241022",
): PricingEntry(
input_cost_per_million=Decimal("3.00"),
output_cost_per_million=Decimal("15.00"),
cache_read_cost_per_million=Decimal("0.30"),
cache_write_cost_per_million=Decimal("3.75"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-pricing-2026-03-16",
),
(
"anthropic",
"claude-3-5-haiku-20241022",
): PricingEntry(
input_cost_per_million=Decimal("0.80"),
output_cost_per_million=Decimal("4.00"),
cache_read_cost_per_million=Decimal("0.08"),
cache_write_cost_per_million=Decimal("1.00"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-pricing-2026-03-16",
),
(
"anthropic",
"claude-3-opus-20240229",
): PricingEntry(
input_cost_per_million=Decimal("15.00"),
output_cost_per_million=Decimal("75.00"),
cache_read_cost_per_million=Decimal("1.50"),
cache_write_cost_per_million=Decimal("18.75"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-pricing-2026-03-16",
),
(
"anthropic",
"claude-3-haiku-20240307",
): PricingEntry(
input_cost_per_million=Decimal("0.25"),
output_cost_per_million=Decimal("1.25"),
cache_read_cost_per_million=Decimal("0.03"),
cache_write_cost_per_million=Decimal("0.30"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-pricing-2026-03-16",
),
# DeepSeek
(
"deepseek",
"deepseek-chat",
): PricingEntry(
input_cost_per_million=Decimal("0.14"),
output_cost_per_million=Decimal("0.28"),
source="official_docs_snapshot",
source_url="https://api-docs.deepseek.com/quick_start/pricing",
pricing_version="deepseek-pricing-2026-03-16",
),
(
"deepseek",
"deepseek-reasoner",
): PricingEntry(
input_cost_per_million=Decimal("0.55"),
output_cost_per_million=Decimal("2.19"),
source="official_docs_snapshot",
source_url="https://api-docs.deepseek.com/quick_start/pricing",
pricing_version="deepseek-pricing-2026-03-16",
),
# Google Gemini
(
"google",
"gemini-2.5-pro",
): PricingEntry(
input_cost_per_million=Decimal("1.25"),
output_cost_per_million=Decimal("10.00"),
source="official_docs_snapshot",
source_url="https://ai.google.dev/pricing",
pricing_version="google-pricing-2026-03-16",
),
(
"google",
"gemini-2.5-flash",
): PricingEntry(
input_cost_per_million=Decimal("0.15"),
output_cost_per_million=Decimal("0.60"),
source="official_docs_snapshot",
source_url="https://ai.google.dev/pricing",
pricing_version="google-pricing-2026-03-16",
),
(
"google",
"gemini-2.0-flash",
): PricingEntry(
input_cost_per_million=Decimal("0.10"),
output_cost_per_million=Decimal("0.40"),
source="official_docs_snapshot",
source_url="https://ai.google.dev/pricing",
pricing_version="google-pricing-2026-03-16",
),
}
def _to_decimal(value: Any) -> Optional[Decimal]:
if value is None:
return None
try:
return Decimal(str(value))
except Exception:
return None
def _to_int(value: Any) -> int:
try:
return int(value or 0)
except Exception:
return 0
def resolve_billing_route(
model_name: str,
provider: Optional[str] = None,
base_url: Optional[str] = None,
) -> BillingRoute:
provider_name = (provider or "").strip().lower()
base = (base_url or "").strip().lower()
model = (model_name or "").strip()
if not provider_name and "/" in model:
inferred_provider, bare_model = model.split("/", 1)
if inferred_provider in {"anthropic", "openai", "google"}:
provider_name = inferred_provider
model = bare_model
if provider_name == "openai-codex":
return BillingRoute(provider="openai-codex", model=model, base_url=base_url or "", billing_mode="subscription_included")
if provider_name == "openrouter" or "openrouter.ai" in base:
return BillingRoute(provider="openrouter", model=model, base_url=base_url or "", billing_mode="official_models_api")
if provider_name == "anthropic":
return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
if provider_name == "openai":
return BillingRoute(provider="openai", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
if provider_name in {"custom", "local"} or (base and "localhost" in base):
return BillingRoute(provider=provider_name or "custom", model=model, base_url=base_url or "", billing_mode="unknown")
return BillingRoute(provider=provider_name or "unknown", model=model.split("/")[-1] if model else "", base_url=base_url or "", billing_mode="unknown")
def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry]:
return _OFFICIAL_DOCS_PRICING.get((route.provider, route.model.lower()))
def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
return _pricing_entry_from_metadata(
fetch_model_metadata(),
route.model,
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models",
pricing_version="openrouter-models-api",
)
def _pricing_entry_from_metadata(
metadata: Dict[str, Dict[str, Any]],
model_id: str,
*,
source_url: str,
pricing_version: str,
) -> Optional[PricingEntry]:
if model_id not in metadata:
return None
pricing = metadata[model_id].get("pricing") or {}
prompt = _to_decimal(pricing.get("prompt"))
completion = _to_decimal(pricing.get("completion"))
request = _to_decimal(pricing.get("request"))
cache_read = _to_decimal(
pricing.get("cache_read")
or pricing.get("cached_prompt")
or pricing.get("input_cache_read")
)
cache_write = _to_decimal(
pricing.get("cache_write")
or pricing.get("cache_creation")
or pricing.get("input_cache_write")
)
if prompt is None and completion is None and request is None:
return None
def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]:
if value is None:
return None
return value * _ONE_MILLION
return PricingEntry(
input_cost_per_million=_per_token_to_per_million(prompt),
output_cost_per_million=_per_token_to_per_million(completion),
cache_read_cost_per_million=_per_token_to_per_million(cache_read),
cache_write_cost_per_million=_per_token_to_per_million(cache_write),
request_cost=request,
source="provider_models_api",
source_url=source_url,
pricing_version=pricing_version,
fetched_at=_UTC_NOW(),
)
def get_pricing_entry(
model_name: str,
provider: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> Optional[PricingEntry]:
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
if route.billing_mode == "subscription_included":
return PricingEntry(
input_cost_per_million=_ZERO,
output_cost_per_million=_ZERO,
cache_read_cost_per_million=_ZERO,
cache_write_cost_per_million=_ZERO,
source="none",
pricing_version="included-route",
)
if route.provider == "openrouter":
return _openrouter_pricing_entry(route)
if route.base_url:
entry = _pricing_entry_from_metadata(
fetch_endpoint_model_metadata(route.base_url, api_key=api_key or ""),
route.model,
source_url=f"{route.base_url.rstrip('/')}/models",
pricing_version="openai-compatible-models-api",
)
if entry:
return entry
return _lookup_official_docs_pricing(route)
def normalize_usage(
response_usage: Any,
*,
provider: Optional[str] = None,
api_mode: Optional[str] = None,
) -> CanonicalUsage:
"""Normalize raw API response usage into canonical token buckets.
Handles three API shapes:
- Anthropic: input_tokens/output_tokens/cache_read_input_tokens/cache_creation_input_tokens
- Codex Responses: input_tokens includes cache tokens; input_tokens_details.cached_tokens separates them
- OpenAI Chat Completions: prompt_tokens includes cache tokens; prompt_tokens_details.cached_tokens separates them
In both Codex and OpenAI modes, input_tokens is derived by subtracting cache
tokens from the total the API contract is that input/prompt totals include
cached tokens and the details object breaks them out.
"""
if not response_usage:
return CanonicalUsage()
provider_name = (provider or "").strip().lower()
mode = (api_mode or "").strip().lower()
if mode == "anthropic_messages" or provider_name == "anthropic":
input_tokens = _to_int(getattr(response_usage, "input_tokens", 0))
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
cache_read_tokens = _to_int(getattr(response_usage, "cache_read_input_tokens", 0))
cache_write_tokens = _to_int(getattr(response_usage, "cache_creation_input_tokens", 0))
elif mode == "codex_responses":
input_total = _to_int(getattr(response_usage, "input_tokens", 0))
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
details = getattr(response_usage, "input_tokens_details", None)
cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0)
cache_write_tokens = _to_int(
getattr(details, "cache_creation_tokens", 0) if details else 0
)
input_tokens = max(0, input_total - cache_read_tokens - cache_write_tokens)
else:
prompt_total = _to_int(getattr(response_usage, "prompt_tokens", 0))
output_tokens = _to_int(getattr(response_usage, "completion_tokens", 0))
details = getattr(response_usage, "prompt_tokens_details", None)
cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0)
cache_write_tokens = _to_int(
getattr(details, "cache_write_tokens", 0) if details else 0
)
input_tokens = max(0, prompt_total - cache_read_tokens - cache_write_tokens)
reasoning_tokens = 0
output_details = getattr(response_usage, "output_tokens_details", None)
if output_details:
reasoning_tokens = _to_int(getattr(output_details, "reasoning_tokens", 0))
return CanonicalUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
reasoning_tokens=reasoning_tokens,
)
def estimate_usage_cost(
model_name: str,
usage: CanonicalUsage,
*,
provider: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> CostResult:
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
if route.billing_mode == "subscription_included":
return CostResult(
amount_usd=_ZERO,
status="included",
source="none",
label="included",
pricing_version="included-route",
)
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
if not entry:
return CostResult(amount_usd=None, status="unknown", source="none", label="n/a")
notes: list[str] = []
amount = _ZERO
if usage.input_tokens and entry.input_cost_per_million is None:
return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a")
if usage.output_tokens and entry.output_cost_per_million is None:
return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a")
if usage.cache_read_tokens:
if entry.cache_read_cost_per_million is None:
return CostResult(
amount_usd=None,
status="unknown",
source=entry.source,
label="n/a",
notes=("cache-read pricing unavailable for route",),
)
if usage.cache_write_tokens:
if entry.cache_write_cost_per_million is None:
return CostResult(
amount_usd=None,
status="unknown",
source=entry.source,
label="n/a",
notes=("cache-write pricing unavailable for route",),
)
if entry.input_cost_per_million is not None:
amount += Decimal(usage.input_tokens) * entry.input_cost_per_million / _ONE_MILLION
if entry.output_cost_per_million is not None:
amount += Decimal(usage.output_tokens) * entry.output_cost_per_million / _ONE_MILLION
if entry.cache_read_cost_per_million is not None:
amount += Decimal(usage.cache_read_tokens) * entry.cache_read_cost_per_million / _ONE_MILLION
if entry.cache_write_cost_per_million is not None:
amount += Decimal(usage.cache_write_tokens) * entry.cache_write_cost_per_million / _ONE_MILLION
if entry.request_cost is not None and usage.request_count:
amount += Decimal(usage.request_count) * entry.request_cost
status: CostStatus = "estimated"
label = f"~${amount:.2f}"
if entry.source == "none" and amount == _ZERO:
status = "included"
label = "included"
if route.provider == "openrouter":
notes.append("OpenRouter cost is estimated from the models API until reconciled.")
return CostResult(
amount_usd=amount,
status=status,
source=entry.source,
label=label,
fetched_at=entry.fetched_at,
pricing_version=entry.pricing_version,
notes=tuple(notes),
)
def has_known_pricing(
model_name: str,
provider: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> bool:
"""Check whether we have pricing data for this model+route.
Uses direct lookup instead of routing through the full estimation
pipeline avoids creating dummy usage objects just to check status.
"""
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
if route.billing_mode == "subscription_included":
return True
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
return entry is not None
def get_pricing(
model_name: str,
provider: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> Dict[str, float]:
"""Backward-compatible thin wrapper for legacy callers.
Returns only non-cache input/output fields when a pricing entry exists.
Unknown routes return zeroes.
"""
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
if not entry:
return {"input": 0.0, "output": 0.0}
return {
"input": float(entry.input_cost_per_million or _ZERO),
"output": float(entry.output_cost_per_million or _ZERO),
}
def estimate_cost_usd(
model: str,
input_tokens: int,
output_tokens: int,
*,
provider: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> float:
"""Backward-compatible helper for legacy callers.
This uses non-cached input/output only. New code should call
`estimate_usage_cost()` with canonical usage buckets.
"""
result = estimate_usage_cost(
model,
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
provider=provider,
base_url=base_url,
api_key=api_key,
)
return float(result.amount_usd or _ZERO)
def format_duration_compact(seconds: float) -> str:
if seconds < 60:
return f"{seconds:.0f}s"
minutes = seconds / 60
if minutes < 60:
return f"{minutes:.0f}m"
hours = minutes / 60
if hours < 24:
remaining_min = int(minutes % 60)
return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h"
days = hours / 24
return f"{days:.1f}d"
def format_token_count_compact(value: int) -> str:
abs_value = abs(int(value))
if abs_value < 1_000:
return str(int(value))
sign = "-" if value < 0 else ""
units = ((1_000_000_000, "B"), (1_000_000, "M"), (1_000, "K"))
for threshold, suffix in units:
if abs_value >= threshold:
scaled = abs_value / threshold
if scaled < 10:
text = f"{scaled:.2f}"
elif scaled < 100:
text = f"{scaled:.1f}"
else:
text = f"{scaled:.0f}"
text = text.rstrip("0").rstrip(".")
return f"{sign}{text}{suffix}"
return f"{value:,}"