merge: resolve conflicts with origin/main

This commit is contained in:
teknium1 2026-03-17 04:30:37 -07:00
commit 0897e4350e
100 changed files with 11637 additions and 1337 deletions

View file

@ -65,10 +65,15 @@ OPENCODE_GO_API_KEY=
# TOOL API KEYS
# =============================================================================
# Parallel API Key - AI-native web search and extract
# Get at: https://parallel.ai
PARALLEL_API_KEY=
# Firecrawl API Key - Web search, extract, and crawl
# Get at: https://firecrawl.dev/
FIRECRAWL_API_KEY=
# FAL.ai API Key - Image generation
# Get at: https://fal.ai/
FAL_KEY=

View file

@ -44,7 +44,7 @@ hermes-agent/
│ ├── terminal_tool.py # Terminal orchestration
│ ├── process_registry.py # Background process management
│ ├── file_tools.py # File read/write/search/patch
│ ├── web_tools.py # Firecrawl search/extract
│ ├── web_tools.py # Web search/extract (Parallel + Firecrawl)
│ ├── browser_tool.py # Browserbase browser automation
│ ├── code_execution_tool.py # execute_code sandbox
│ ├── delegate_tool.py # Subagent delegation

View file

@ -147,7 +147,7 @@ hermes-agent/
│ ├── approval.py # Dangerous command detection + per-session approval
│ ├── terminal_tool.py # Terminal orchestration (sudo, env lifecycle, backends)
│ ├── file_operations.py # read_file, write_file, search, patch, etc.
│ ├── web_tools.py # web_search, web_extract (Firecrawl + Gemini summarization)
│ ├── web_tools.py # web_search, web_extract (Parallel/Firecrawl + Gemini summarization)
│ ├── vision_tools.py # Image analysis via multimodal models
│ ├── delegate_tool.py # Subagent spawning and parallel task execution
│ ├── code_execution_tool.py # Sandboxed Python with RPC tool access

View file

@ -963,8 +963,12 @@ def convert_messages_to_anthropic(
elif isinstance(prev_blocks, str) and isinstance(curr_blocks, str):
fixed[-1]["content"] = prev_blocks + "\n" + curr_blocks
else:
# Keep the later message
fixed[-1] = m
# Mixed types — normalize both to list and merge
if isinstance(prev_blocks, str):
prev_blocks = [{"type": "text", "text": prev_blocks}]
if isinstance(curr_blocks, str):
curr_blocks = [{"type": "text", "text": curr_blocks}]
fixed[-1]["content"] = prev_blocks + curr_blocks
else:
fixed.append(m)
result = fixed
@ -1049,7 +1053,8 @@ def build_anthropic_kwargs(
elif tool_choice == "required":
kwargs["tool_choice"] = {"type": "any"}
elif tool_choice == "none":
pass # Don't send tool_choice — Anthropic will use tools if needed
# Anthropic has no tool_choice "none" — omit tools entirely to prevent use
kwargs.pop("tools", None)
elif isinstance(tool_choice, str):
# Specific tool name
kwargs["tool_choice"] = {"type": "tool", "name": tool_choice}

View file

@ -39,6 +39,7 @@ custom OpenAI-compatible endpoint without touching the main model settings.
import json
import logging
import os
import threading
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
@ -705,6 +706,8 @@ def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[st
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
global auxiliary_is_nous
auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint,
_try_codex, _resolve_api_key_provider):
client, model = try_fn()
@ -1171,6 +1174,7 @@ def auxiliary_max_tokens_param(value: int) -> dict:
# Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model)
_client_cache: Dict[tuple, tuple] = {}
_client_cache_lock = threading.Lock()
def _get_cached_client(
@ -1182,9 +1186,11 @@ def _get_cached_client(
) -> Tuple[Optional[Any], Optional[str]]:
"""Get or create a cached client for the given provider."""
cache_key = (provider, async_mode, base_url or "", api_key or "")
if cache_key in _client_cache:
cached_client, cached_default = _client_cache[cache_key]
return cached_client, model or cached_default
with _client_cache_lock:
if cache_key in _client_cache:
cached_client, cached_default = _client_cache[cache_key]
return cached_client, model or cached_default
# Build outside the lock
client, default_model = resolve_provider_client(
provider,
model,
@ -1193,7 +1199,11 @@ def _get_cached_client(
explicit_api_key=api_key,
)
if client is not None:
_client_cache[cache_key] = (client, default_model)
with _client_cache_lock:
if cache_key not in _client_cache:
_client_cache[cache_key] = (client, default_model)
else:
client, default_model = _client_cache[cache_key]
return client, model or default_model

View file

@ -313,7 +313,19 @@ Write only the summary body. Do not include any preamble or prefix; the system w
if summary:
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
summary_role = "user" if last_head_role in ("assistant", "tool") else "assistant"
first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user"
# Pick a role that avoids consecutive same-role with both neighbors.
# Priority: avoid colliding with head (already committed), then tail.
if last_head_role in ("assistant", "tool"):
summary_role = "user"
else:
summary_role = "assistant"
# If the chosen role collides with the tail AND flipping wouldn't
# collide with the head, flip it.
if summary_role == first_tail_role:
flipped = "assistant" if summary_role == "user" else "user"
if flipped != last_head_role:
summary_role = flipped
compressed.append({"role": summary_role, "content": summary})
else:
if not self.quiet_mode:

View file

@ -22,14 +22,21 @@ from collections import Counter, defaultdict
from datetime import datetime
from typing import Any, Dict, List
from agent.usage_pricing import DEFAULT_PRICING, estimate_cost_usd, format_duration_compact, get_pricing, has_known_pricing
from agent.usage_pricing import (
CanonicalUsage,
DEFAULT_PRICING,
estimate_usage_cost,
format_duration_compact,
get_pricing,
has_known_pricing,
)
_DEFAULT_PRICING = DEFAULT_PRICING
def _has_known_pricing(model_name: str) -> bool:
def _has_known_pricing(model_name: str, provider: str = None, base_url: str = None) -> bool:
"""Check if a model has known pricing (vs unknown/custom endpoint)."""
return has_known_pricing(model_name)
return has_known_pricing(model_name, provider=provider, base_url=base_url)
def _get_pricing(model_name: str) -> Dict[str, float]:
@ -41,9 +48,43 @@ def _get_pricing(model_name: str) -> Dict[str, float]:
return get_pricing(model_name)
def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
"""Estimate the USD cost for a given model and token counts."""
return estimate_cost_usd(model, input_tokens, output_tokens)
def _estimate_cost(
session_or_model: Dict[str, Any] | str,
input_tokens: int = 0,
output_tokens: int = 0,
*,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
provider: str = None,
base_url: str = None,
) -> tuple[float, str]:
"""Estimate the USD cost for a session row or a model/token tuple."""
if isinstance(session_or_model, dict):
session = session_or_model
model = session.get("model") or ""
usage = CanonicalUsage(
input_tokens=session.get("input_tokens") or 0,
output_tokens=session.get("output_tokens") or 0,
cache_read_tokens=session.get("cache_read_tokens") or 0,
cache_write_tokens=session.get("cache_write_tokens") or 0,
)
provider = session.get("billing_provider")
base_url = session.get("billing_base_url")
else:
model = session_or_model or ""
usage = CanonicalUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
)
result = estimate_usage_cost(
model,
usage,
provider=provider,
base_url=base_url,
)
return float(result.amount_usd or 0.0), result.status
def _format_duration(seconds: float) -> str:
@ -135,7 +176,10 @@ class InsightsEngine:
# Columns we actually need (skip system_prompt, model_config blobs)
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
"message_count, tool_call_count, input_tokens, output_tokens")
"message_count, tool_call_count, input_tokens, output_tokens, "
"cache_read_tokens, cache_write_tokens, billing_provider, "
"billing_base_url, billing_mode, estimated_cost_usd, "
"actual_cost_usd, cost_status, cost_source")
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
"""Fetch sessions within the time window."""
@ -287,21 +331,30 @@ class InsightsEngine:
"""Compute high-level overview statistics."""
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
total_tokens = total_input + total_output
total_cache_read = sum(s.get("cache_read_tokens") or 0 for s in sessions)
total_cache_write = sum(s.get("cache_write_tokens") or 0 for s in sessions)
total_tokens = total_input + total_output + total_cache_read + total_cache_write
total_tool_calls = sum(s.get("tool_call_count") or 0 for s in sessions)
total_messages = sum(s.get("message_count") or 0 for s in sessions)
# Cost estimation (weighted by model)
total_cost = 0.0
actual_cost = 0.0
models_with_pricing = set()
models_without_pricing = set()
unknown_cost_sessions = 0
included_cost_sessions = 0
for s in sessions:
model = s.get("model") or ""
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
total_cost += _estimate_cost(model, inp, out)
estimated, status = _estimate_cost(s)
total_cost += estimated
actual_cost += s.get("actual_cost_usd") or 0.0
display = model.split("/")[-1] if "/" in model else (model or "unknown")
if _has_known_pricing(model):
if status == "included":
included_cost_sessions += 1
elif status == "unknown":
unknown_cost_sessions += 1
if _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url")):
models_with_pricing.add(display)
else:
models_without_pricing.add(display)
@ -328,8 +381,11 @@ class InsightsEngine:
"total_tool_calls": total_tool_calls,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_cache_read_tokens": total_cache_read,
"total_cache_write_tokens": total_cache_write,
"total_tokens": total_tokens,
"estimated_cost": total_cost,
"actual_cost": actual_cost,
"total_hours": total_hours,
"avg_session_duration": avg_duration,
"avg_messages_per_session": total_messages / len(sessions) if sessions else 0,
@ -341,12 +397,15 @@ class InsightsEngine:
"date_range_end": date_range_end,
"models_with_pricing": sorted(models_with_pricing),
"models_without_pricing": sorted(models_without_pricing),
"unknown_cost_sessions": unknown_cost_sessions,
"included_cost_sessions": included_cost_sessions,
}
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
"""Break down usage by model."""
model_data = defaultdict(lambda: {
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
"cache_read_tokens": 0, "cache_write_tokens": 0,
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
})
@ -358,12 +417,18 @@ class InsightsEngine:
d["sessions"] += 1
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
cache_read = s.get("cache_read_tokens") or 0
cache_write = s.get("cache_write_tokens") or 0
d["input_tokens"] += inp
d["output_tokens"] += out
d["total_tokens"] += inp + out
d["cache_read_tokens"] += cache_read
d["cache_write_tokens"] += cache_write
d["total_tokens"] += inp + out + cache_read + cache_write
d["tool_calls"] += s.get("tool_call_count") or 0
d["cost"] += _estimate_cost(model, inp, out)
d["has_pricing"] = _has_known_pricing(model)
estimate, status = _estimate_cost(s)
d["cost"] += estimate
d["has_pricing"] = _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url"))
d["cost_status"] = status
result = [
{"model": model, **data}
@ -377,7 +442,8 @@ class InsightsEngine:
"""Break down usage by platform/source."""
platform_data = defaultdict(lambda: {
"sessions": 0, "messages": 0, "input_tokens": 0,
"output_tokens": 0, "total_tokens": 0, "tool_calls": 0,
"output_tokens": 0, "cache_read_tokens": 0,
"cache_write_tokens": 0, "total_tokens": 0, "tool_calls": 0,
})
for s in sessions:
@ -387,9 +453,13 @@ class InsightsEngine:
d["messages"] += s.get("message_count") or 0
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
cache_read = s.get("cache_read_tokens") or 0
cache_write = s.get("cache_write_tokens") or 0
d["input_tokens"] += inp
d["output_tokens"] += out
d["total_tokens"] += inp + out
d["cache_read_tokens"] += cache_read
d["cache_write_tokens"] += cache_write
d["total_tokens"] += inp + out + cache_read + cache_write
d["tool_calls"] += s.get("tool_call_count") or 0
result = [

View file

@ -266,8 +266,10 @@ def get_model_context_length(model: str, base_url: str = "") -> int:
if model in metadata:
return metadata[model].get("context_length", 128000)
# 3. Hardcoded defaults (fuzzy match)
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
# 3. Hardcoded defaults (fuzzy match — longest key first for specificity)
for default_model, length in sorted(
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
):
if default_model in model or model in default_model:
return length

View file

@ -212,16 +212,15 @@ PLATFORM_HINTS = {
"the scheduled destination, put it directly in your final response. Use "
"send_message only for additional or different targets."
),
"sms": (
"You are communicating via SMS text messaging. Keep responses concise "
"and plain text only -- no markdown, no formatting. SMS has a 1600 "
"character limit per message (10 segments). Longer replies are split "
"across multiple messages. Be brief and direct."
),
"cli": (
"You are a CLI AI Agent. Try not to use markdown but simple text "
"renderable inside a terminal."
),
"sms": (
"You are communicating via SMS. Keep responses concise and use plain text "
"only — no markdown, no formatting. SMS messages are limited to ~1600 "
"characters, so be brief and direct."
),
}
CONTEXT_FILE_MAX_CHARS = 20_000

View file

@ -1,101 +1,593 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal
from typing import Dict
from typing import Any, Dict, Literal, Optional
MODEL_PRICING = {
"gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"gpt-4.1": {"input": 2.00, "output": 8.00},
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
"gpt-4.1-nano": {"input": 0.10, "output": 0.40},
"gpt-4.5-preview": {"input": 75.00, "output": 150.00},
"gpt-5": {"input": 10.00, "output": 30.00},
"gpt-5.4": {"input": 10.00, "output": 30.00},
"o3": {"input": 10.00, "output": 40.00},
"o3-mini": {"input": 1.10, "output": 4.40},
"o4-mini": {"input": 1.10, "output": 4.40},
"claude-opus-4-20250514": {"input": 15.00, "output": 75.00},
"claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00},
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
"claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
"claude-3-opus-20240229": {"input": 15.00, "output": 75.00},
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
"deepseek-chat": {"input": 0.14, "output": 0.28},
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
"gemini-2.5-pro": {"input": 1.25, "output": 10.00},
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
"llama-4-maverick": {"input": 0.50, "output": 0.70},
"llama-4-scout": {"input": 0.20, "output": 0.30},
"glm-5": {"input": 0.0, "output": 0.0},
"glm-4.7": {"input": 0.0, "output": 0.0},
"glm-4.5": {"input": 0.0, "output": 0.0},
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
"kimi-k2.5": {"input": 0.0, "output": 0.0},
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
}
from agent.model_metadata import fetch_model_metadata
DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
_ZERO = Decimal("0")
_ONE_MILLION = Decimal("1000000")
def get_pricing(model_name: str) -> Dict[str, float]:
if not model_name:
return DEFAULT_PRICING
bare = model_name.split("/")[-1].lower()
if bare in MODEL_PRICING:
return MODEL_PRICING[bare]
best_match = None
best_len = 0
for key, price in MODEL_PRICING.items():
if bare.startswith(key) and len(key) > best_len:
best_match = price
best_len = len(key)
if best_match:
return best_match
if "opus" in bare:
return {"input": 15.00, "output": 75.00}
if "sonnet" in bare:
return {"input": 3.00, "output": 15.00}
if "haiku" in bare:
return {"input": 0.80, "output": 4.00}
if "gpt-4o-mini" in bare:
return {"input": 0.15, "output": 0.60}
if "gpt-4o" in bare:
return {"input": 2.50, "output": 10.00}
if "gpt-5" in bare:
return {"input": 10.00, "output": 30.00}
if "deepseek" in bare:
return {"input": 0.14, "output": 0.28}
if "gemini" in bare:
return {"input": 0.15, "output": 0.60}
return DEFAULT_PRICING
CostStatus = Literal["actual", "estimated", "included", "unknown"]
CostSource = Literal[
"provider_cost_api",
"provider_generation_api",
"provider_models_api",
"official_docs_snapshot",
"user_override",
"custom_contract",
"none",
]
def has_known_pricing(model_name: str) -> bool:
pricing = get_pricing(model_name)
return pricing is not DEFAULT_PRICING and any(
float(value) > 0 for value in pricing.values()
@dataclass(frozen=True)
class CanonicalUsage:
input_tokens: int = 0
output_tokens: int = 0
cache_read_tokens: int = 0
cache_write_tokens: int = 0
reasoning_tokens: int = 0
request_count: int = 1
raw_usage: Optional[dict[str, Any]] = None
@property
def prompt_tokens(self) -> int:
return self.input_tokens + self.cache_read_tokens + self.cache_write_tokens
@property
def total_tokens(self) -> int:
return self.prompt_tokens + self.output_tokens
@dataclass(frozen=True)
class BillingRoute:
provider: str
model: str
base_url: str = ""
billing_mode: str = "unknown"
@dataclass(frozen=True)
class PricingEntry:
input_cost_per_million: Optional[Decimal] = None
output_cost_per_million: Optional[Decimal] = None
cache_read_cost_per_million: Optional[Decimal] = None
cache_write_cost_per_million: Optional[Decimal] = None
request_cost: Optional[Decimal] = None
source: CostSource = "none"
source_url: Optional[str] = None
pricing_version: Optional[str] = None
fetched_at: Optional[datetime] = None
@dataclass(frozen=True)
class CostResult:
amount_usd: Optional[Decimal]
status: CostStatus
source: CostSource
label: str
fetched_at: Optional[datetime] = None
pricing_version: Optional[str] = None
notes: tuple[str, ...] = ()
_UTC_NOW = lambda: datetime.now(timezone.utc)
# Official docs snapshot entries. Models whose published pricing and cache
# semantics are stable enough to encode exactly.
_OFFICIAL_DOCS_PRICING: Dict[tuple[str, str], PricingEntry] = {
(
"anthropic",
"claude-opus-4-20250514",
): PricingEntry(
input_cost_per_million=Decimal("15.00"),
output_cost_per_million=Decimal("75.00"),
cache_read_cost_per_million=Decimal("1.50"),
cache_write_cost_per_million=Decimal("18.75"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-prompt-caching-2026-03-16",
),
(
"anthropic",
"claude-sonnet-4-20250514",
): PricingEntry(
input_cost_per_million=Decimal("3.00"),
output_cost_per_million=Decimal("15.00"),
cache_read_cost_per_million=Decimal("0.30"),
cache_write_cost_per_million=Decimal("3.75"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-prompt-caching-2026-03-16",
),
# OpenAI
(
"openai",
"gpt-4o",
): PricingEntry(
input_cost_per_million=Decimal("2.50"),
output_cost_per_million=Decimal("10.00"),
cache_read_cost_per_million=Decimal("1.25"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"gpt-4o-mini",
): PricingEntry(
input_cost_per_million=Decimal("0.15"),
output_cost_per_million=Decimal("0.60"),
cache_read_cost_per_million=Decimal("0.075"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"gpt-4.1",
): PricingEntry(
input_cost_per_million=Decimal("2.00"),
output_cost_per_million=Decimal("8.00"),
cache_read_cost_per_million=Decimal("0.50"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"gpt-4.1-mini",
): PricingEntry(
input_cost_per_million=Decimal("0.40"),
output_cost_per_million=Decimal("1.60"),
cache_read_cost_per_million=Decimal("0.10"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"gpt-4.1-nano",
): PricingEntry(
input_cost_per_million=Decimal("0.10"),
output_cost_per_million=Decimal("0.40"),
cache_read_cost_per_million=Decimal("0.025"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"o3",
): PricingEntry(
input_cost_per_million=Decimal("10.00"),
output_cost_per_million=Decimal("40.00"),
cache_read_cost_per_million=Decimal("2.50"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
(
"openai",
"o3-mini",
): PricingEntry(
input_cost_per_million=Decimal("1.10"),
output_cost_per_million=Decimal("4.40"),
cache_read_cost_per_million=Decimal("0.55"),
source="official_docs_snapshot",
source_url="https://openai.com/api/pricing/",
pricing_version="openai-pricing-2026-03-16",
),
# Anthropic older models (pre-4.6 generation)
(
"anthropic",
"claude-3-5-sonnet-20241022",
): PricingEntry(
input_cost_per_million=Decimal("3.00"),
output_cost_per_million=Decimal("15.00"),
cache_read_cost_per_million=Decimal("0.30"),
cache_write_cost_per_million=Decimal("3.75"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-pricing-2026-03-16",
),
(
"anthropic",
"claude-3-5-haiku-20241022",
): PricingEntry(
input_cost_per_million=Decimal("0.80"),
output_cost_per_million=Decimal("4.00"),
cache_read_cost_per_million=Decimal("0.08"),
cache_write_cost_per_million=Decimal("1.00"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-pricing-2026-03-16",
),
(
"anthropic",
"claude-3-opus-20240229",
): PricingEntry(
input_cost_per_million=Decimal("15.00"),
output_cost_per_million=Decimal("75.00"),
cache_read_cost_per_million=Decimal("1.50"),
cache_write_cost_per_million=Decimal("18.75"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-pricing-2026-03-16",
),
(
"anthropic",
"claude-3-haiku-20240307",
): PricingEntry(
input_cost_per_million=Decimal("0.25"),
output_cost_per_million=Decimal("1.25"),
cache_read_cost_per_million=Decimal("0.03"),
cache_write_cost_per_million=Decimal("0.30"),
source="official_docs_snapshot",
source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching",
pricing_version="anthropic-pricing-2026-03-16",
),
# DeepSeek
(
"deepseek",
"deepseek-chat",
): PricingEntry(
input_cost_per_million=Decimal("0.14"),
output_cost_per_million=Decimal("0.28"),
source="official_docs_snapshot",
source_url="https://api-docs.deepseek.com/quick_start/pricing",
pricing_version="deepseek-pricing-2026-03-16",
),
(
"deepseek",
"deepseek-reasoner",
): PricingEntry(
input_cost_per_million=Decimal("0.55"),
output_cost_per_million=Decimal("2.19"),
source="official_docs_snapshot",
source_url="https://api-docs.deepseek.com/quick_start/pricing",
pricing_version="deepseek-pricing-2026-03-16",
),
# Google Gemini
(
"google",
"gemini-2.5-pro",
): PricingEntry(
input_cost_per_million=Decimal("1.25"),
output_cost_per_million=Decimal("10.00"),
source="official_docs_snapshot",
source_url="https://ai.google.dev/pricing",
pricing_version="google-pricing-2026-03-16",
),
(
"google",
"gemini-2.5-flash",
): PricingEntry(
input_cost_per_million=Decimal("0.15"),
output_cost_per_million=Decimal("0.60"),
source="official_docs_snapshot",
source_url="https://ai.google.dev/pricing",
pricing_version="google-pricing-2026-03-16",
),
(
"google",
"gemini-2.0-flash",
): PricingEntry(
input_cost_per_million=Decimal("0.10"),
output_cost_per_million=Decimal("0.40"),
source="official_docs_snapshot",
source_url="https://ai.google.dev/pricing",
pricing_version="google-pricing-2026-03-16",
),
}
def _to_decimal(value: Any) -> Optional[Decimal]:
if value is None:
return None
try:
return Decimal(str(value))
except Exception:
return None
def _to_int(value: Any) -> int:
try:
return int(value or 0)
except Exception:
return 0
def resolve_billing_route(
model_name: str,
provider: Optional[str] = None,
base_url: Optional[str] = None,
) -> BillingRoute:
provider_name = (provider or "").strip().lower()
base = (base_url or "").strip().lower()
model = (model_name or "").strip()
if not provider_name and "/" in model:
inferred_provider, bare_model = model.split("/", 1)
if inferred_provider in {"anthropic", "openai", "google"}:
provider_name = inferred_provider
model = bare_model
if provider_name == "openai-codex":
return BillingRoute(provider="openai-codex", model=model, base_url=base_url or "", billing_mode="subscription_included")
if provider_name == "openrouter" or "openrouter.ai" in base:
return BillingRoute(provider="openrouter", model=model, base_url=base_url or "", billing_mode="official_models_api")
if provider_name == "anthropic":
return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
if provider_name == "openai":
return BillingRoute(provider="openai", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
if provider_name in {"custom", "local"} or (base and "localhost" in base):
return BillingRoute(provider=provider_name or "custom", model=model, base_url=base_url or "", billing_mode="unknown")
return BillingRoute(provider=provider_name or "unknown", model=model.split("/")[-1] if model else "", base_url=base_url or "", billing_mode="unknown")
def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry]:
return _OFFICIAL_DOCS_PRICING.get((route.provider, route.model.lower()))
def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
metadata = fetch_model_metadata()
model_id = route.model
if model_id not in metadata:
return None
pricing = metadata[model_id].get("pricing") or {}
prompt = _to_decimal(pricing.get("prompt"))
completion = _to_decimal(pricing.get("completion"))
request = _to_decimal(pricing.get("request"))
cache_read = _to_decimal(
pricing.get("cache_read")
or pricing.get("cached_prompt")
or pricing.get("input_cache_read")
)
cache_write = _to_decimal(
pricing.get("cache_write")
or pricing.get("cache_creation")
or pricing.get("input_cache_write")
)
if prompt is None and completion is None and request is None:
return None
def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]:
if value is None:
return None
return value * _ONE_MILLION
return PricingEntry(
input_cost_per_million=_per_token_to_per_million(prompt),
output_cost_per_million=_per_token_to_per_million(completion),
cache_read_cost_per_million=_per_token_to_per_million(cache_read),
cache_write_cost_per_million=_per_token_to_per_million(cache_write),
request_cost=request,
source="provider_models_api",
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models",
pricing_version="openrouter-models-api",
fetched_at=_UTC_NOW(),
)
def estimate_cost_usd(model: str, input_tokens: int, output_tokens: int) -> float:
pricing = get_pricing(model)
total = (
Decimal(input_tokens) * Decimal(str(pricing["input"]))
+ Decimal(output_tokens) * Decimal(str(pricing["output"]))
) / Decimal("1000000")
return float(total)
def get_pricing_entry(
model_name: str,
provider: Optional[str] = None,
base_url: Optional[str] = None,
) -> Optional[PricingEntry]:
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
if route.billing_mode == "subscription_included":
return PricingEntry(
input_cost_per_million=_ZERO,
output_cost_per_million=_ZERO,
cache_read_cost_per_million=_ZERO,
cache_write_cost_per_million=_ZERO,
source="none",
pricing_version="included-route",
)
if route.provider == "openrouter":
return _openrouter_pricing_entry(route)
return _lookup_official_docs_pricing(route)
def normalize_usage(
response_usage: Any,
*,
provider: Optional[str] = None,
api_mode: Optional[str] = None,
) -> CanonicalUsage:
"""Normalize raw API response usage into canonical token buckets.
Handles three API shapes:
- Anthropic: input_tokens/output_tokens/cache_read_input_tokens/cache_creation_input_tokens
- Codex Responses: input_tokens includes cache tokens; input_tokens_details.cached_tokens separates them
- OpenAI Chat Completions: prompt_tokens includes cache tokens; prompt_tokens_details.cached_tokens separates them
In both Codex and OpenAI modes, input_tokens is derived by subtracting cache
tokens from the total the API contract is that input/prompt totals include
cached tokens and the details object breaks them out.
"""
if not response_usage:
return CanonicalUsage()
provider_name = (provider or "").strip().lower()
mode = (api_mode or "").strip().lower()
if mode == "anthropic_messages" or provider_name == "anthropic":
input_tokens = _to_int(getattr(response_usage, "input_tokens", 0))
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
cache_read_tokens = _to_int(getattr(response_usage, "cache_read_input_tokens", 0))
cache_write_tokens = _to_int(getattr(response_usage, "cache_creation_input_tokens", 0))
elif mode == "codex_responses":
input_total = _to_int(getattr(response_usage, "input_tokens", 0))
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
details = getattr(response_usage, "input_tokens_details", None)
cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0)
cache_write_tokens = _to_int(
getattr(details, "cache_creation_tokens", 0) if details else 0
)
input_tokens = max(0, input_total - cache_read_tokens - cache_write_tokens)
else:
prompt_total = _to_int(getattr(response_usage, "prompt_tokens", 0))
output_tokens = _to_int(getattr(response_usage, "completion_tokens", 0))
details = getattr(response_usage, "prompt_tokens_details", None)
cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0)
cache_write_tokens = _to_int(
getattr(details, "cache_write_tokens", 0) if details else 0
)
input_tokens = max(0, prompt_total - cache_read_tokens - cache_write_tokens)
reasoning_tokens = 0
output_details = getattr(response_usage, "output_tokens_details", None)
if output_details:
reasoning_tokens = _to_int(getattr(output_details, "reasoning_tokens", 0))
return CanonicalUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
reasoning_tokens=reasoning_tokens,
)
def estimate_usage_cost(
model_name: str,
usage: CanonicalUsage,
*,
provider: Optional[str] = None,
base_url: Optional[str] = None,
) -> CostResult:
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
if route.billing_mode == "subscription_included":
return CostResult(
amount_usd=_ZERO,
status="included",
source="none",
label="included",
pricing_version="included-route",
)
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
if not entry:
return CostResult(amount_usd=None, status="unknown", source="none", label="n/a")
notes: list[str] = []
amount = _ZERO
if usage.input_tokens and entry.input_cost_per_million is None:
return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a")
if usage.output_tokens and entry.output_cost_per_million is None:
return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a")
if usage.cache_read_tokens:
if entry.cache_read_cost_per_million is None:
return CostResult(
amount_usd=None,
status="unknown",
source=entry.source,
label="n/a",
notes=("cache-read pricing unavailable for route",),
)
if usage.cache_write_tokens:
if entry.cache_write_cost_per_million is None:
return CostResult(
amount_usd=None,
status="unknown",
source=entry.source,
label="n/a",
notes=("cache-write pricing unavailable for route",),
)
if entry.input_cost_per_million is not None:
amount += Decimal(usage.input_tokens) * entry.input_cost_per_million / _ONE_MILLION
if entry.output_cost_per_million is not None:
amount += Decimal(usage.output_tokens) * entry.output_cost_per_million / _ONE_MILLION
if entry.cache_read_cost_per_million is not None:
amount += Decimal(usage.cache_read_tokens) * entry.cache_read_cost_per_million / _ONE_MILLION
if entry.cache_write_cost_per_million is not None:
amount += Decimal(usage.cache_write_tokens) * entry.cache_write_cost_per_million / _ONE_MILLION
if entry.request_cost is not None and usage.request_count:
amount += Decimal(usage.request_count) * entry.request_cost
status: CostStatus = "estimated"
label = f"~${amount:.2f}"
if entry.source == "none" and amount == _ZERO:
status = "included"
label = "included"
if route.provider == "openrouter":
notes.append("OpenRouter cost is estimated from the models API until reconciled.")
return CostResult(
amount_usd=amount,
status=status,
source=entry.source,
label=label,
fetched_at=entry.fetched_at,
pricing_version=entry.pricing_version,
notes=tuple(notes),
)
def has_known_pricing(
model_name: str,
provider: Optional[str] = None,
base_url: Optional[str] = None,
) -> bool:
"""Check whether we have pricing data for this model+route.
Uses direct lookup instead of routing through the full estimation
pipeline avoids creating dummy usage objects just to check status.
"""
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
if route.billing_mode == "subscription_included":
return True
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
return entry is not None
def get_pricing(
model_name: str,
provider: Optional[str] = None,
base_url: Optional[str] = None,
) -> Dict[str, float]:
"""Backward-compatible thin wrapper for legacy callers.
Returns only non-cache input/output fields when a pricing entry exists.
Unknown routes return zeroes.
"""
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
if not entry:
return {"input": 0.0, "output": 0.0}
return {
"input": float(entry.input_cost_per_million or _ZERO),
"output": float(entry.output_cost_per_million or _ZERO),
}
def estimate_cost_usd(
model: str,
input_tokens: int,
output_tokens: int,
*,
provider: Optional[str] = None,
base_url: Optional[str] = None,
) -> float:
"""Backward-compatible helper for legacy callers.
This uses non-cached input/output only. New code should call
`estimate_usage_cost()` with canonical usage buckets.
"""
result = estimate_usage_cost(
model,
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
provider=provider,
base_url=base_url,
)
return float(result.amount_usd or _ZERO)
def format_duration_compact(seconds: float) -> str:

106
cli.py
View file

@ -58,7 +58,12 @@ except (ImportError, AttributeError):
import threading
import queue
from agent.usage_pricing import estimate_cost_usd, format_duration_compact, format_token_count_compact, has_known_pricing
from agent.usage_pricing import (
CanonicalUsage,
estimate_usage_cost,
format_duration_compact,
format_token_count_compact,
)
from hermes_cli.banner import _format_context_length
_COMMAND_SPINNER_FRAMES = ("", "", "", "", "", "", "", "", "", "")
@ -212,7 +217,7 @@ def load_cli_config() -> Dict[str, Any]:
"resume_display": "full",
"show_reasoning": False,
"streaming": False,
"show_cost": False,
"skin": "default",
"theme_mode": "auto",
},
@ -1034,8 +1039,7 @@ class HermesCLI:
self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False)
# show_reasoning: display model thinking/reasoning before the response
self.show_reasoning = CLI_CONFIG["display"].get("show_reasoning", False)
# show_cost: display $ cost in the status bar (off by default)
self.show_cost = CLI_CONFIG["display"].get("show_cost", False)
self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose")
# streaming: stream tokens to the terminal as they arrive (display.streaming in config.yaml)
@ -1260,12 +1264,14 @@ class HermesCLI:
"context_tokens": 0,
"context_length": None,
"context_percent": None,
"session_input_tokens": 0,
"session_output_tokens": 0,
"session_cache_read_tokens": 0,
"session_cache_write_tokens": 0,
"session_prompt_tokens": 0,
"session_completion_tokens": 0,
"session_total_tokens": 0,
"session_api_calls": 0,
"session_cost": 0.0,
"pricing_known": has_known_pricing(model_name),
"compressions": 0,
}
@ -1273,15 +1279,14 @@ class HermesCLI:
if not agent:
return snapshot
snapshot["session_input_tokens"] = getattr(agent, "session_input_tokens", 0) or 0
snapshot["session_output_tokens"] = getattr(agent, "session_output_tokens", 0) or 0
snapshot["session_cache_read_tokens"] = getattr(agent, "session_cache_read_tokens", 0) or 0
snapshot["session_cache_write_tokens"] = getattr(agent, "session_cache_write_tokens", 0) or 0
snapshot["session_prompt_tokens"] = getattr(agent, "session_prompt_tokens", 0) or 0
snapshot["session_completion_tokens"] = getattr(agent, "session_completion_tokens", 0) or 0
snapshot["session_total_tokens"] = getattr(agent, "session_total_tokens", 0) or 0
snapshot["session_api_calls"] = getattr(agent, "session_api_calls", 0) or 0
snapshot["session_cost"] = estimate_cost_usd(
model_name,
snapshot["session_prompt_tokens"],
snapshot["session_completion_tokens"],
)
compressor = getattr(agent, "context_compressor", None)
if compressor:
@ -1302,19 +1307,11 @@ class HermesCLI:
percent = snapshot["context_percent"]
percent_label = f"{percent}%" if percent is not None else "--"
duration_label = snapshot["duration"]
show_cost = getattr(self, "show_cost", False)
if show_cost:
cost_label = f"${snapshot['session_cost']:.2f}" if snapshot["pricing_known"] else "cost n/a"
else:
cost_label = None
if width < 52:
return f"{snapshot['model_short']} · {duration_label}"
if width < 76:
parts = [f"{snapshot['model_short']}", percent_label]
if cost_label:
parts.append(cost_label)
parts.append(duration_label)
return " · ".join(parts)
@ -1326,8 +1323,6 @@ class HermesCLI:
context_label = "ctx --"
parts = [f"{snapshot['model_short']}", context_label, percent_label]
if cost_label:
parts.append(cost_label)
parts.append(duration_label)
return "".join(parts)
except Exception:
@ -1338,12 +1333,6 @@ class HermesCLI:
snapshot = self._get_status_bar_snapshot()
width = shutil.get_terminal_size((80, 24)).columns
duration_label = snapshot["duration"]
show_cost = getattr(self, "show_cost", False)
if show_cost:
cost_label = f"${snapshot['session_cost']:.2f}" if snapshot["pricing_known"] else "cost n/a"
else:
cost_label = None
if width < 52:
return [
@ -1363,11 +1352,6 @@ class HermesCLI:
("class:status-bar-dim", " · "),
(self._status_bar_context_style(percent), percent_label),
]
if cost_label:
frags.extend([
("class:status-bar-dim", " · "),
("class:status-bar-dim", cost_label),
])
frags.extend([
("class:status-bar-dim", " · "),
("class:status-bar-dim", duration_label),
@ -1393,11 +1377,6 @@ class HermesCLI:
("class:status-bar-dim", " "),
(bar_style, percent_label),
]
if cost_label:
frags.extend([
("class:status-bar-dim", ""),
("class:status-bar-dim", cost_label),
])
frags.extend([
("class:status-bar-dim", ""),
("class:status-bar-dim", duration_label),
@ -3653,8 +3632,17 @@ class HermesCLI:
self.console.print(f"[bold red]Quick command error: {e}[/]")
else:
self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]")
elif qcmd.get("type") == "alias":
target = qcmd.get("target", "").strip()
if target:
target = target if target.startswith("/") else f"/{target}"
user_args = cmd_original[len(base_cmd):].strip()
aliased_command = f"{target} {user_args}".strip()
return self.process_command(aliased_command)
else:
self.console.print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]")
else:
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (only 'exec' is supported)[/]")
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]")
# Check for skill slash commands (/gif-search, /axolotl, etc.)
elif base_cmd in _skill_commands:
user_instruction = cmd_original[len(base_cmd):].strip()
@ -4242,6 +4230,10 @@ class HermesCLI:
return
agent = self.agent
input_tokens = getattr(agent, "session_input_tokens", 0) or 0
output_tokens = getattr(agent, "session_output_tokens", 0) or 0
cache_read_tokens = getattr(agent, "session_cache_read_tokens", 0) or 0
cache_write_tokens = getattr(agent, "session_cache_write_tokens", 0) or 0
prompt = agent.session_prompt_tokens
completion = agent.session_completion_tokens
total = agent.session_total_tokens
@ -4259,33 +4251,45 @@ class HermesCLI:
compressions = compressor.compression_count
msg_count = len(self.conversation_history)
cost = estimate_cost_usd(agent.model, prompt, completion)
prompt_cost = estimate_cost_usd(agent.model, prompt, 0)
completion_cost = estimate_cost_usd(agent.model, 0, completion)
pricing_known = has_known_pricing(agent.model)
cost_result = estimate_usage_cost(
agent.model,
CanonicalUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
),
provider=getattr(agent, "provider", None),
base_url=getattr(agent, "base_url", None),
)
elapsed = format_duration_compact((datetime.now() - self.session_start).total_seconds())
print(f" 📊 Session Token Usage")
print(f" {'' * 40}")
print(f" Model: {agent.model}")
print(f" Prompt tokens (input): {prompt:>10,}")
print(f" Completion tokens (output): {completion:>9,}")
print(f" Input tokens: {input_tokens:>10,}")
print(f" Cache read tokens: {cache_read_tokens:>10,}")
print(f" Cache write tokens: {cache_write_tokens:>10,}")
print(f" Output tokens: {output_tokens:>10,}")
print(f" Prompt tokens (total): {prompt:>10,}")
print(f" Completion tokens: {completion:>10,}")
print(f" Total tokens: {total:>10,}")
print(f" API calls: {calls:>10,}")
print(f" Session duration: {elapsed:>10}")
if pricing_known:
print(f" Input cost: ${prompt_cost:>10.4f}")
print(f" Output cost: ${completion_cost:>10.4f}")
print(f" Total cost: ${cost:>10.4f}")
print(f" Cost status: {cost_result.status:>10}")
print(f" Cost source: {cost_result.source:>10}")
if cost_result.amount_usd is not None:
prefix = "~" if cost_result.status == "estimated" else ""
print(f" Total cost: {prefix}${float(cost_result.amount_usd):>10.4f}")
elif cost_result.status == "included":
print(f" Total cost: {'included':>10}")
else:
print(f" Input cost: {'n/a':>10}")
print(f" Output cost: {'n/a':>10}")
print(f" Total cost: {'n/a':>10}")
print(f" {'' * 40}")
print(f" Current context: {last_prompt:,} / {ctx_len:,} ({pct:.0f}%)")
print(f" Messages: {msg_count}")
print(f" Compressions: {compressions}")
if not pricing_known:
if cost_result.status == "unknown":
print(f" Note: Pricing unknown for {agent.model}")
if self.verbose:

View file

@ -5,6 +5,7 @@ Jobs are stored in ~/.hermes/cron/jobs.json
Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md
"""
import copy
import json
import logging
import tempfile
@ -167,6 +168,10 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
try:
# Parse and validate
dt = datetime.fromisoformat(schedule.replace('Z', '+00:00'))
# Make naive timestamps timezone-aware at parse time so the stored
# value doesn't depend on the system timezone matching at check time.
if dt.tzinfo is None:
dt = dt.astimezone() # Interpret as local timezone
return {
"kind": "once",
"run_at": dt.isoformat(),
@ -539,8 +544,8 @@ def get_due_jobs() -> List[Dict[str, Any]]:
immediately. This prevents a burst of missed jobs on gateway restart.
"""
now = _hermes_now()
jobs = [_apply_skill_fields(j) for j in load_jobs()]
raw_jobs = load_jobs() # For saving updates
raw_jobs = load_jobs()
jobs = [_apply_skill_fields(j) for j in copy.deepcopy(raw_jobs)]
due = []
needs_save = False

View file

@ -0,0 +1,608 @@
# Pricing Accuracy Architecture
Date: 2026-03-16
## Goal
Hermes should only show dollar costs when they are backed by an official source for the user's actual billing path.
This design replaces the current static, heuristic pricing flow in:
- `run_agent.py`
- `agent/usage_pricing.py`
- `agent/insights.py`
- `cli.py`
with a provider-aware pricing system that:
- handles cache billing correctly
- distinguishes `actual` vs `estimated` vs `included` vs `unknown`
- reconciles post-hoc costs when providers expose authoritative billing data
- supports direct providers, OpenRouter, subscriptions, enterprise pricing, and custom endpoints
## Problems In The Current Design
Current Hermes behavior has four structural issues:
1. It stores only `prompt_tokens` and `completion_tokens`, which is insufficient for providers that bill cache reads and cache writes separately.
2. It uses a static model price table and fuzzy heuristics, which can drift from current official pricing.
3. It assumes public API list pricing matches the user's real billing path.
4. It has no distinction between live estimates and reconciled billed cost.
## Design Principles
1. Normalize usage before pricing.
2. Never fold cached tokens into plain input cost.
3. Track certainty explicitly.
4. Treat the billing path as part of the model identity.
5. Prefer official machine-readable sources over scraped docs.
6. Use post-hoc provider cost APIs when available.
7. Show `n/a` rather than inventing precision.
## High-Level Architecture
The new system has four layers:
1. `usage_normalization`
Converts raw provider usage into a canonical usage record.
2. `pricing_source_resolution`
Determines the billing path, source of truth, and applicable pricing source.
3. `cost_estimation_and_reconciliation`
Produces an immediate estimate when possible, then replaces or annotates it with actual billed cost later.
4. `presentation`
`/usage`, `/insights`, and the status bar display cost with certainty metadata.
## Canonical Usage Record
Add a canonical usage model that every provider path maps into before any pricing math happens.
Suggested structure:
```python
@dataclass
class CanonicalUsage:
provider: str
billing_provider: str
model: str
billing_route: str
input_tokens: int = 0
output_tokens: int = 0
cache_read_tokens: int = 0
cache_write_tokens: int = 0
reasoning_tokens: int = 0
request_count: int = 1
raw_usage: dict[str, Any] | None = None
raw_usage_fields: dict[str, str] | None = None
computed_fields: set[str] | None = None
provider_request_id: str | None = None
provider_generation_id: str | None = None
provider_response_id: str | None = None
```
Rules:
- `input_tokens` means non-cached input only.
- `cache_read_tokens` and `cache_write_tokens` are never merged into `input_tokens`.
- `output_tokens` excludes cache metrics.
- `reasoning_tokens` is telemetry unless a provider officially bills it separately.
This is the same normalization pattern used by `opencode`, extended with provenance and reconciliation ids.
## Provider Normalization Rules
### OpenAI Direct
Source usage fields:
- `prompt_tokens`
- `completion_tokens`
- `prompt_tokens_details.cached_tokens`
Normalization:
- `cache_read_tokens = cached_tokens`
- `input_tokens = prompt_tokens - cached_tokens`
- `cache_write_tokens = 0` unless OpenAI exposes it in the relevant route
- `output_tokens = completion_tokens`
### Anthropic Direct
Source usage fields:
- `input_tokens`
- `output_tokens`
- `cache_read_input_tokens`
- `cache_creation_input_tokens`
Normalization:
- `input_tokens = input_tokens`
- `output_tokens = output_tokens`
- `cache_read_tokens = cache_read_input_tokens`
- `cache_write_tokens = cache_creation_input_tokens`
### OpenRouter
Estimate-time usage normalization should use the response usage payload with the same rules as the underlying provider when possible.
Reconciliation-time records should also store:
- OpenRouter generation id
- native token fields when available
- `total_cost`
- `cache_discount`
- `upstream_inference_cost`
- `is_byok`
### Gemini / Vertex
Use official Gemini or Vertex usage fields where available.
If cached content tokens are exposed:
- map them to `cache_read_tokens`
If a route exposes no cache creation metric:
- store `cache_write_tokens = 0`
- preserve the raw usage payload for later extension
### DeepSeek And Other Direct Providers
Normalize only the fields that are officially exposed.
If a provider does not expose cache buckets:
- do not infer them unless the provider explicitly documents how to derive them
### Subscription / Included-Cost Routes
These still use the canonical usage model.
Tokens are tracked normally. Cost depends on billing mode, not on whether usage exists.
## Billing Route Model
Hermes must stop keying pricing solely by `model`.
Introduce a billing route descriptor:
```python
@dataclass
class BillingRoute:
provider: str
base_url: str | None
model: str
billing_mode: str
organization_hint: str | None = None
```
`billing_mode` values:
- `official_cost_api`
- `official_generation_api`
- `official_models_api`
- `official_docs_snapshot`
- `subscription_included`
- `user_override`
- `custom_contract`
- `unknown`
Examples:
- OpenAI direct API with Costs API access: `official_cost_api`
- Anthropic direct API with Usage & Cost API access: `official_cost_api`
- OpenRouter request before reconciliation: `official_models_api`
- OpenRouter request after generation lookup: `official_generation_api`
- GitHub Copilot style subscription route: `subscription_included`
- local OpenAI-compatible server: `unknown`
- enterprise contract with configured rates: `custom_contract`
## Cost Status Model
Every displayed cost should have:
```python
@dataclass
class CostResult:
amount_usd: Decimal | None
status: Literal["actual", "estimated", "included", "unknown"]
source: Literal[
"provider_cost_api",
"provider_generation_api",
"provider_models_api",
"official_docs_snapshot",
"user_override",
"custom_contract",
"none",
]
label: str
fetched_at: datetime | None
pricing_version: str | None
notes: list[str]
```
Presentation rules:
- `actual`: show dollar amount as final
- `estimated`: show dollar amount with estimate labeling
- `included`: show `included` or `$0.00 (included)` depending on UX choice
- `unknown`: show `n/a`
## Official Source Hierarchy
Resolve cost using this order:
1. Request-level or account-level official billed cost
2. Official machine-readable model pricing
3. Official docs snapshot
4. User override or custom contract
5. Unknown
The system must never skip to a lower level if a higher-confidence source exists for the current billing route.
## Provider-Specific Truth Rules
### OpenAI Direct
Preferred truth:
1. Costs API for reconciled spend
2. Official pricing page for live estimate
### Anthropic Direct
Preferred truth:
1. Usage & Cost API for reconciled spend
2. Official pricing docs for live estimate
### OpenRouter
Preferred truth:
1. `GET /api/v1/generation` for reconciled `total_cost`
2. `GET /api/v1/models` pricing for live estimate
Do not use underlying provider public pricing as the source of truth for OpenRouter billing.
### Gemini / Vertex
Preferred truth:
1. official billing export or billing API for reconciled spend when available for the route
2. official pricing docs for estimate
### DeepSeek
Preferred truth:
1. official machine-readable cost source if available in the future
2. official pricing docs snapshot today
### Subscription-Included Routes
Preferred truth:
1. explicit route config marking the model as included in subscription
These should display `included`, not an API list-price estimate.
### Custom Endpoint / Local Model
Preferred truth:
1. user override
2. custom contract config
3. unknown
These should default to `unknown`.
## Pricing Catalog
Replace the current `MODEL_PRICING` dict with a richer pricing catalog.
Suggested record:
```python
@dataclass
class PricingEntry:
provider: str
route_pattern: str
model_pattern: str
input_cost_per_million: Decimal | None = None
output_cost_per_million: Decimal | None = None
cache_read_cost_per_million: Decimal | None = None
cache_write_cost_per_million: Decimal | None = None
request_cost: Decimal | None = None
image_cost: Decimal | None = None
source: str = "official_docs_snapshot"
source_url: str | None = None
fetched_at: datetime | None = None
pricing_version: str | None = None
```
The catalog should be route-aware:
- `openai:gpt-5`
- `anthropic:claude-opus-4-6`
- `openrouter:anthropic/claude-opus-4.6`
- `copilot:gpt-4o`
This avoids conflating direct-provider billing with aggregator billing.
## Pricing Sync Architecture
Introduce a pricing sync subsystem instead of manually maintaining a single hardcoded table.
Suggested modules:
- `agent/pricing/catalog.py`
- `agent/pricing/sources.py`
- `agent/pricing/sync.py`
- `agent/pricing/reconcile.py`
- `agent/pricing/types.py`
### Sync Sources
- OpenRouter models API
- official provider docs snapshots where no API exists
- user overrides from config
### Sync Output
Cache pricing entries locally with:
- source URL
- fetch timestamp
- version/hash
- confidence/source type
### Sync Frequency
- startup warm cache
- background refresh every 6 to 24 hours depending on source
- manual `hermes pricing sync`
## Reconciliation Architecture
Live requests may produce only an estimate initially. Hermes should reconcile them later when a provider exposes actual billed cost.
Suggested flow:
1. Agent call completes.
2. Hermes stores canonical usage plus reconciliation ids.
3. Hermes computes an immediate estimate if a pricing source exists.
4. A reconciliation worker fetches actual cost when supported.
5. Session and message records are updated with `actual` cost.
This can run:
- inline for cheap lookups
- asynchronously for delayed provider accounting
## Persistence Changes
Session storage should stop storing only aggregate prompt/completion totals.
Add fields for both usage and cost certainty:
- `input_tokens`
- `output_tokens`
- `cache_read_tokens`
- `cache_write_tokens`
- `reasoning_tokens`
- `estimated_cost_usd`
- `actual_cost_usd`
- `cost_status`
- `cost_source`
- `pricing_version`
- `billing_provider`
- `billing_mode`
If schema expansion is too large for one PR, add a new pricing events table:
```text
session_cost_events
id
session_id
request_id
provider
model
billing_mode
input_tokens
output_tokens
cache_read_tokens
cache_write_tokens
estimated_cost_usd
actual_cost_usd
cost_status
cost_source
pricing_version
created_at
updated_at
```
## Hermes Touchpoints
### `run_agent.py`
Current responsibility:
- parse raw provider usage
- update session token counters
New responsibility:
- build `CanonicalUsage`
- update canonical counters
- store reconciliation ids
- emit usage event to pricing subsystem
### `agent/usage_pricing.py`
Current responsibility:
- static lookup table
- direct cost arithmetic
New responsibility:
- move or replace with pricing catalog facade
- no fuzzy model-family heuristics
- no direct pricing without billing-route context
### `cli.py`
Current responsibility:
- compute session cost directly from prompt/completion totals
New responsibility:
- display `CostResult`
- show status badges:
- `actual`
- `estimated`
- `included`
- `n/a`
### `agent/insights.py`
Current responsibility:
- recompute historical estimates from static pricing
New responsibility:
- aggregate stored pricing events
- prefer actual cost over estimate
- surface estimates only when reconciliation is unavailable
## UX Rules
### Status Bar
Show one of:
- `$1.42`
- `~$1.42`
- `included`
- `cost n/a`
Where:
- `$1.42` means `actual`
- `~$1.42` means `estimated`
- `included` means subscription-backed or explicitly zero-cost route
- `cost n/a` means unknown
### `/usage`
Show:
- token buckets
- estimated cost
- actual cost if available
- cost status
- pricing source
### `/insights`
Aggregate:
- actual cost totals
- estimated-only totals
- unknown-cost sessions count
- included-cost sessions count
## Config And Overrides
Add user-configurable pricing overrides in config:
```yaml
pricing:
mode: hybrid
sync_on_startup: true
sync_interval_hours: 12
overrides:
- provider: openrouter
model: anthropic/claude-opus-4.6
billing_mode: custom_contract
input_cost_per_million: 4.25
output_cost_per_million: 22.0
cache_read_cost_per_million: 0.5
cache_write_cost_per_million: 6.0
included_routes:
- provider: copilot
model: "*"
- provider: codex-subscription
model: "*"
```
Overrides must win over catalog defaults for the matching billing route.
## Rollout Plan
### Phase 1
- add canonical usage model
- split cache token buckets in `run_agent.py`
- stop pricing cache-inflated prompt totals
- preserve current UI with improved backend math
### Phase 2
- add route-aware pricing catalog
- integrate OpenRouter models API sync
- add `estimated` vs `included` vs `unknown`
### Phase 3
- add reconciliation for OpenRouter generation cost
- add actual cost persistence
- update `/insights` to prefer actual cost
### Phase 4
- add direct OpenAI and Anthropic reconciliation paths
- add user overrides and contract pricing
- add pricing sync CLI command
## Testing Strategy
Add tests for:
- OpenAI cached token subtraction
- Anthropic cache read/write separation
- OpenRouter estimated vs actual reconciliation
- subscription-backed models showing `included`
- custom endpoints showing `n/a`
- override precedence
- stale catalog fallback behavior
Current tests that assume heuristic pricing should be replaced with route-aware expectations.
## Non-Goals
- exact enterprise billing reconstruction without an official source or user override
- backfilling perfect historical cost for old sessions that lack cache bucket data
- scraping arbitrary provider web pages at request time
## Recommendation
Do not expand the existing `MODEL_PRICING` dict.
That path cannot satisfy the product requirement. Hermes should instead migrate to:
- canonical usage normalization
- route-aware pricing sources
- estimate-then-reconcile cost lifecycle
- explicit certainty states in the UI
This is the minimum architecture that makes the statement "Hermes pricing is backed by official sources where possible, and otherwise clearly labeled" defensible.

View file

@ -40,9 +40,12 @@ class Platform(Enum):
WHATSAPP = "whatsapp"
SLACK = "slack"
SIGNAL = "signal"
MATTERMOST = "mattermost"
MATRIX = "matrix"
HOMEASSISTANT = "homeassistant"
EMAIL = "email"
SMS = "sms"
DINGTALK = "dingtalk"
@dataclass
@ -226,15 +229,15 @@ class GatewayConfig:
# WhatsApp uses enabled flag only (bridge handles auth)
elif platform == Platform.WHATSAPP:
connected.append(platform)
# SMS uses api_key from env (checked via extra or env var)
elif platform == Platform.SMS and os.getenv("TELNYX_API_KEY"):
connected.append(platform)
# Signal uses extra dict for config (http_url + account)
elif platform == Platform.SIGNAL and config.extra.get("http_url"):
connected.append(platform)
# Email uses extra dict for config (address + imap_host + smtp_host)
elif platform == Platform.EMAIL and config.extra.get("address"):
connected.append(platform)
# SMS uses api_key (Twilio auth token) — SID checked via env
elif platform == Platform.SMS and os.getenv("TWILIO_ACCOUNT_SID"):
connected.append(platform)
return connected
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
@ -441,6 +444,8 @@ def load_gateway_config() -> GatewayConfig:
Platform.TELEGRAM: "TELEGRAM_BOT_TOKEN",
Platform.DISCORD: "DISCORD_BOT_TOKEN",
Platform.SLACK: "SLACK_BOT_TOKEN",
Platform.MATTERMOST: "MATTERMOST_TOKEN",
Platform.MATRIX: "MATRIX_ACCESS_TOKEN",
}
for platform, pconfig in config.platforms.items():
if not pconfig.enabled:
@ -534,6 +539,53 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
name=os.getenv("SIGNAL_HOME_CHANNEL_NAME", "Home"),
)
# Mattermost
mattermost_token = os.getenv("MATTERMOST_TOKEN")
if mattermost_token:
mattermost_url = os.getenv("MATTERMOST_URL", "")
if not mattermost_url:
logger.warning("MATTERMOST_TOKEN set but MATTERMOST_URL is missing")
if Platform.MATTERMOST not in config.platforms:
config.platforms[Platform.MATTERMOST] = PlatformConfig()
config.platforms[Platform.MATTERMOST].enabled = True
config.platforms[Platform.MATTERMOST].token = mattermost_token
config.platforms[Platform.MATTERMOST].extra["url"] = mattermost_url
mattermost_home = os.getenv("MATTERMOST_HOME_CHANNEL")
if mattermost_home:
config.platforms[Platform.MATTERMOST].home_channel = HomeChannel(
platform=Platform.MATTERMOST,
chat_id=mattermost_home,
name=os.getenv("MATTERMOST_HOME_CHANNEL_NAME", "Home"),
)
# Matrix
matrix_token = os.getenv("MATRIX_ACCESS_TOKEN")
matrix_homeserver = os.getenv("MATRIX_HOMESERVER", "")
if matrix_token or os.getenv("MATRIX_PASSWORD"):
if not matrix_homeserver:
logger.warning("MATRIX_ACCESS_TOKEN/MATRIX_PASSWORD set but MATRIX_HOMESERVER is missing")
if Platform.MATRIX not in config.platforms:
config.platforms[Platform.MATRIX] = PlatformConfig()
config.platforms[Platform.MATRIX].enabled = True
if matrix_token:
config.platforms[Platform.MATRIX].token = matrix_token
config.platforms[Platform.MATRIX].extra["homeserver"] = matrix_homeserver
matrix_user = os.getenv("MATRIX_USER_ID", "")
if matrix_user:
config.platforms[Platform.MATRIX].extra["user_id"] = matrix_user
matrix_password = os.getenv("MATRIX_PASSWORD", "")
if matrix_password:
config.platforms[Platform.MATRIX].extra["password"] = matrix_password
matrix_e2ee = os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes")
config.platforms[Platform.MATRIX].extra["encryption"] = matrix_e2ee
matrix_home = os.getenv("MATRIX_HOME_ROOM")
if matrix_home:
config.platforms[Platform.MATRIX].home_channel = HomeChannel(
platform=Platform.MATRIX,
chat_id=matrix_home,
name=os.getenv("MATRIX_HOME_ROOM_NAME", "Home"),
)
# Home Assistant
hass_token = os.getenv("HASS_TOKEN")
if hass_token:
@ -567,13 +619,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
name=os.getenv("EMAIL_HOME_ADDRESS_NAME", "Home"),
)
# SMS (Telnyx)
telnyx_key = os.getenv("TELNYX_API_KEY")
if telnyx_key:
# SMS (Twilio)
twilio_sid = os.getenv("TWILIO_ACCOUNT_SID")
if twilio_sid:
if Platform.SMS not in config.platforms:
config.platforms[Platform.SMS] = PlatformConfig()
config.platforms[Platform.SMS].enabled = True
config.platforms[Platform.SMS].api_key = telnyx_key
config.platforms[Platform.SMS].api_key = os.getenv("TWILIO_AUTH_TOKEN", "")
sms_home = os.getenv("SMS_HOME_CHANNEL")
if sms_home:
config.platforms[Platform.SMS].home_channel = HomeChannel(

View file

@ -8,8 +8,9 @@ Hooks are discovered from ~/.hermes/hooks/ directories, each containing:
Events:
- gateway:startup -- Gateway process starts
- session:start -- New session created
- session:reset -- User ran /new or /reset
- session:start -- New session created (first message of a new session)
- session:end -- Session ends (user ran /new or /reset)
- session:reset -- Session reset completed (new session entry created)
- agent:start -- Agent begins processing a message
- agent:step -- Each turn in the tool-calling loop
- agent:end -- Agent finishes processing

View file

@ -0,0 +1,340 @@
"""
DingTalk platform adapter using Stream Mode.
Uses dingtalk-stream SDK for real-time message reception without webhooks.
Responses are sent via DingTalk's session webhook (markdown format).
Requires:
pip install dingtalk-stream httpx
DINGTALK_CLIENT_ID and DINGTALK_CLIENT_SECRET env vars
Configuration in config.yaml:
platforms:
dingtalk:
enabled: true
extra:
client_id: "your-app-key" # or DINGTALK_CLIENT_ID env var
client_secret: "your-secret" # or DINGTALK_CLIENT_SECRET env var
"""
import asyncio
import logging
import os
import time
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, Optional
try:
import dingtalk_stream
from dingtalk_stream import ChatbotHandler, ChatbotMessage
DINGTALK_STREAM_AVAILABLE = True
except ImportError:
DINGTALK_STREAM_AVAILABLE = False
dingtalk_stream = None # type: ignore[assignment]
try:
import httpx
HTTPX_AVAILABLE = True
except ImportError:
HTTPX_AVAILABLE = False
httpx = None # type: ignore[assignment]
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
SendResult,
)
logger = logging.getLogger(__name__)
MAX_MESSAGE_LENGTH = 20000
DEDUP_WINDOW_SECONDS = 300
DEDUP_MAX_SIZE = 1000
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
def check_dingtalk_requirements() -> bool:
"""Check if DingTalk dependencies are available and configured."""
if not DINGTALK_STREAM_AVAILABLE or not HTTPX_AVAILABLE:
return False
if not os.getenv("DINGTALK_CLIENT_ID") or not os.getenv("DINGTALK_CLIENT_SECRET"):
return False
return True
class DingTalkAdapter(BasePlatformAdapter):
"""DingTalk chatbot adapter using Stream Mode.
The dingtalk-stream SDK maintains a long-lived WebSocket connection.
Incoming messages arrive via a ChatbotHandler callback. Replies are
sent via the incoming message's session_webhook URL using httpx.
"""
MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH
def __init__(self, config: PlatformConfig):
super().__init__(config, Platform.DINGTALK)
extra = config.extra or {}
self._client_id: str = extra.get("client_id") or os.getenv("DINGTALK_CLIENT_ID", "")
self._client_secret: str = extra.get("client_secret") or os.getenv("DINGTALK_CLIENT_SECRET", "")
self._stream_client: Any = None
self._stream_task: Optional[asyncio.Task] = None
self._http_client: Optional["httpx.AsyncClient"] = None
# Message deduplication: msg_id -> timestamp
self._seen_messages: Dict[str, float] = {}
# Map chat_id -> session_webhook for reply routing
self._session_webhooks: Dict[str, str] = {}
# -- Connection lifecycle -----------------------------------------------
async def connect(self) -> bool:
"""Connect to DingTalk via Stream Mode."""
if not DINGTALK_STREAM_AVAILABLE:
logger.warning("[%s] dingtalk-stream not installed. Run: pip install dingtalk-stream", self.name)
return False
if not HTTPX_AVAILABLE:
logger.warning("[%s] httpx not installed. Run: pip install httpx", self.name)
return False
if not self._client_id or not self._client_secret:
logger.warning("[%s] DINGTALK_CLIENT_ID and DINGTALK_CLIENT_SECRET required", self.name)
return False
try:
self._http_client = httpx.AsyncClient(timeout=30.0)
credential = dingtalk_stream.Credential(self._client_id, self._client_secret)
self._stream_client = dingtalk_stream.DingTalkStreamClient(credential)
# Capture the current event loop for cross-thread dispatch
loop = asyncio.get_running_loop()
handler = _IncomingHandler(self, loop)
self._stream_client.register_callback_handler(
dingtalk_stream.ChatbotMessage.TOPIC, handler
)
self._stream_task = asyncio.create_task(self._run_stream())
self._mark_connected()
logger.info("[%s] Connected via Stream Mode", self.name)
return True
except Exception as e:
logger.error("[%s] Failed to connect: %s", self.name, e)
return False
async def _run_stream(self) -> None:
"""Run the blocking stream client with auto-reconnection."""
backoff_idx = 0
while self._running:
try:
logger.debug("[%s] Starting stream client...", self.name)
await asyncio.to_thread(self._stream_client.start)
except asyncio.CancelledError:
return
except Exception as e:
if not self._running:
return
logger.warning("[%s] Stream client error: %s", self.name, e)
if not self._running:
return
delay = RECONNECT_BACKOFF[min(backoff_idx, len(RECONNECT_BACKOFF) - 1)]
logger.info("[%s] Reconnecting in %ds...", self.name, delay)
await asyncio.sleep(delay)
backoff_idx += 1
async def disconnect(self) -> None:
"""Disconnect from DingTalk."""
self._running = False
self._mark_disconnected()
if self._stream_task:
self._stream_task.cancel()
try:
await self._stream_task
except asyncio.CancelledError:
pass
self._stream_task = None
if self._http_client:
await self._http_client.aclose()
self._http_client = None
self._stream_client = None
self._session_webhooks.clear()
self._seen_messages.clear()
logger.info("[%s] Disconnected", self.name)
# -- Inbound message processing -----------------------------------------
async def _on_message(self, message: "ChatbotMessage") -> None:
"""Process an incoming DingTalk chatbot message."""
msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex
if self._is_duplicate(msg_id):
logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id)
return
text = self._extract_text(message)
if not text:
logger.debug("[%s] Empty message, skipping", self.name)
return
# Chat context
conversation_id = getattr(message, "conversation_id", "") or ""
conversation_type = getattr(message, "conversation_type", "1")
is_group = str(conversation_type) == "2"
sender_id = getattr(message, "sender_id", "") or ""
sender_nick = getattr(message, "sender_nick", "") or sender_id
sender_staff_id = getattr(message, "sender_staff_id", "") or ""
chat_id = conversation_id or sender_id
chat_type = "group" if is_group else "dm"
# Store session webhook for reply routing
session_webhook = getattr(message, "session_webhook", None) or ""
if session_webhook and chat_id:
self._session_webhooks[chat_id] = session_webhook
source = self.build_source(
chat_id=chat_id,
chat_name=getattr(message, "conversation_title", None),
chat_type=chat_type,
user_id=sender_id,
user_name=sender_nick,
user_id_alt=sender_staff_id if sender_staff_id else None,
)
# Parse timestamp
create_at = getattr(message, "create_at", None)
try:
timestamp = datetime.fromtimestamp(int(create_at) / 1000, tz=timezone.utc) if create_at else datetime.now(tz=timezone.utc)
except (ValueError, OSError, TypeError):
timestamp = datetime.now(tz=timezone.utc)
event = MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=source,
message_id=msg_id,
raw_message=message,
timestamp=timestamp,
)
logger.debug("[%s] Message from %s in %s: %s",
self.name, sender_nick, chat_id[:20] if chat_id else "?", text[:50])
await self.handle_message(event)
@staticmethod
def _extract_text(message: "ChatbotMessage") -> str:
"""Extract plain text from a DingTalk chatbot message."""
text = getattr(message, "text", None) or ""
if isinstance(text, dict):
content = text.get("content", "").strip()
else:
content = str(text).strip()
# Fall back to rich text if present
if not content:
rich_text = getattr(message, "rich_text", None)
if rich_text and isinstance(rich_text, list):
parts = [item["text"] for item in rich_text
if isinstance(item, dict) and item.get("text")]
content = " ".join(parts).strip()
return content
# -- Deduplication ------------------------------------------------------
def _is_duplicate(self, msg_id: str) -> bool:
"""Check and record a message ID. Returns True if already seen."""
now = time.time()
if len(self._seen_messages) > DEDUP_MAX_SIZE:
cutoff = now - DEDUP_WINDOW_SECONDS
self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff}
if msg_id in self._seen_messages:
return True
self._seen_messages[msg_id] = now
return False
# -- Outbound messaging -------------------------------------------------
async def send(
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send a markdown reply via DingTalk session webhook."""
metadata = metadata or {}
session_webhook = metadata.get("session_webhook") or self._session_webhooks.get(chat_id)
if not session_webhook:
return SendResult(success=False,
error="No session_webhook available. Reply must follow an incoming message.")
if not self._http_client:
return SendResult(success=False, error="HTTP client not initialized")
payload = {
"msgtype": "markdown",
"markdown": {"title": "Hermes", "text": content[:self.MAX_MESSAGE_LENGTH]},
}
try:
resp = await self._http_client.post(session_webhook, json=payload, timeout=15.0)
if resp.status_code < 300:
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
body = resp.text
logger.warning("[%s] Send failed HTTP %d: %s", self.name, resp.status_code, body[:200])
return SendResult(success=False, error=f"HTTP {resp.status_code}: {body[:200]}")
except httpx.TimeoutException:
return SendResult(success=False, error="Timeout sending message to DingTalk")
except Exception as e:
logger.error("[%s] Send error: %s", self.name, e)
return SendResult(success=False, error=str(e))
async def send_typing(self, chat_id: str, metadata=None) -> None:
"""DingTalk does not support typing indicators."""
pass
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
"""Return basic info about a DingTalk conversation."""
return {"name": chat_id, "type": "group" if "group" in chat_id.lower() else "dm"}
# ---------------------------------------------------------------------------
# Internal stream handler
# ---------------------------------------------------------------------------
class _IncomingHandler(ChatbotHandler if DINGTALK_STREAM_AVAILABLE else object):
"""dingtalk-stream ChatbotHandler that forwards messages to the adapter."""
def __init__(self, adapter: DingTalkAdapter, loop: asyncio.AbstractEventLoop):
if DINGTALK_STREAM_AVAILABLE:
super().__init__()
self._adapter = adapter
self._loop = loop
def process(self, message: "ChatbotMessage"):
"""Called by dingtalk-stream in its thread when a message arrives.
Schedules the async handler on the main event loop.
"""
loop = self._loop
if loop is None or loop.is_closed():
logger.error("[DingTalk] Event loop unavailable, cannot dispatch message")
return dingtalk_stream.AckMessage.STATUS_OK, "OK"
future = asyncio.run_coroutine_threadsafe(self._adapter._on_message(message), loop)
try:
future.result(timeout=60)
except Exception:
logger.exception("[DingTalk] Error processing incoming message")
return dingtalk_stream.AckMessage.STATUS_OK, "OK"

842
gateway/platforms/matrix.py Normal file
View file

@ -0,0 +1,842 @@
"""Matrix gateway adapter.
Connects to any Matrix homeserver (self-hosted or matrix.org) via the
matrix-nio Python SDK. Supports optional end-to-end encryption (E2EE)
when installed with ``pip install "matrix-nio[e2e]"``.
Environment variables:
MATRIX_HOMESERVER Homeserver URL (e.g. https://matrix.example.org)
MATRIX_ACCESS_TOKEN Access token (preferred auth method)
MATRIX_USER_ID Full user ID (@bot:server) required for password login
MATRIX_PASSWORD Password (alternative to access token)
MATRIX_ENCRYPTION Set "true" to enable E2EE
MATRIX_ALLOWED_USERS Comma-separated Matrix user IDs (@user:server)
MATRIX_HOME_ROOM Room ID for cron/notification delivery
"""
from __future__ import annotations
import asyncio
import json
import logging
import mimetypes
import os
import re
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
SendResult,
)
logger = logging.getLogger(__name__)
# Matrix message size limit (4000 chars practical, spec has no hard limit
# but clients render poorly above this).
MAX_MESSAGE_LENGTH = 4000
# Store directory for E2EE keys and sync state.
_STORE_DIR = Path.home() / ".hermes" / "matrix" / "store"
# Grace period: ignore messages older than this many seconds before startup.
_STARTUP_GRACE_SECONDS = 5
def check_matrix_requirements() -> bool:
"""Return True if the Matrix adapter can be used."""
token = os.getenv("MATRIX_ACCESS_TOKEN", "")
password = os.getenv("MATRIX_PASSWORD", "")
homeserver = os.getenv("MATRIX_HOMESERVER", "")
if not token and not password:
logger.debug("Matrix: neither MATRIX_ACCESS_TOKEN nor MATRIX_PASSWORD set")
return False
if not homeserver:
logger.warning("Matrix: MATRIX_HOMESERVER not set")
return False
try:
import nio # noqa: F401
return True
except ImportError:
logger.warning(
"Matrix: matrix-nio not installed. "
"Run: pip install 'matrix-nio[e2e]'"
)
return False
class MatrixAdapter(BasePlatformAdapter):
"""Gateway adapter for Matrix (any homeserver)."""
def __init__(self, config: PlatformConfig):
super().__init__(config, Platform.MATRIX)
self._homeserver: str = (
config.extra.get("homeserver", "")
or os.getenv("MATRIX_HOMESERVER", "")
).rstrip("/")
self._access_token: str = config.token or os.getenv("MATRIX_ACCESS_TOKEN", "")
self._user_id: str = (
config.extra.get("user_id", "")
or os.getenv("MATRIX_USER_ID", "")
)
self._password: str = (
config.extra.get("password", "")
or os.getenv("MATRIX_PASSWORD", "")
)
self._encryption: bool = config.extra.get(
"encryption",
os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes"),
)
self._client: Any = None # nio.AsyncClient
self._sync_task: Optional[asyncio.Task] = None
self._closing = False
self._startup_ts: float = 0.0
# Cache: room_id → bool (is DM)
self._dm_rooms: Dict[str, bool] = {}
# Set of room IDs we've joined
self._joined_rooms: Set[str] = set()
# ------------------------------------------------------------------
# Required overrides
# ------------------------------------------------------------------
async def connect(self) -> bool:
"""Connect to the Matrix homeserver and start syncing."""
import nio
if not self._homeserver:
logger.error("Matrix: homeserver URL not configured")
return False
# Determine store path and ensure it exists.
store_path = str(_STORE_DIR)
_STORE_DIR.mkdir(parents=True, exist_ok=True)
# Create the client.
if self._encryption:
try:
client = nio.AsyncClient(
self._homeserver,
self._user_id or "",
store_path=store_path,
)
logger.info("Matrix: E2EE enabled (store: %s)", store_path)
except Exception as exc:
logger.warning(
"Matrix: failed to create E2EE client (%s), "
"falling back to plain client. Install: "
"pip install 'matrix-nio[e2e]'",
exc,
)
client = nio.AsyncClient(self._homeserver, self._user_id or "")
else:
client = nio.AsyncClient(self._homeserver, self._user_id or "")
self._client = client
# Authenticate.
if self._access_token:
client.access_token = self._access_token
# Resolve user_id if not set.
if not self._user_id:
resp = await client.whoami()
if isinstance(resp, nio.WhoamiResponse):
self._user_id = resp.user_id
client.user_id = resp.user_id
logger.info("Matrix: authenticated as %s", self._user_id)
else:
logger.error(
"Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER"
)
await client.close()
return False
else:
client.user_id = self._user_id
logger.info("Matrix: using access token for %s", self._user_id)
elif self._password and self._user_id:
resp = await client.login(
self._password,
device_name="Hermes Agent",
)
if isinstance(resp, nio.LoginResponse):
logger.info("Matrix: logged in as %s", self._user_id)
else:
logger.error("Matrix: login failed — %s", getattr(resp, "message", resp))
await client.close()
return False
else:
logger.error("Matrix: need MATRIX_ACCESS_TOKEN or MATRIX_USER_ID + MATRIX_PASSWORD")
await client.close()
return False
# If E2EE is enabled, load the crypto store.
if self._encryption and hasattr(client, "olm"):
try:
if client.should_upload_keys:
await client.keys_upload()
logger.info("Matrix: E2EE crypto initialized")
except Exception as exc:
logger.warning("Matrix: crypto init issue: %s", exc)
# Register event callbacks.
client.add_event_callback(self._on_room_message, nio.RoomMessageText)
client.add_event_callback(self._on_room_message_media, nio.RoomMessageMedia)
client.add_event_callback(self._on_room_message_media, nio.RoomMessageImage)
client.add_event_callback(self._on_room_message_media, nio.RoomMessageAudio)
client.add_event_callback(self._on_room_message_media, nio.RoomMessageVideo)
client.add_event_callback(self._on_room_message_media, nio.RoomMessageFile)
client.add_event_callback(self._on_invite, nio.InviteMemberEvent)
# If E2EE: handle encrypted events.
if self._encryption and hasattr(client, "olm"):
client.add_event_callback(
self._on_room_message, nio.MegolmEvent
)
# Initial sync to catch up, then start background sync.
self._startup_ts = time.time()
self._closing = False
# Do an initial sync to populate room state.
resp = await client.sync(timeout=10000, full_state=True)
if isinstance(resp, nio.SyncResponse):
self._joined_rooms = set(resp.rooms.join.keys())
logger.info(
"Matrix: initial sync complete, joined %d rooms",
len(self._joined_rooms),
)
# Build DM room cache from m.direct account data.
await self._refresh_dm_cache()
else:
logger.warning("Matrix: initial sync returned %s", type(resp).__name__)
# Start the sync loop.
self._sync_task = asyncio.create_task(self._sync_loop())
self._mark_connected()
return True
async def disconnect(self) -> None:
"""Disconnect from Matrix."""
self._closing = True
if self._sync_task and not self._sync_task.done():
self._sync_task.cancel()
try:
await self._sync_task
except (asyncio.CancelledError, Exception):
pass
if self._client:
await self._client.close()
self._client = None
logger.info("Matrix: disconnected")
async def send(
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send a message to a Matrix room."""
import nio
if not content:
return SendResult(success=True)
formatted = self.format_message(content)
chunks = self.truncate_message(formatted, MAX_MESSAGE_LENGTH)
last_event_id = None
for chunk in chunks:
msg_content: Dict[str, Any] = {
"msgtype": "m.text",
"body": chunk,
}
# Convert markdown to HTML for rich rendering.
html = self._markdown_to_html(chunk)
if html and html != chunk:
msg_content["format"] = "org.matrix.custom.html"
msg_content["formatted_body"] = html
# Reply-to support.
if reply_to:
msg_content["m.relates_to"] = {
"m.in_reply_to": {"event_id": reply_to}
}
# Thread support: if metadata has thread_id, send as threaded reply.
thread_id = (metadata or {}).get("thread_id")
if thread_id:
relates_to = msg_content.get("m.relates_to", {})
relates_to["rel_type"] = "m.thread"
relates_to["event_id"] = thread_id
relates_to["is_falling_back"] = True
if reply_to and "m.in_reply_to" not in relates_to:
relates_to["m.in_reply_to"] = {"event_id": reply_to}
msg_content["m.relates_to"] = relates_to
resp = await self._client.room_send(
chat_id,
"m.room.message",
msg_content,
)
if isinstance(resp, nio.RoomSendResponse):
last_event_id = resp.event_id
else:
err = getattr(resp, "message", str(resp))
logger.error("Matrix: failed to send to %s: %s", chat_id, err)
return SendResult(success=False, error=err)
return SendResult(success=True, message_id=last_event_id)
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
"""Return room name and type (dm/group)."""
name = chat_id
chat_type = "group"
if self._client:
room = self._client.rooms.get(chat_id)
if room:
name = room.display_name or room.canonical_alias or chat_id
# Use DM cache.
if self._dm_rooms.get(chat_id, False):
chat_type = "dm"
elif room.member_count == 2:
chat_type = "dm"
return {"name": name, "type": chat_type}
# ------------------------------------------------------------------
# Optional overrides
# ------------------------------------------------------------------
async def send_typing(
self, chat_id: str, metadata: Optional[Dict[str, Any]] = None
) -> None:
"""Send a typing indicator."""
if self._client:
try:
await self._client.room_typing(chat_id, typing_state=True, timeout=30000)
except Exception:
pass
async def edit_message(
self, chat_id: str, message_id: str, content: str
) -> SendResult:
"""Edit an existing message (via m.replace)."""
import nio
formatted = self.format_message(content)
msg_content: Dict[str, Any] = {
"msgtype": "m.text",
"body": f"* {formatted}",
"m.new_content": {
"msgtype": "m.text",
"body": formatted,
},
"m.relates_to": {
"rel_type": "m.replace",
"event_id": message_id,
},
}
html = self._markdown_to_html(formatted)
if html and html != formatted:
msg_content["m.new_content"]["format"] = "org.matrix.custom.html"
msg_content["m.new_content"]["formatted_body"] = html
msg_content["format"] = "org.matrix.custom.html"
msg_content["formatted_body"] = f"* {html}"
resp = await self._client.room_send(chat_id, "m.room.message", msg_content)
if isinstance(resp, nio.RoomSendResponse):
return SendResult(success=True, message_id=resp.event_id)
return SendResult(success=False, error=getattr(resp, "message", str(resp)))
async def send_image(
self,
chat_id: str,
image_url: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Download an image URL and upload it to Matrix."""
try:
# Try aiohttp first (always available), fall back to httpx
try:
import aiohttp as _aiohttp
async with _aiohttp.ClientSession() as http:
async with http.get(image_url, timeout=_aiohttp.ClientTimeout(total=30)) as resp:
resp.raise_for_status()
data = await resp.read()
ct = resp.content_type or "image/png"
fname = image_url.rsplit("/", 1)[-1].split("?")[0] or "image.png"
except ImportError:
import httpx
async with httpx.AsyncClient() as http:
resp = await http.get(image_url, follow_redirects=True, timeout=30)
resp.raise_for_status()
data = resp.content
ct = resp.headers.get("content-type", "image/png")
fname = image_url.rsplit("/", 1)[-1].split("?")[0] or "image.png"
except Exception as exc:
logger.warning("Matrix: failed to download image %s: %s", image_url, exc)
return await self.send(chat_id, f"{caption or ''}\n{image_url}".strip(), reply_to)
return await self._upload_and_send(chat_id, data, fname, ct, "m.image", caption, reply_to, metadata)
async def send_image_file(
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Upload a local image file to Matrix."""
return await self._send_local_file(chat_id, image_path, "m.image", caption, reply_to, metadata=metadata)
async def send_document(
self,
chat_id: str,
file_path: str,
caption: Optional[str] = None,
file_name: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Upload a local file as a document."""
return await self._send_local_file(chat_id, file_path, "m.file", caption, reply_to, file_name, metadata)
async def send_voice(
self,
chat_id: str,
audio_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Upload an audio file as a voice message."""
return await self._send_local_file(chat_id, audio_path, "m.audio", caption, reply_to, metadata=metadata)
async def send_video(
self,
chat_id: str,
video_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Upload a video file."""
return await self._send_local_file(chat_id, video_path, "m.video", caption, reply_to, metadata=metadata)
def format_message(self, content: str) -> str:
"""Pass-through — Matrix supports standard Markdown natively."""
# Strip image markdown; media is uploaded separately.
content = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", r"\2", content)
return content
# ------------------------------------------------------------------
# File helpers
# ------------------------------------------------------------------
async def _upload_and_send(
self,
room_id: str,
data: bytes,
filename: str,
content_type: str,
msgtype: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Upload bytes to Matrix and send as a media message."""
import nio
# Upload to homeserver.
resp = await self._client.upload(
data,
content_type=content_type,
filename=filename,
)
if not isinstance(resp, nio.UploadResponse):
err = getattr(resp, "message", str(resp))
logger.error("Matrix: upload failed: %s", err)
return SendResult(success=False, error=err)
mxc_url = resp.content_uri
# Build media message content.
msg_content: Dict[str, Any] = {
"msgtype": msgtype,
"body": caption or filename,
"url": mxc_url,
"info": {
"mimetype": content_type,
"size": len(data),
},
}
if reply_to:
msg_content["m.relates_to"] = {
"m.in_reply_to": {"event_id": reply_to}
}
thread_id = (metadata or {}).get("thread_id")
if thread_id:
relates_to = msg_content.get("m.relates_to", {})
relates_to["rel_type"] = "m.thread"
relates_to["event_id"] = thread_id
relates_to["is_falling_back"] = True
msg_content["m.relates_to"] = relates_to
resp2 = await self._client.room_send(room_id, "m.room.message", msg_content)
if isinstance(resp2, nio.RoomSendResponse):
return SendResult(success=True, message_id=resp2.event_id)
return SendResult(success=False, error=getattr(resp2, "message", str(resp2)))
async def _send_local_file(
self,
room_id: str,
file_path: str,
msgtype: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
file_name: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Read a local file and upload it."""
p = Path(file_path)
if not p.exists():
return await self.send(
room_id, f"{caption or ''}\n(file not found: {file_path})", reply_to
)
fname = file_name or p.name
ct = mimetypes.guess_type(fname)[0] or "application/octet-stream"
data = p.read_bytes()
return await self._upload_and_send(room_id, data, fname, ct, msgtype, caption, reply_to, metadata)
# ------------------------------------------------------------------
# Sync loop
# ------------------------------------------------------------------
async def _sync_loop(self) -> None:
"""Continuously sync with the homeserver."""
while not self._closing:
try:
await self._client.sync(timeout=30000)
except asyncio.CancelledError:
return
except Exception as exc:
if self._closing:
return
logger.warning("Matrix: sync error: %s — retrying in 5s", exc)
await asyncio.sleep(5)
# ------------------------------------------------------------------
# Event callbacks
# ------------------------------------------------------------------
async def _on_room_message(self, room: Any, event: Any) -> None:
"""Handle incoming text messages (and decrypted megolm events)."""
import nio
# Ignore own messages.
if event.sender == self._user_id:
return
# Startup grace: ignore old messages from initial sync.
event_ts = getattr(event, "server_timestamp", 0) / 1000.0
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
return
# Handle decrypted MegolmEvents — extract the inner event.
if isinstance(event, nio.MegolmEvent):
# Failed to decrypt.
logger.warning(
"Matrix: could not decrypt event %s in %s",
event.event_id, room.room_id,
)
return
# Skip edits (m.replace relation).
source_content = getattr(event, "source", {}).get("content", {})
relates_to = source_content.get("m.relates_to", {})
if relates_to.get("rel_type") == "m.replace":
return
body = getattr(event, "body", "") or ""
if not body:
return
# Determine chat type.
is_dm = self._dm_rooms.get(room.room_id, False)
if not is_dm and room.member_count == 2:
is_dm = True
chat_type = "dm" if is_dm else "group"
# Thread support.
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
# Reply-to detection.
reply_to = None
in_reply_to = relates_to.get("m.in_reply_to", {})
if in_reply_to:
reply_to = in_reply_to.get("event_id")
# Strip reply fallback from body (Matrix prepends "> ..." lines).
if reply_to and body.startswith("> "):
lines = body.split("\n")
stripped = []
past_fallback = False
for line in lines:
if not past_fallback:
if line.startswith("> ") or line == ">":
continue
if line == "":
past_fallback = True
continue
past_fallback = True
stripped.append(line)
body = "\n".join(stripped) if stripped else body
# Message type.
msg_type = MessageType.TEXT
if body.startswith("!") or body.startswith("/"):
msg_type = MessageType.COMMAND
source = self.build_source(
chat_id=room.room_id,
chat_type=chat_type,
user_id=event.sender,
user_name=self._get_display_name(room, event.sender),
thread_id=thread_id,
)
msg_event = MessageEvent(
text=body,
message_type=msg_type,
source=source,
raw_message=getattr(event, "source", {}),
message_id=event.event_id,
reply_to=reply_to,
)
await self.handle_message(msg_event)
async def _on_room_message_media(self, room: Any, event: Any) -> None:
"""Handle incoming media messages (images, audio, video, files)."""
import nio
# Ignore own messages.
if event.sender == self._user_id:
return
# Startup grace.
event_ts = getattr(event, "server_timestamp", 0) / 1000.0
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
return
body = getattr(event, "body", "") or ""
url = getattr(event, "url", "")
# Convert mxc:// to HTTP URL for downstream processing.
http_url = ""
if url and url.startswith("mxc://"):
http_url = self._mxc_to_http(url)
# Determine message type from event class.
media_type = "document"
msg_type = MessageType.DOCUMENT
if isinstance(event, nio.RoomMessageImage):
msg_type = MessageType.PHOTO
media_type = "image"
elif isinstance(event, nio.RoomMessageAudio):
msg_type = MessageType.AUDIO
media_type = "audio"
elif isinstance(event, nio.RoomMessageVideo):
msg_type = MessageType.VIDEO
media_type = "video"
is_dm = self._dm_rooms.get(room.room_id, False)
if not is_dm and room.member_count == 2:
is_dm = True
chat_type = "dm" if is_dm else "group"
# Thread/reply detection.
source_content = getattr(event, "source", {}).get("content", {})
relates_to = source_content.get("m.relates_to", {})
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
source = self.build_source(
chat_id=room.room_id,
chat_type=chat_type,
user_id=event.sender,
user_name=self._get_display_name(room, event.sender),
thread_id=thread_id,
)
msg_event = MessageEvent(
text=body,
message_type=msg_type,
source=source,
raw_message=getattr(event, "source", {}),
message_id=event.event_id,
media_urls=[http_url] if http_url else None,
media_types=[media_type] if http_url else None,
)
await self.handle_message(msg_event)
async def _on_invite(self, room: Any, event: Any) -> None:
"""Auto-join rooms when invited."""
import nio
if not isinstance(event, nio.InviteMemberEvent):
return
# Only process invites directed at us.
if event.state_key != self._user_id:
return
if event.membership != "invite":
return
logger.info(
"Matrix: invited to %s by %s — joining",
room.room_id, event.sender,
)
try:
resp = await self._client.join(room.room_id)
if isinstance(resp, nio.JoinResponse):
self._joined_rooms.add(room.room_id)
logger.info("Matrix: joined %s", room.room_id)
# Refresh DM cache since new room may be a DM.
await self._refresh_dm_cache()
else:
logger.warning(
"Matrix: failed to join %s: %s",
room.room_id, getattr(resp, "message", resp),
)
except Exception as exc:
logger.warning("Matrix: error joining %s: %s", room.room_id, exc)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
async def _refresh_dm_cache(self) -> None:
"""Refresh the DM room cache from m.direct account data.
Tries the account_data API first, then falls back to parsing
the sync response's account_data for robustness.
"""
if not self._client:
return
dm_data: Optional[Dict] = None
# Primary: try the dedicated account data endpoint.
try:
resp = await self._client.get_account_data("m.direct")
if hasattr(resp, "content"):
dm_data = resp.content
elif isinstance(resp, dict):
dm_data = resp
except Exception as exc:
logger.debug("Matrix: get_account_data('m.direct') failed: %s — trying sync fallback", exc)
# Fallback: parse from the client's account_data store (populated by sync).
if dm_data is None:
try:
# matrix-nio stores account data events on the client object
ad = getattr(self._client, "account_data", None)
if ad and isinstance(ad, dict) and "m.direct" in ad:
event = ad["m.direct"]
if hasattr(event, "content"):
dm_data = event.content
elif isinstance(event, dict):
dm_data = event
except Exception:
pass
if dm_data is None:
return
dm_room_ids: Set[str] = set()
for user_id, rooms in dm_data.items():
if isinstance(rooms, list):
dm_room_ids.update(rooms)
self._dm_rooms = {
rid: (rid in dm_room_ids)
for rid in self._joined_rooms
}
def _get_display_name(self, room: Any, user_id: str) -> str:
"""Get a user's display name in a room, falling back to user_id."""
if room and hasattr(room, "users"):
user = room.users.get(user_id)
if user and getattr(user, "display_name", None):
return user.display_name
# Strip the @...:server format to just the localpart.
if user_id.startswith("@") and ":" in user_id:
return user_id[1:].split(":")[0]
return user_id
def _mxc_to_http(self, mxc_url: str) -> str:
"""Convert mxc://server/media_id to an HTTP download URL."""
# mxc://matrix.org/abc123 → https://matrix.org/_matrix/client/v1/media/download/matrix.org/abc123
# Uses the authenticated client endpoint (spec v1.11+) instead of the
# deprecated /_matrix/media/v3/download/ path.
if not mxc_url.startswith("mxc://"):
return mxc_url
parts = mxc_url[6:] # strip mxc://
# Use our homeserver for download (federation handles the rest).
return f"{self._homeserver}/_matrix/client/v1/media/download/{parts}"
def _markdown_to_html(self, text: str) -> str:
"""Convert Markdown to Matrix-compatible HTML.
Uses a simple conversion for common patterns. For full fidelity
a markdown-it style library could be used, but this covers the
common cases without an extra dependency.
"""
try:
import markdown
html = markdown.markdown(
text,
extensions=["fenced_code", "tables", "nl2br"],
)
# Strip wrapping <p> tags for single-paragraph messages.
if html.count("<p>") == 1:
html = html.replace("<p>", "").replace("</p>", "")
return html
except ImportError:
pass
# Minimal fallback: just handle bold, italic, code.
html = text
html = re.sub(r"\*\*(.+?)\*\*", r"<strong>\1</strong>", html)
html = re.sub(r"\*(.+?)\*", r"<em>\1</em>", html)
html = re.sub(r"`([^`]+)`", r"<code>\1</code>", html)
html = re.sub(r"\n", r"<br>", html)
return html

View file

@ -0,0 +1,664 @@
"""Mattermost gateway adapter.
Connects to a self-hosted (or cloud) Mattermost instance via its REST API
(v4) and WebSocket for real-time events. No external Mattermost library
required uses aiohttp which is already a Hermes dependency.
Environment variables:
MATTERMOST_URL Server URL (e.g. https://mm.example.com)
MATTERMOST_TOKEN Bot token or personal-access token
MATTERMOST_ALLOWED_USERS Comma-separated user IDs
MATTERMOST_HOME_CHANNEL Channel ID for cron/notification delivery
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
SendResult,
)
logger = logging.getLogger(__name__)
# Mattermost post size limit (server default is 16383, but 4000 is the
# practical limit for readable messages — matching OpenClaw's choice).
MAX_POST_LENGTH = 4000
# Channel type codes returned by the Mattermost API.
_CHANNEL_TYPE_MAP = {
"D": "dm",
"G": "group",
"P": "group", # private channel → treat as group
"O": "channel",
}
# Reconnect parameters (exponential backoff).
_RECONNECT_BASE_DELAY = 2.0
_RECONNECT_MAX_DELAY = 60.0
_RECONNECT_JITTER = 0.2
def check_mattermost_requirements() -> bool:
"""Return True if the Mattermost adapter can be used."""
token = os.getenv("MATTERMOST_TOKEN", "")
url = os.getenv("MATTERMOST_URL", "")
if not token:
logger.debug("Mattermost: MATTERMOST_TOKEN not set")
return False
if not url:
logger.warning("Mattermost: MATTERMOST_URL not set")
return False
try:
import aiohttp # noqa: F401
return True
except ImportError:
logger.warning("Mattermost: aiohttp not installed")
return False
class MattermostAdapter(BasePlatformAdapter):
"""Gateway adapter for Mattermost (self-hosted or cloud)."""
def __init__(self, config: PlatformConfig):
super().__init__(config, Platform.MATTERMOST)
self._base_url: str = (
config.extra.get("url", "")
or os.getenv("MATTERMOST_URL", "")
).rstrip("/")
self._token: str = config.token or os.getenv("MATTERMOST_TOKEN", "")
self._bot_user_id: str = ""
self._bot_username: str = ""
# aiohttp session + websocket handle
self._session: Any = None # aiohttp.ClientSession
self._ws: Any = None # aiohttp.ClientWebSocketResponse
self._ws_task: Optional[asyncio.Task] = None
self._reconnect_task: Optional[asyncio.Task] = None
self._closing = False
# Reply mode: "thread" to nest replies, "off" for flat messages.
self._reply_mode: str = (
config.extra.get("reply_mode", "")
or os.getenv("MATTERMOST_REPLY_MODE", "off")
).lower()
# Dedup cache: post_id → timestamp (prevent reprocessing)
self._seen_posts: Dict[str, float] = {}
self._SEEN_MAX = 2000
self._SEEN_TTL = 300 # 5 minutes
# ------------------------------------------------------------------
# HTTP helpers
# ------------------------------------------------------------------
def _headers(self) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self._token}",
"Content-Type": "application/json",
}
async def _api_get(self, path: str) -> Dict[str, Any]:
"""GET /api/v4/{path}."""
import aiohttp
url = f"{self._base_url}/api/v4/{path.lstrip('/')}"
try:
async with self._session.get(url, headers=self._headers()) as resp:
if resp.status >= 400:
body = await resp.text()
logger.error("MM API GET %s%s: %s", path, resp.status, body[:200])
return {}
return await resp.json()
except aiohttp.ClientError as exc:
logger.error("MM API GET %s network error: %s", path, exc)
return {}
async def _api_post(
self, path: str, payload: Dict[str, Any]
) -> Dict[str, Any]:
"""POST /api/v4/{path} with JSON body."""
import aiohttp
url = f"{self._base_url}/api/v4/{path.lstrip('/')}"
try:
async with self._session.post(
url, headers=self._headers(), json=payload
) as resp:
if resp.status >= 400:
body = await resp.text()
logger.error("MM API POST %s%s: %s", path, resp.status, body[:200])
return {}
return await resp.json()
except aiohttp.ClientError as exc:
logger.error("MM API POST %s network error: %s", path, exc)
return {}
async def _api_put(
self, path: str, payload: Dict[str, Any]
) -> Dict[str, Any]:
"""PUT /api/v4/{path} with JSON body."""
import aiohttp
url = f"{self._base_url}/api/v4/{path.lstrip('/')}"
try:
async with self._session.put(
url, headers=self._headers(), json=payload
) as resp:
if resp.status >= 400:
body = await resp.text()
logger.error("MM API PUT %s%s: %s", path, resp.status, body[:200])
return {}
return await resp.json()
except aiohttp.ClientError as exc:
logger.error("MM API PUT %s network error: %s", path, exc)
return {}
async def _upload_file(
self, channel_id: str, file_data: bytes, filename: str, content_type: str = "application/octet-stream"
) -> Optional[str]:
"""Upload a file and return its file ID, or None on failure."""
import aiohttp
url = f"{self._base_url}/api/v4/files"
form = aiohttp.FormData()
form.add_field("channel_id", channel_id)
form.add_field(
"files",
file_data,
filename=filename,
content_type=content_type,
)
headers = {"Authorization": f"Bearer {self._token}"}
async with self._session.post(url, headers=headers, data=form) as resp:
if resp.status >= 400:
body = await resp.text()
logger.error("MM file upload → %s: %s", resp.status, body[:200])
return None
data = await resp.json()
infos = data.get("file_infos", [])
return infos[0]["id"] if infos else None
# ------------------------------------------------------------------
# Required overrides
# ------------------------------------------------------------------
async def connect(self) -> bool:
"""Connect to Mattermost and start the WebSocket listener."""
import aiohttp
if not self._base_url or not self._token:
logger.error("Mattermost: URL or token not configured")
return False
self._session = aiohttp.ClientSession()
self._closing = False
# Verify credentials and fetch bot identity.
me = await self._api_get("users/me")
if not me or "id" not in me:
logger.error("Mattermost: failed to authenticate — check MATTERMOST_TOKEN and MATTERMOST_URL")
await self._session.close()
return False
self._bot_user_id = me["id"]
self._bot_username = me.get("username", "")
logger.info(
"Mattermost: authenticated as @%s (%s) on %s",
self._bot_username,
self._bot_user_id,
self._base_url,
)
# Start WebSocket in background.
self._ws_task = asyncio.create_task(self._ws_loop())
self._mark_connected()
return True
async def disconnect(self) -> None:
"""Disconnect from Mattermost."""
self._closing = True
if self._ws_task and not self._ws_task.done():
self._ws_task.cancel()
try:
await self._ws_task
except (asyncio.CancelledError, Exception):
pass
if self._reconnect_task and not self._reconnect_task.done():
self._reconnect_task.cancel()
if self._ws:
await self._ws.close()
self._ws = None
if self._session and not self._session.closed:
await self._session.close()
logger.info("Mattermost: disconnected")
async def send(
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send a message (or multiple chunks) to a channel."""
if not content:
return SendResult(success=True)
formatted = self.format_message(content)
chunks = self.truncate_message(formatted, MAX_POST_LENGTH)
last_id = None
for chunk in chunks:
payload: Dict[str, Any] = {
"channel_id": chat_id,
"message": chunk,
}
# Thread support: reply_to is the root post ID.
if reply_to and self._reply_mode == "thread":
payload["root_id"] = reply_to
data = await self._api_post("posts", payload)
if not data or "id" not in data:
return SendResult(success=False, error="Failed to create post")
last_id = data["id"]
return SendResult(success=True, message_id=last_id)
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
"""Return channel name and type."""
data = await self._api_get(f"channels/{chat_id}")
if not data:
return {"name": chat_id, "type": "channel"}
ch_type = _CHANNEL_TYPE_MAP.get(data.get("type", "O"), "channel")
display_name = data.get("display_name") or data.get("name") or chat_id
return {"name": display_name, "type": ch_type}
# ------------------------------------------------------------------
# Optional overrides
# ------------------------------------------------------------------
async def send_typing(
self, chat_id: str, metadata: Optional[Dict[str, Any]] = None
) -> None:
"""Send a typing indicator."""
await self._api_post(
f"users/{self._bot_user_id}/typing",
{"channel_id": chat_id},
)
async def edit_message(
self, chat_id: str, message_id: str, content: str
) -> SendResult:
"""Edit an existing post."""
formatted = self.format_message(content)
data = await self._api_put(
f"posts/{message_id}/patch",
{"message": formatted},
)
if not data or "id" not in data:
return SendResult(success=False, error="Failed to edit post")
return SendResult(success=True, message_id=data["id"])
async def send_image(
self,
chat_id: str,
image_url: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Download an image and upload it as a file attachment."""
return await self._send_url_as_file(
chat_id, image_url, caption, reply_to, "image"
)
async def send_image_file(
self,
chat_id: str,
image_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Upload a local image file."""
return await self._send_local_file(
chat_id, image_path, caption, reply_to
)
async def send_document(
self,
chat_id: str,
file_path: str,
caption: Optional[str] = None,
file_name: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Upload a local file as a document."""
return await self._send_local_file(
chat_id, file_path, caption, reply_to, file_name
)
async def send_voice(
self,
chat_id: str,
audio_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Upload an audio file."""
return await self._send_local_file(
chat_id, audio_path, caption, reply_to
)
async def send_video(
self,
chat_id: str,
video_path: str,
caption: Optional[str] = None,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Upload a video file."""
return await self._send_local_file(
chat_id, video_path, caption, reply_to
)
def format_message(self, content: str) -> str:
"""Mattermost uses standard Markdown — mostly pass through.
Strip image markdown into plain links (files are uploaded separately).
"""
# Convert ![alt](url) to just the URL — Mattermost renders
# image URLs as inline previews automatically.
content = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", r"\2", content)
return content
# ------------------------------------------------------------------
# File helpers
# ------------------------------------------------------------------
async def _send_url_as_file(
self,
chat_id: str,
url: str,
caption: Optional[str],
reply_to: Optional[str],
kind: str = "file",
) -> SendResult:
"""Download a URL and upload it as a file attachment."""
import aiohttp
try:
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
if resp.status >= 400:
# Fall back to sending the URL as text.
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
file_data = await resp.read()
ct = resp.content_type or "application/octet-stream"
# Derive filename from URL.
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
except Exception as exc:
logger.warning("Mattermost: failed to download %s: %s", url, exc)
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
file_id = await self._upload_file(chat_id, file_data, fname, ct)
if not file_id:
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
payload: Dict[str, Any] = {
"channel_id": chat_id,
"message": caption or "",
"file_ids": [file_id],
}
if reply_to and self._reply_mode == "thread":
payload["root_id"] = reply_to
data = await self._api_post("posts", payload)
if not data or "id" not in data:
return SendResult(success=False, error="Failed to post with file")
return SendResult(success=True, message_id=data["id"])
async def _send_local_file(
self,
chat_id: str,
file_path: str,
caption: Optional[str],
reply_to: Optional[str],
file_name: Optional[str] = None,
) -> SendResult:
"""Upload a local file and attach it to a post."""
import mimetypes
p = Path(file_path)
if not p.exists():
return await self.send(
chat_id, f"{caption or ''}\n(file not found: {file_path})", reply_to
)
fname = file_name or p.name
ct = mimetypes.guess_type(fname)[0] or "application/octet-stream"
file_data = p.read_bytes()
file_id = await self._upload_file(chat_id, file_data, fname, ct)
if not file_id:
return SendResult(success=False, error="File upload failed")
payload: Dict[str, Any] = {
"channel_id": chat_id,
"message": caption or "",
"file_ids": [file_id],
}
if reply_to and self._reply_mode == "thread":
payload["root_id"] = reply_to
data = await self._api_post("posts", payload)
if not data or "id" not in data:
return SendResult(success=False, error="Failed to post with file")
return SendResult(success=True, message_id=data["id"])
# ------------------------------------------------------------------
# WebSocket
# ------------------------------------------------------------------
async def _ws_loop(self) -> None:
"""Connect to the WebSocket and listen for events, reconnecting on failure."""
delay = _RECONNECT_BASE_DELAY
while not self._closing:
try:
await self._ws_connect_and_listen()
# Clean disconnect — reset delay.
delay = _RECONNECT_BASE_DELAY
except asyncio.CancelledError:
return
except Exception as exc:
if self._closing:
return
logger.warning("Mattermost WS error: %s — reconnecting in %.0fs", exc, delay)
if self._closing:
return
# Exponential backoff with jitter.
import random
jitter = delay * _RECONNECT_JITTER * random.random()
await asyncio.sleep(delay + jitter)
delay = min(delay * 2, _RECONNECT_MAX_DELAY)
async def _ws_connect_and_listen(self) -> None:
"""Single WebSocket session: connect, authenticate, process events."""
# Build WS URL: https:// → wss://, http:// → ws://
ws_url = re.sub(r"^http", "ws", self._base_url) + "/api/v4/websocket"
logger.info("Mattermost: connecting to %s", ws_url)
self._ws = await self._session.ws_connect(ws_url, heartbeat=30.0)
# Authenticate via the WebSocket.
auth_msg = {
"seq": 1,
"action": "authentication_challenge",
"data": {"token": self._token},
}
await self._ws.send_json(auth_msg)
logger.info("Mattermost: WebSocket connected and authenticated")
async for raw_msg in self._ws:
if self._closing:
return
if raw_msg.type in (
raw_msg.type.TEXT,
raw_msg.type.BINARY,
):
try:
event = json.loads(raw_msg.data)
except (json.JSONDecodeError, TypeError):
continue
await self._handle_ws_event(event)
elif raw_msg.type in (
raw_msg.type.ERROR,
raw_msg.type.CLOSE,
raw_msg.type.CLOSING,
raw_msg.type.CLOSED,
):
logger.info("Mattermost: WebSocket closed (%s)", raw_msg.type)
break
async def _handle_ws_event(self, event: Dict[str, Any]) -> None:
"""Process a single WebSocket event."""
event_type = event.get("event")
if event_type != "posted":
return
data = event.get("data", {})
raw_post_str = data.get("post")
if not raw_post_str:
return
try:
post = json.loads(raw_post_str)
except (json.JSONDecodeError, TypeError):
return
# Ignore own messages.
if post.get("user_id") == self._bot_user_id:
return
# Ignore system posts.
if post.get("type"):
return
post_id = post.get("id", "")
# Dedup.
self._prune_seen()
if post_id in self._seen_posts:
return
self._seen_posts[post_id] = time.time()
# Build message event.
channel_id = post.get("channel_id", "")
channel_type_raw = data.get("channel_type", "O")
chat_type = _CHANNEL_TYPE_MAP.get(channel_type_raw, "channel")
# For DMs, user_id is sufficient. For channels, check for @mention.
message_text = post.get("message", "")
# Resolve sender info.
sender_id = post.get("user_id", "")
sender_name = data.get("sender_name", "").lstrip("@") or sender_id
# Thread support: if the post is in a thread, use root_id.
thread_id = post.get("root_id") or None
# Determine message type.
file_ids = post.get("file_ids") or []
msg_type = MessageType.TEXT
if message_text.startswith("/"):
msg_type = MessageType.COMMAND
# Download file attachments immediately (URLs require auth headers
# that downstream tools won't have).
media_urls: List[str] = []
media_types: List[str] = []
for fid in file_ids:
try:
file_info = await self._api_get(f"files/{fid}/info")
fname = file_info.get("name", f"file_{fid}")
ext = Path(fname).suffix or ""
mime = file_info.get("mime_type", "application/octet-stream")
import aiohttp
dl_url = f"{self._base_url}/api/v4/files/{fid}"
async with self._session.get(
dl_url,
headers={"Authorization": f"Bearer {self._token}"},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
if resp.status < 400:
file_data = await resp.read()
from gateway.platforms.base import cache_image_from_bytes, cache_document_from_bytes
if mime.startswith("image/"):
local_path = cache_image_from_bytes(file_data, ext or ".png")
media_urls.append(local_path)
media_types.append("image")
elif mime.startswith("audio/"):
from gateway.platforms.base import cache_audio_from_bytes
local_path = cache_audio_from_bytes(file_data, ext or ".ogg")
media_urls.append(local_path)
media_types.append("audio")
else:
local_path = cache_document_from_bytes(file_data, fname)
media_urls.append(local_path)
media_types.append("document")
else:
logger.warning("Mattermost: failed to download file %s: HTTP %s", fid, resp.status)
except Exception as exc:
logger.warning("Mattermost: error downloading file %s: %s", fid, exc)
source = self.build_source(
chat_id=channel_id,
chat_type=chat_type,
user_id=sender_id,
user_name=sender_name,
thread_id=thread_id,
)
msg_event = MessageEvent(
text=message_text,
message_type=msg_type,
source=source,
raw_message=post,
message_id=post_id,
media_urls=media_urls if media_urls else None,
media_types=media_types if media_types else None,
)
await self.handle_message(msg_event)
def _prune_seen(self) -> None:
"""Remove expired entries from the dedup cache."""
if len(self._seen_posts) < self._SEEN_MAX:
return
now = time.time()
self._seen_posts = {
pid: ts
for pid, ts in self._seen_posts.items()
if now - ts < self._SEEN_TTL
}

View file

@ -1,19 +1,27 @@
"""SMS (Telnyx) platform adapter.
"""SMS (Twilio) platform adapter.
Connects to the Telnyx REST API for outbound SMS and runs an aiohttp
Connects to the Twilio REST API for outbound SMS and runs an aiohttp
webhook server to receive inbound messages.
Requires:
- aiohttp installed: pip install 'hermes-agent[sms]'
- TELNYX_API_KEY environment variable set
- TELNYX_FROM_NUMBERS: comma-separated E.164 numbers (e.g. +15551234567)
Shares credentials with the optional telephony skill same env vars:
- TWILIO_ACCOUNT_SID
- TWILIO_AUTH_TOKEN
- TWILIO_PHONE_NUMBER (E.164 from-number, e.g. +15551234567)
Gateway-specific env vars:
- SMS_WEBHOOK_PORT (default 8080)
- SMS_ALLOWED_USERS (comma-separated E.164 phone numbers)
- SMS_ALLOW_ALL_USERS (true/false)
- SMS_HOME_CHANNEL (phone number for cron delivery)
"""
import asyncio
import base64
import json
import logging
import os
import re
import urllib.parse
from typing import Any, Dict, List, Optional
from gateway.config import Platform, PlatformConfig
@ -26,7 +34,7 @@ from gateway.platforms.base import (
logger = logging.getLogger(__name__)
TELNYX_BASE = "https://api.telnyx.com/v2"
TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts"
MAX_SMS_LENGTH = 1600 # ~10 SMS segments
DEFAULT_WEBHOOK_PORT = 8080
@ -35,17 +43,12 @@ _PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
def _redact_phone(phone: str) -> str:
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
"""Redact a phone number for logging: +15551234567 -> +1555***4567."""
if not phone:
return "<none>"
if len(phone) <= 8:
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
return phone[:4] + "****" + phone[-4:]
def _parse_comma_list(value: str) -> List[str]:
"""Split a comma-separated string into a list, stripping whitespace."""
return [v.strip() for v in value.split(",") if v.strip()]
return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****"
return phone[:5] + "***" + phone[-4:]
def check_sms_requirements() -> bool:
@ -54,32 +57,35 @@ def check_sms_requirements() -> bool:
import aiohttp # noqa: F401
except ImportError:
return False
return bool(os.getenv("TELNYX_API_KEY"))
return bool(os.getenv("TWILIO_ACCOUNT_SID") and os.getenv("TWILIO_AUTH_TOKEN"))
class SmsAdapter(BasePlatformAdapter):
"""
Telnyx SMS <-> Hermes gateway adapter.
Twilio SMS <-> Hermes gateway adapter.
Each inbound phone number gets its own Hermes session (multi-tenant).
Tracks which owned number received each user's message to reply from
the same number.
Replies are always sent from the configured TWILIO_PHONE_NUMBER.
"""
MAX_MESSAGE_LENGTH = MAX_SMS_LENGTH
def __init__(self, config: PlatformConfig):
super().__init__(config, Platform.SMS)
self._api_key: str = os.environ["TELNYX_API_KEY"]
self._account_sid: str = os.environ["TWILIO_ACCOUNT_SID"]
self._auth_token: str = os.environ["TWILIO_AUTH_TOKEN"]
self._from_number: str = os.getenv("TWILIO_PHONE_NUMBER", "")
self._webhook_port: int = int(
os.getenv("SMS_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT))
)
# Set of owned numbers
self._from_numbers: set = set(
_parse_comma_list(os.getenv("TELNYX_FROM_NUMBERS", ""))
)
# Runtime map: user phone -> which owned number to reply from
self._reply_from: Dict[str, str] = {}
self._runner = None
def _basic_auth_header(self) -> str:
"""Build HTTP Basic auth header value for Twilio."""
creds = f"{self._account_sid}:{self._auth_token}"
encoded = base64.b64encode(creds.encode("ascii")).decode("ascii")
return f"Basic {encoded}"
# ------------------------------------------------------------------
# Required abstract methods
# ------------------------------------------------------------------
@ -88,8 +94,12 @@ class SmsAdapter(BasePlatformAdapter):
import aiohttp
from aiohttp import web
if not self._from_number:
logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies")
return False
app = web.Application()
app.router.add_post("/webhooks/telnyx", self._handle_webhook)
app.router.add_post("/webhooks/twilio", self._handle_webhook)
app.router.add_get("/health", lambda _: web.Response(text="ok"))
self._runner = web.AppRunner(app)
@ -98,11 +108,10 @@ class SmsAdapter(BasePlatformAdapter):
await site.start()
self._running = True
from_display = ", ".join(_redact_phone(n) for n in self._from_numbers) or "(none)"
logger.info(
"[sms] Webhook server listening on port %d, from numbers: %s",
"[sms] Twilio webhook server listening on port %d, from: %s",
self._webhook_port,
from_display,
_redact_phone(self._from_number),
)
return True
@ -122,40 +131,41 @@ class SmsAdapter(BasePlatformAdapter):
) -> SendResult:
import aiohttp
from_number = self._get_reply_from(chat_id, metadata)
formatted = self.format_message(content)
chunks = self.truncate_message(formatted)
last_result = SendResult(success=True)
url = f"{TWILIO_API_BASE}/{self._account_sid}/Messages.json"
headers = {
"Authorization": self._basic_auth_header(),
}
async with aiohttp.ClientSession() as session:
for i, chunk in enumerate(chunks):
payload = {"from": from_number, "to": chat_id, "text": chunk}
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
for chunk in chunks:
form_data = aiohttp.FormData()
form_data.add_field("From", self._from_number)
form_data.add_field("To", chat_id)
form_data.add_field("Body", chunk)
try:
async with session.post(
f"{TELNYX_BASE}/messages",
json=payload,
headers=headers,
) as resp:
async with session.post(url, data=form_data, headers=headers) as resp:
body = await resp.json()
if resp.status >= 400:
error_msg = body.get("message", str(body))
logger.error(
"[sms] send failed %s: %s %s",
"[sms] send failed to %s: %s %s",
_redact_phone(chat_id),
resp.status,
body,
error_msg,
)
return SendResult(
success=False,
error=f"Telnyx {resp.status}: {body}",
error=f"Twilio {resp.status}: {error_msg}",
)
msg_id = body.get("data", {}).get("id", "")
last_result = SendResult(success=True, message_id=msg_id)
msg_sid = body.get("sid", "")
last_result = SendResult(success=True, message_id=msg_sid)
except Exception as e:
logger.error("[sms] send error %s: %s", _redact_phone(chat_id), e)
logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e)
return SendResult(success=False, error=str(e))
return last_result
@ -168,7 +178,7 @@ class SmsAdapter(BasePlatformAdapter):
# ------------------------------------------------------------------
def format_message(self, content: str) -> str:
"""Strip markdown -- SMS renders it as literal characters."""
"""Strip markdown SMS renders it as literal characters."""
content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL)
content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL)
content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL)
@ -180,28 +190,8 @@ class SmsAdapter(BasePlatformAdapter):
content = re.sub(r"\n{3,}", "\n\n", content)
return content.strip()
def truncate_message(
self, content: str, max_length: int = MAX_SMS_LENGTH
) -> List[str]:
"""Split into <=1600-char chunks (10 SMS segments)."""
if len(content) <= max_length:
return [content]
chunks: List[str] = []
while content:
if len(content) <= max_length:
chunks.append(content)
break
split_at = content.rfind("\n", 0, max_length)
if split_at < max_length // 2:
split_at = content.rfind(" ", 0, max_length)
if split_at < 1:
split_at = max_length
chunks.append(content[:split_at].strip())
content = content[split_at:].strip()
return chunks
# ------------------------------------------------------------------
# Telnyx webhook handler
# Twilio webhook handler
# ------------------------------------------------------------------
async def _handle_webhook(self, request) -> "aiohttp.web.Response":
@ -209,32 +199,35 @@ class SmsAdapter(BasePlatformAdapter):
try:
raw = await request.read()
body = json.loads(raw.decode("utf-8"))
# Twilio sends form-encoded data, not JSON
form = urllib.parse.parse_qs(raw.decode("utf-8"))
except Exception as e:
logger.error("[sms] webhook parse error: %s", e)
return web.json_response({"error": "invalid json"}, status=400)
return web.Response(
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
content_type="application/xml",
status=400,
)
# Only handle inbound messages
if body.get("data", {}).get("event_type") != "message.received":
return web.json_response({"received": True})
payload = body["data"]["payload"]
from_number: str = payload.get("from", {}).get("phone_number", "")
to_list = payload.get("to", [])
to_number: str = to_list[0].get("phone_number", "") if to_list else ""
text: str = payload.get("text", "").strip()
# Extract fields (parse_qs returns lists)
from_number = (form.get("From", [""]))[0].strip()
to_number = (form.get("To", [""]))[0].strip()
text = (form.get("Body", [""]))[0].strip()
message_sid = (form.get("MessageSid", [""]))[0].strip()
if not from_number or not text:
return web.json_response({"received": True})
return web.Response(
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
content_type="application/xml",
)
# Ignore messages sent FROM one of our own numbers (echo loop prevention)
if from_number in self._from_numbers:
# Ignore messages from our own number (echo prevention)
if from_number == self._from_number:
logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number))
return web.json_response({"received": True})
# Remember which owned number received this user's message
if to_number and to_number in self._from_numbers:
self._reply_from[from_number] = to_number
return web.Response(
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
content_type="application/xml",
)
logger.info(
"[sms] inbound from %s -> %s: %s",
@ -254,29 +247,15 @@ class SmsAdapter(BasePlatformAdapter):
text=text,
message_type=MessageType.TEXT,
source=source,
raw_message=body,
message_id=payload.get("id"),
raw_message=form,
message_id=message_sid,
)
# Non-blocking: Telnyx expects a fast 200
# Non-blocking: Twilio expects a fast response
asyncio.create_task(self.handle_message(event))
return web.json_response({"received": True})
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _get_reply_from(
self, user_phone: str, metadata: Optional[Dict] = None
) -> str:
"""Determine which owned number to send from."""
if metadata and "from_number" in metadata:
return metadata["from_number"]
if user_phone in self._reply_from:
return self._reply_from[user_phone]
if self._from_numbers:
return next(iter(self._from_numbers))
raise RuntimeError(
"No FROM number configured (TELNYX_FROM_NUMBERS) and no prior "
"reply_from mapping for this user"
# Return empty TwiML — we send replies via the REST API, not inline TwiML
return web.Response(
text='<?xml version="1.0" encoding="UTF-8"?><Response></Response>',
content_type="application/xml",
)

View file

@ -848,7 +848,8 @@ class GatewayRunner:
os.getenv(v)
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
"SMS_ALLOWED_USERS", "GATEWAY_ALLOWED_USERS")
"SMS_ALLOWED_USERS",
"GATEWAY_ALLOWED_USERS")
)
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
if not _any_allowlist and not _allow_all:
@ -983,6 +984,16 @@ class GatewayRunner:
):
self._schedule_update_notification_watch()
# Drain any recovered process watchers (from crash recovery checkpoint)
try:
from tools.process_registry import process_registry
while process_registry.pending_watchers:
watcher = process_registry.pending_watchers.pop(0)
asyncio.create_task(self._run_process_watcher(watcher))
logger.info("Resumed watcher for recovered process %s", watcher.get("session_id"))
except Exception as e:
logger.error("Recovered watcher setup error: %s", e)
# Start background session expiry watcher for proactive memory flushing
asyncio.create_task(self._session_expiry_watcher())
@ -1135,10 +1146,31 @@ class GatewayRunner:
elif platform == Platform.SMS:
from gateway.platforms.sms import SmsAdapter, check_sms_requirements
if not check_sms_requirements():
logger.warning("SMS: aiohttp not installed or TELNYX_API_KEY not set. Run: pip install 'hermes-agent[sms]'")
logger.warning("SMS: aiohttp not installed or TWILIO_ACCOUNT_SID/TWILIO_AUTH_TOKEN not set")
return None
return SmsAdapter(config)
elif platform == Platform.DINGTALK:
from gateway.platforms.dingtalk import DingTalkAdapter, check_dingtalk_requirements
if not check_dingtalk_requirements():
logger.warning("DingTalk: dingtalk-stream not installed or DINGTALK_CLIENT_ID/SECRET not set")
return None
return DingTalkAdapter(config)
elif platform == Platform.MATTERMOST:
from gateway.platforms.mattermost import MattermostAdapter, check_mattermost_requirements
if not check_mattermost_requirements():
logger.warning("Mattermost: MATTERMOST_TOKEN or MATTERMOST_URL not set, or aiohttp missing")
return None
return MattermostAdapter(config)
elif platform == Platform.MATRIX:
from gateway.platforms.matrix import MatrixAdapter, check_matrix_requirements
if not check_matrix_requirements():
logger.warning("Matrix: matrix-nio not installed or credentials not set. Run: pip install 'matrix-nio[e2e]'")
return None
return MatrixAdapter(config)
return None
def _is_user_authorized(self, source: SessionSource) -> bool:
@ -1170,6 +1202,9 @@ class GatewayRunner:
Platform.SIGNAL: "SIGNAL_ALLOWED_USERS",
Platform.EMAIL: "EMAIL_ALLOWED_USERS",
Platform.SMS: "SMS_ALLOWED_USERS",
Platform.MATTERMOST: "MATTERMOST_ALLOWED_USERS",
Platform.MATRIX: "MATRIX_ALLOWED_USERS",
Platform.DINGTALK: "DINGTALK_ALLOWED_USERS",
}
platform_allow_all_map = {
Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS",
@ -1179,6 +1214,9 @@ class GatewayRunner:
Platform.SIGNAL: "SIGNAL_ALLOW_ALL_USERS",
Platform.EMAIL: "EMAIL_ALLOW_ALL_USERS",
Platform.SMS: "SMS_ALLOW_ALL_USERS",
Platform.MATTERMOST: "MATTERMOST_ALLOW_ALL_USERS",
Platform.MATRIX: "MATRIX_ALLOW_ALL_USERS",
Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS",
}
# Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true)
@ -1430,8 +1468,19 @@ class GatewayRunner:
return f"Quick command error: {e}"
else:
return f"Quick command '/{command}' has no command defined."
elif qcmd.get("type") == "alias":
target = qcmd.get("target", "").strip()
if target:
target = target if target.startswith("/") else f"/{target}"
target_command = target.lstrip("/")
user_args = event.get_command_args().strip()
event.text = f"{target} {user_args}".strip()
command = target_command
# Fall through to normal command dispatch below
else:
return f"Quick command '/{command}' has no target defined."
else:
return f"Quick command '/{command}' has unsupported type (only 'exec' is supported)."
return f"Quick command '/{command}' has unsupported type (supported: 'exec', 'alias')."
# Skill slash commands: /skill-name loads the skill and sends to agent
if command:
@ -1442,7 +1491,7 @@ class GatewayRunner:
if cmd_key in skill_cmds:
user_instruction = event.get_command_args().strip()
msg = build_skill_invocation_message(
cmd_key, user_instruction, task_id=session_key
cmd_key, user_instruction, task_id=_quick_key
)
if msg:
event.text = msg
@ -1503,8 +1552,9 @@ class GatewayRunner:
# Read privacy.redact_pii from config (re-read per message)
_redact_pii = False
try:
import yaml as _pii_yaml
with open(_config_path, encoding="utf-8") as _pf:
_pcfg = yaml.safe_load(_pf) or {}
_pcfg = _pii_yaml.safe_load(_pf) or {}
_redact_pii = bool((_pcfg.get("privacy") or {}).get("redact_pii", False))
except Exception:
pass
@ -2050,8 +2100,15 @@ class GatewayRunner:
session_entry.session_key,
input_tokens=agent_result.get("input_tokens", 0),
output_tokens=agent_result.get("output_tokens", 0),
cache_read_tokens=agent_result.get("cache_read_tokens", 0),
cache_write_tokens=agent_result.get("cache_write_tokens", 0),
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
model=agent_result.get("model"),
estimated_cost_usd=agent_result.get("estimated_cost_usd"),
cost_status=agent_result.get("cost_status"),
cost_source=agent_result.get("cost_source"),
provider=agent_result.get("provider"),
base_url=agent_result.get("base_url"),
)
# Auto voice reply: send TTS audio before the text response
@ -2121,7 +2178,14 @@ class GatewayRunner:
# Reset the session
new_entry = self.session_store.reset_session(session_key)
# Emit session:end hook (session is ending)
await self.hooks.emit("session:end", {
"platform": source.platform.value if source.platform else "",
"user_id": source.user_id,
"session_key": session_key,
})
# Emit session:reset hook
await self.hooks.emit("session:reset", {
"platform": source.platform.value if source.platform else "",
@ -3027,6 +3091,7 @@ class GatewayRunner:
Platform.SIGNAL: "hermes-signal",
Platform.HOMEASSISTANT: "hermes-homeassistant",
Platform.EMAIL: "hermes-email",
Platform.DINGTALK: "hermes-dingtalk",
}
platform_toolsets_config = {}
try:
@ -3048,6 +3113,7 @@ class GatewayRunner:
Platform.SIGNAL: "signal",
Platform.HOMEASSISTANT: "homeassistant",
Platform.EMAIL: "email",
Platform.DINGTALK: "dingtalk",
}.get(source.platform, "telegram")
config_toolsets = platform_toolsets_config.get(platform_config_key)
@ -4045,6 +4111,7 @@ class GatewayRunner:
Platform.SIGNAL: "hermes-signal",
Platform.HOMEASSISTANT: "hermes-homeassistant",
Platform.EMAIL: "hermes-email",
Platform.DINGTALK: "hermes-dingtalk",
}
# Try to load platform_toolsets from config
@ -4069,6 +4136,7 @@ class GatewayRunner:
Platform.SIGNAL: "signal",
Platform.HOMEASSISTANT: "homeassistant",
Platform.EMAIL: "email",
Platform.DINGTALK: "dingtalk",
}.get(source.platform, "telegram")
# Use config override if present (list of toolsets), otherwise hardcoded default

View file

@ -343,7 +343,11 @@ class SessionEntry:
# Token tracking
input_tokens: int = 0
output_tokens: int = 0
cache_read_tokens: int = 0
cache_write_tokens: int = 0
total_tokens: int = 0
estimated_cost_usd: float = 0.0
cost_status: str = "unknown"
# Last API-reported prompt tokens (for accurate compression pre-check)
last_prompt_tokens: int = 0
@ -363,8 +367,12 @@ class SessionEntry:
"chat_type": self.chat_type,
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cache_read_tokens": self.cache_read_tokens,
"cache_write_tokens": self.cache_write_tokens,
"total_tokens": self.total_tokens,
"last_prompt_tokens": self.last_prompt_tokens,
"estimated_cost_usd": self.estimated_cost_usd,
"cost_status": self.cost_status,
}
if self.origin:
result["origin"] = self.origin.to_dict()
@ -394,8 +402,12 @@ class SessionEntry:
chat_type=data.get("chat_type", "dm"),
input_tokens=data.get("input_tokens", 0),
output_tokens=data.get("output_tokens", 0),
cache_read_tokens=data.get("cache_read_tokens", 0),
cache_write_tokens=data.get("cache_write_tokens", 0),
total_tokens=data.get("total_tokens", 0),
last_prompt_tokens=data.get("last_prompt_tokens", 0),
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
cost_status=data.get("cost_status", "unknown"),
)
@ -696,8 +708,15 @@ class SessionStore:
session_key: str,
input_tokens: int = 0,
output_tokens: int = 0,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
last_prompt_tokens: int = None,
model: str = None,
estimated_cost_usd: Optional[float] = None,
cost_status: Optional[str] = None,
cost_source: Optional[str] = None,
provider: Optional[str] = None,
base_url: Optional[str] = None,
) -> None:
"""Update a session's metadata after an interaction."""
self._ensure_loaded()
@ -707,15 +726,35 @@ class SessionStore:
entry.updated_at = datetime.now()
entry.input_tokens += input_tokens
entry.output_tokens += output_tokens
entry.cache_read_tokens += cache_read_tokens
entry.cache_write_tokens += cache_write_tokens
if last_prompt_tokens is not None:
entry.last_prompt_tokens = last_prompt_tokens
entry.total_tokens = entry.input_tokens + entry.output_tokens
if estimated_cost_usd is not None:
entry.estimated_cost_usd += estimated_cost_usd
if cost_status:
entry.cost_status = cost_status
entry.total_tokens = (
entry.input_tokens
+ entry.output_tokens
+ entry.cache_read_tokens
+ entry.cache_write_tokens
)
self._save()
if self._db:
try:
self._db.update_token_counts(
entry.session_id, input_tokens, output_tokens,
entry.session_id,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
estimated_cost_usd=estimated_cost_usd,
cost_status=cost_status,
cost_source=cost_source,
billing_provider=provider,
billing_base_url=base_url,
model=model,
)
except Exception as e:

View file

@ -34,8 +34,11 @@ _EXTRA_ENV_KEYS = frozenset({
"DISCORD_HOME_CHANNEL", "TELEGRAM_HOME_CHANNEL",
"SIGNAL_ACCOUNT", "SIGNAL_HTTP_URL",
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
"DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET",
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
"WHATSAPP_MODE", "WHATSAPP_ENABLED",
"MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE",
"MATRIX_PASSWORD", "MATRIX_ENCRYPTION", "MATRIX_HOME_ROOM",
})
import yaml
@ -354,6 +357,11 @@ DEFAULT_CONFIG = {
"tirith_path": "tirith",
"tirith_timeout": 5,
"tirith_fail_open": True,
"website_blocklist": {
"enabled": False,
"domains": [],
"shared_files": [],
},
},
# Config schema version - bump this when adding new required fields
@ -371,6 +379,7 @@ ENV_VARS_BY_VERSION: Dict[int, List[str]] = {
4: ["VOICE_TOOLS_OPENAI_KEY", "ELEVENLABS_API_KEY"],
5: ["WHATSAPP_ENABLED", "WHATSAPP_MODE", "WHATSAPP_ALLOWED_USERS",
"SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", "SLACK_ALLOWED_USERS"],
10: ["TAVILY_API_KEY"],
}
# Required environment variables with metadata for migration prompts.
@ -542,6 +551,14 @@ OPTIONAL_ENV_VARS = {
},
# ── Tool API keys ──
"PARALLEL_API_KEY": {
"description": "Parallel API key for AI-native web search and extract",
"prompt": "Parallel API key",
"url": "https://parallel.ai/",
"tools": ["web_search", "web_extract"],
"password": True,
"category": "tool",
},
"FIRECRAWL_API_KEY": {
"description": "Firecrawl API key for web search and scraping",
"prompt": "Firecrawl API key",
@ -558,6 +575,14 @@ OPTIONAL_ENV_VARS = {
"category": "tool",
"advanced": True,
},
"TAVILY_API_KEY": {
"description": "Tavily API key for AI-native web search, extract, and crawl",
"prompt": "Tavily API key",
"url": "https://app.tavily.com/home",
"tools": ["web_search", "web_extract", "web_crawl"],
"password": True,
"category": "tool",
},
"BROWSERBASE_API_KEY": {
"description": "Browserbase API key for cloud browser (optional — local browser works without this)",
"prompt": "Browserbase API key",
@ -686,6 +711,55 @@ OPTIONAL_ENV_VARS = {
"password": True,
"category": "messaging",
},
"MATTERMOST_URL": {
"description": "Mattermost server URL (e.g. https://mm.example.com)",
"prompt": "Mattermost server URL",
"url": "https://mattermost.com/deploy/",
"password": False,
"category": "messaging",
},
"MATTERMOST_TOKEN": {
"description": "Mattermost bot token or personal access token",
"prompt": "Mattermost bot token",
"url": None,
"password": True,
"category": "messaging",
},
"MATTERMOST_ALLOWED_USERS": {
"description": "Comma-separated Mattermost user IDs allowed to use the bot",
"prompt": "Allowed Mattermost user IDs (comma-separated)",
"url": None,
"password": False,
"category": "messaging",
},
"MATRIX_HOMESERVER": {
"description": "Matrix homeserver URL (e.g. https://matrix.example.org)",
"prompt": "Matrix homeserver URL",
"url": "https://matrix.org/ecosystem/servers/",
"password": False,
"category": "messaging",
},
"MATRIX_ACCESS_TOKEN": {
"description": "Matrix access token (preferred over password login)",
"prompt": "Matrix access token",
"url": None,
"password": True,
"category": "messaging",
},
"MATRIX_USER_ID": {
"description": "Matrix user ID (e.g. @hermes:example.org)",
"prompt": "Matrix user ID (@user:server)",
"url": None,
"password": False,
"category": "messaging",
},
"MATRIX_ALLOWED_USERS": {
"description": "Comma-separated Matrix user IDs allowed to use the bot (@user:server format)",
"prompt": "Allowed Matrix user IDs (comma-separated)",
"url": None,
"password": False,
"category": "messaging",
},
"GATEWAY_ALLOW_ALL_USERS": {
"description": "Allow all users to interact with messaging bots (true/false). Default: false.",
"prompt": "Allow all users (true/false)",
@ -1449,7 +1523,9 @@ def show_config():
keys = [
("OPENROUTER_API_KEY", "OpenRouter"),
("VOICE_TOOLS_OPENAI_KEY", "OpenAI (STT/TTS)"),
("PARALLEL_API_KEY", "Parallel"),
("FIRECRAWL_API_KEY", "Firecrawl"),
("TAVILY_API_KEY", "Tavily"),
("BROWSERBASE_API_KEY", "Browserbase"),
("BROWSER_USE_API_KEY", "Browser Use"),
("FAL_KEY", "FAL"),
@ -1598,7 +1674,8 @@ def set_config_value(key: str, value: str):
# Check if it's an API key (goes to .env)
api_keys = [
'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY',
'PARALLEL_API_KEY', 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'TAVILY_API_KEY',
'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY',
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',

View file

@ -1001,6 +1001,64 @@ _PLATFORMS = [
"help": "Paste your member ID from step 7 above."},
],
},
{
"key": "matrix",
"label": "Matrix",
"emoji": "🔐",
"token_var": "MATRIX_ACCESS_TOKEN",
"setup_instructions": [
"1. Works with any Matrix homeserver (self-hosted Synapse/Conduit/Dendrite or matrix.org)",
"2. Create a bot user on your homeserver, or use your own account",
"3. Get an access token: Element → Settings → Help & About → Access Token",
" Or via API: curl -X POST https://your-server/_matrix/client/v3/login \\",
" -d '{\"type\":\"m.login.password\",\"user\":\"@bot:server\",\"password\":\"...\"}'",
"4. Alternatively, provide user ID + password and Hermes will log in directly",
"5. For E2EE: set MATRIX_ENCRYPTION=true (requires pip install 'matrix-nio[e2e]')",
"6. To find your user ID: it's @username:your-server (shown in Element profile)",
],
"vars": [
{"name": "MATRIX_HOMESERVER", "prompt": "Homeserver URL (e.g. https://matrix.example.org)", "password": False,
"help": "Your Matrix homeserver URL. Works with any self-hosted instance."},
{"name": "MATRIX_ACCESS_TOKEN", "prompt": "Access token (leave empty to use password login instead)", "password": True,
"help": "Paste your access token, or leave empty and provide user ID + password below."},
{"name": "MATRIX_USER_ID", "prompt": "User ID (@bot:server — required for password login)", "password": False,
"help": "Full Matrix user ID, e.g. @hermes:matrix.example.org"},
{"name": "MATRIX_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated, e.g. @you:server)", "password": False,
"is_allowlist": True,
"help": "Matrix user IDs who can interact with the bot."},
{"name": "MATRIX_HOME_ROOM", "prompt": "Home room ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
"help": "Room ID (e.g. !abc123:server) for delivering cron results and notifications."},
],
},
{
"key": "mattermost",
"label": "Mattermost",
"emoji": "💬",
"token_var": "MATTERMOST_TOKEN",
"setup_instructions": [
"1. In Mattermost: Integrations → Bot Accounts → Add Bot Account",
" (System Console → Integrations → Bot Accounts must be enabled)",
"2. Give it a username (e.g. hermes) and copy the bot token",
"3. Works with any self-hosted Mattermost instance — enter your server URL",
"4. To find your user ID: click your avatar (top-left) → Profile",
" Your user ID is displayed there — click it to copy.",
" ⚠ This is NOT your username — it's a 26-character alphanumeric ID.",
"5. To get a channel ID: click the channel name → View Info → copy the ID",
],
"vars": [
{"name": "MATTERMOST_URL", "prompt": "Server URL (e.g. https://mm.example.com)", "password": False,
"help": "Your Mattermost server URL. Works with any self-hosted instance."},
{"name": "MATTERMOST_TOKEN", "prompt": "Bot token", "password": True,
"help": "Paste the bot token from step 2 above."},
{"name": "MATTERMOST_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False,
"is_allowlist": True,
"help": "Your Mattermost user ID from step 4 above."},
{"name": "MATTERMOST_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False,
"help": "Channel ID where Hermes delivers cron results and notifications."},
{"name": "MATTERMOST_REPLY_MODE", "prompt": "Reply mode — 'off' for flat messages, 'thread' for threaded replies (default: off)", "password": False,
"help": "off = flat channel messages, thread = replies nest under your message."},
],
},
{
"key": "whatsapp",
"label": "WhatsApp",
@ -1013,30 +1071,6 @@ _PLATFORMS = [
"emoji": "📡",
"token_var": "SIGNAL_HTTP_URL",
},
{
"key": "sms",
"label": "SMS (Telnyx)",
"emoji": "📱",
"token_var": "TELNYX_API_KEY",
"setup_instructions": [
"1. Create a Telnyx account at https://portal.telnyx.com/",
"2. Buy a phone number with SMS capability",
"3. Create an API key: API Keys → Create API Key",
"4. Set up a Messaging Profile and assign your number to it",
"5. Configure the webhook URL: https://your-server/webhooks/telnyx",
],
"vars": [
{"name": "TELNYX_API_KEY", "prompt": "Telnyx API key", "password": True,
"help": "Paste the API key from step 3 above."},
{"name": "TELNYX_FROM_NUMBERS", "prompt": "From numbers (comma-separated E.164, e.g. +15551234567)", "password": False,
"help": "The Telnyx phone number(s) Hermes will send SMS from."},
{"name": "SMS_ALLOWED_USERS", "prompt": "Allowed phone numbers (comma-separated E.164)", "password": False,
"is_allowlist": True,
"help": "Only messages from these phone numbers will be processed."},
{"name": "SMS_HOME_CHANNEL", "prompt": "Home channel phone (for cron/notification delivery, or empty)", "password": False,
"help": "A phone number where cron job outputs are delivered."},
],
},
{
"key": "email",
"label": "Email",
@ -1063,6 +1097,51 @@ _PLATFORMS = [
"help": "Only emails from these addresses will be processed."},
],
},
{
"key": "sms",
"label": "SMS (Twilio)",
"emoji": "📱",
"token_var": "TWILIO_ACCOUNT_SID",
"setup_instructions": [
"1. Create a Twilio account at https://www.twilio.com/",
"2. Get your Account SID and Auth Token from the Twilio Console dashboard",
"3. Buy or configure a phone number capable of sending SMS",
"4. Set up your webhook URL for inbound SMS:",
" Twilio Console → Phone Numbers → Active Numbers → your number",
" → Messaging → A MESSAGE COMES IN → Webhook → https://your-server:8080/webhooks/twilio",
],
"vars": [
{"name": "TWILIO_ACCOUNT_SID", "prompt": "Twilio Account SID", "password": False,
"help": "Found on the Twilio Console dashboard."},
{"name": "TWILIO_AUTH_TOKEN", "prompt": "Twilio Auth Token", "password": True,
"help": "Found on the Twilio Console dashboard (click to reveal)."},
{"name": "TWILIO_PHONE_NUMBER", "prompt": "Twilio phone number (E.164 format, e.g. +15551234567)", "password": False,
"help": "The Twilio phone number to send SMS from."},
{"name": "SMS_ALLOWED_USERS", "prompt": "Allowed phone numbers (comma-separated, E.164 format)", "password": False,
"is_allowlist": True,
"help": "Only messages from these phone numbers will be processed."},
{"name": "SMS_HOME_CHANNEL", "prompt": "Home channel phone number (for cron/notification delivery, or empty)", "password": False,
"help": "Phone number to deliver cron job results and notifications to."},
],
},
{
"key": "dingtalk",
"label": "DingTalk",
"emoji": "💬",
"token_var": "DINGTALK_CLIENT_ID",
"setup_instructions": [
"1. Go to https://open-dev.dingtalk.com → Create Application",
"2. Under 'Credentials', copy the AppKey (Client ID) and AppSecret (Client Secret)",
"3. Enable 'Stream Mode' under the bot settings",
"4. Add the bot to a group chat or message it directly",
],
"vars": [
{"name": "DINGTALK_CLIENT_ID", "prompt": "AppKey (Client ID)", "password": False,
"help": "The AppKey from your DingTalk application credentials."},
{"name": "DINGTALK_CLIENT_SECRET", "prompt": "AppSecret (Client Secret)", "password": True,
"help": "The AppSecret from your DingTalk application credentials."},
],
},
]
@ -1097,6 +1176,16 @@ def _platform_status(platform: dict) -> str:
if any([val, pwd, imap, smtp]):
return "partially configured"
return "not configured"
if platform.get("key") == "matrix":
homeserver = get_env_value("MATRIX_HOMESERVER")
password = get_env_value("MATRIX_PASSWORD")
if (val or password) and homeserver:
e2ee = get_env_value("MATRIX_ENCRYPTION")
suffix = " + E2EE" if e2ee and e2ee.lower() in ("true", "1", "yes") else ""
return f"configured{suffix}"
if val or password or homeserver:
return "partially configured"
return "not configured"
if val:
return "configured"
return "not configured"

View file

@ -784,6 +784,7 @@ def cmd_model(args):
"opencode-go": "OpenCode Go",
"ai-gateway": "AI Gateway",
"kilocode": "Kilo Code",
"alibaba": "Alibaba Cloud (DashScope)",
"custom": "Custom endpoint",
}
active_label = provider_labels.get(active, active)
@ -807,6 +808,7 @@ def cmd_model(args):
("opencode-zen", "OpenCode Zen (35+ curated models, pay-as-you-go)"),
("opencode-go", "OpenCode Go (open models, $10/month subscription)"),
("ai-gateway", "AI Gateway (Vercel — 200+ models, pay-per-use)"),
("alibaba", "Alibaba Cloud / DashScope (Qwen models, Anthropic-compatible)"),
]
# Add user-defined custom providers from config.yaml
@ -875,7 +877,7 @@ def cmd_model(args):
_model_flow_anthropic(config, current_model)
elif selected_provider == "kimi-coding":
_model_flow_kimi(config, current_model)
elif selected_provider in ("zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway"):
elif selected_provider in ("zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba"):
_model_flow_api_key_provider(config, selected_provider, current_model)
@ -1994,20 +1996,32 @@ def _update_via_zip(args):
print(f"✗ ZIP update failed: {e}")
sys.exit(1)
# Reinstall Python dependencies
# Reinstall Python dependencies (try .[all] first for optional extras,
# fall back to . if extras fail — mirrors the install script behavior)
print("→ Updating Python dependencies...")
import subprocess
uv_bin = shutil.which("uv")
if uv_bin:
subprocess.run(
[uv_bin, "pip", "install", "-e", ".", "--quiet"],
cwd=PROJECT_ROOT, check=True,
env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}
)
uv_env = {**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}
try:
subprocess.run(
[uv_bin, "pip", "install", "-e", ".[all]", "--quiet"],
cwd=PROJECT_ROOT, check=True, env=uv_env,
)
except subprocess.CalledProcessError:
print(" ⚠ Optional extras failed, installing base dependencies...")
subprocess.run(
[uv_bin, "pip", "install", "-e", ".", "--quiet"],
cwd=PROJECT_ROOT, check=True, env=uv_env,
)
else:
venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip"
if venv_pip.exists():
subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
pip_cmd = [str(venv_pip)] if venv_pip.exists() else ["pip"]
try:
subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True)
except subprocess.CalledProcessError:
print(" ⚠ Optional extras failed, installing base dependencies...")
subprocess.run(pip_cmd + ["install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
# Sync skills
try:
@ -2255,21 +2269,31 @@ def cmd_update(args):
_invalidate_update_cache()
# Reinstall Python dependencies (prefer uv for speed, fall back to pip)
# Reinstall Python dependencies (try .[all] first for optional extras,
# fall back to . if extras fail — mirrors the install script behavior)
print("→ Updating Python dependencies...")
uv_bin = shutil.which("uv")
if uv_bin:
subprocess.run(
[uv_bin, "pip", "install", "-e", ".", "--quiet"],
cwd=PROJECT_ROOT, check=True,
env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}
)
uv_env = {**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}
try:
subprocess.run(
[uv_bin, "pip", "install", "-e", ".[all]", "--quiet"],
cwd=PROJECT_ROOT, check=True, env=uv_env,
)
except subprocess.CalledProcessError:
print(" ⚠ Optional extras failed, installing base dependencies...")
subprocess.run(
[uv_bin, "pip", "install", "-e", ".", "--quiet"],
cwd=PROJECT_ROOT, check=True, env=uv_env,
)
else:
venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip"
if venv_pip.exists():
subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
else:
subprocess.run(["pip", "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
pip_cmd = [str(venv_pip)] if venv_pip.exists() else ["pip"]
try:
subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True)
except subprocess.CalledProcessError:
print(" ⚠ Optional extras failed, installing base dependencies...")
subprocess.run(pip_cmd + ["install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
# Check for Node.js deps
if (PROJECT_ROOT / "package.json").exists():

View file

@ -473,7 +473,7 @@ def provider_model_ids(provider: Optional[str]) -> list[str]:
from hermes_cli.auth import fetch_nous_models, resolve_nous_runtime_credentials
creds = resolve_nous_runtime_credentials()
if creds:
live = fetch_nous_models(creds.get("api_key", ""), creds.get("base_url", ""))
live = fetch_nous_models(api_key=creds.get("api_key", ""), inference_base_url=creds.get("base_url", ""))
if live:
return live
except Exception:

View file

@ -444,11 +444,11 @@ def _print_setup_summary(config: dict, hermes_home):
else:
tool_status.append(("Mixture of Agents", False, "OPENROUTER_API_KEY"))
# Firecrawl (web tools)
if get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL"):
# Web tools (Parallel, Firecrawl, or Tavily)
if get_env_value("PARALLEL_API_KEY") or get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL") or get_env_value("TAVILY_API_KEY"):
tool_status.append(("Web Search & Extract", True, None))
else:
tool_status.append(("Web Search & Extract", False, "FIRECRAWL_API_KEY"))
tool_status.append(("Web Search & Extract", False, "PARALLEL_API_KEY, FIRECRAWL_API_KEY, or TAVILY_API_KEY"))
# Browser tools (local Chromium or Browserbase cloud)
import shutil
@ -738,6 +738,7 @@ def setup_model_provider(config: dict):
"Kilo Code (Kilo Gateway API)",
"Anthropic (Claude models — API key or Claude Code subscription)",
"AI Gateway (Vercel — 200+ models, pay-per-use)",
"Alibaba Cloud / DashScope (Qwen models via Anthropic-compatible API)",
"OpenCode Zen (35+ curated models, pay-as-you-go)",
"OpenCode Go (open models, $10/month subscription)",
]
@ -1313,7 +1314,39 @@ def setup_model_provider(config: dict):
_update_config_for_provider("ai-gateway", pconfig.inference_base_url, default_model="anthropic/claude-opus-4.6")
_set_model_provider(config, "ai-gateway", pconfig.inference_base_url)
elif provider_idx == 11: # OpenCode Zen
elif provider_idx == 11: # Alibaba Cloud / DashScope
selected_provider = "alibaba"
print()
print_header("Alibaba Cloud / DashScope API Key")
pconfig = PROVIDER_REGISTRY["alibaba"]
print_info(f"Provider: {pconfig.name}")
print_info("Get your API key at: https://modelstudio.console.alibabacloud.com/")
print()
existing_key = get_env_value("DASHSCOPE_API_KEY")
if existing_key:
print_info(f"Current: {existing_key[:8]}... (configured)")
if prompt_yes_no("Update API key?", False):
new_key = prompt(" DashScope API key", password=True)
if new_key:
save_env_value("DASHSCOPE_API_KEY", new_key)
print_success("DashScope API key updated")
else:
new_key = prompt(" DashScope API key", password=True)
if new_key:
save_env_value("DASHSCOPE_API_KEY", new_key)
print_success("DashScope API key saved")
else:
print_warning("Skipped - agent won't work without an API key")
# Clear custom endpoint vars if switching
if existing_custom:
save_env_value("OPENAI_BASE_URL", "")
save_env_value("OPENAI_API_KEY", "")
_update_config_for_provider("alibaba", pconfig.inference_base_url, default_model="qwen3.5-plus")
_set_model_provider(config, "alibaba", pconfig.inference_base_url)
elif provider_idx == 12: # OpenCode Zen
selected_provider = "opencode-zen"
print()
print_header("OpenCode Zen API Key")
@ -1346,7 +1379,7 @@ def setup_model_provider(config: dict):
_set_model_provider(config, "opencode-zen", pconfig.inference_base_url)
selected_base_url = pconfig.inference_base_url
elif provider_idx == 12: # OpenCode Go
elif provider_idx == 13: # OpenCode Go
selected_provider = "opencode-go"
print()
print_header("OpenCode Go API Key")
@ -1379,7 +1412,7 @@ def setup_model_provider(config: dict):
_set_model_provider(config, "opencode-go", pconfig.inference_base_url)
selected_base_url = pconfig.inference_base_url
# else: provider_idx == 13 (Keep current) — only shown when a provider already exists
# else: provider_idx == 14 (Keep current) — only shown when a provider already exists
# Normalize "keep current" to an explicit provider so downstream logic
# doesn't fall back to the generic OpenRouter/static-model path.
if selected_provider is None:
@ -2486,6 +2519,119 @@ def setup_gateway(config: dict):
" Set SLACK_ALLOW_ALL_USERS=true or GATEWAY_ALLOW_ALL_USERS=true only if you intentionally want open workspace access."
)
# ── Matrix ──
existing_matrix = get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD")
if existing_matrix:
print_info("Matrix: already configured")
if prompt_yes_no("Reconfigure Matrix?", False):
existing_matrix = None
if not existing_matrix and prompt_yes_no("Set up Matrix?", False):
print_info("Works with any Matrix homeserver (Synapse, Conduit, Dendrite, or matrix.org).")
print_info(" 1. Create a bot user on your homeserver, or use your own account")
print_info(" 2. Get an access token from Element, or provide user ID + password")
print()
homeserver = prompt("Homeserver URL (e.g. https://matrix.example.org)")
if homeserver:
save_env_value("MATRIX_HOMESERVER", homeserver.rstrip("/"))
print()
print_info("Auth: provide an access token (recommended), or user ID + password.")
token = prompt("Access token (leave empty for password login)", password=True)
if token:
save_env_value("MATRIX_ACCESS_TOKEN", token)
user_id = prompt("User ID (@bot:server — optional, will be auto-detected)")
if user_id:
save_env_value("MATRIX_USER_ID", user_id)
print_success("Matrix access token saved")
else:
user_id = prompt("User ID (@bot:server)")
if user_id:
save_env_value("MATRIX_USER_ID", user_id)
password = prompt("Password", password=True)
if password:
save_env_value("MATRIX_PASSWORD", password)
print_success("Matrix credentials saved")
if token or get_env_value("MATRIX_PASSWORD"):
# E2EE
print()
if prompt_yes_no("Enable end-to-end encryption (E2EE)?", False):
save_env_value("MATRIX_ENCRYPTION", "true")
print_success("E2EE enabled")
print_info(" Requires: pip install 'matrix-nio[e2e]'")
# Allowed users
print()
print_info("🔒 Security: Restrict who can use your bot")
print_info(" Matrix user IDs look like @username:server")
print()
allowed_users = prompt(
"Allowed user IDs (comma-separated, leave empty for open access)"
)
if allowed_users:
save_env_value("MATRIX_ALLOWED_USERS", allowed_users.replace(" ", ""))
print_success("Matrix allowlist configured")
else:
print_info(
"⚠️ No allowlist set - anyone who can message the bot can use it!"
)
# Home room
print()
print_info("📬 Home Room: where Hermes delivers cron job results and notifications.")
print_info(" Room IDs look like !abc123:server (shown in Element room settings)")
print_info(" You can also set this later by typing /set-home in a Matrix room.")
home_room = prompt("Home room ID (leave empty to set later with /set-home)")
if home_room:
save_env_value("MATRIX_HOME_ROOM", home_room)
# ── Mattermost ──
existing_mattermost = get_env_value("MATTERMOST_TOKEN")
if existing_mattermost:
print_info("Mattermost: already configured")
if prompt_yes_no("Reconfigure Mattermost?", False):
existing_mattermost = None
if not existing_mattermost and prompt_yes_no("Set up Mattermost?", False):
print_info("Works with any self-hosted Mattermost instance.")
print_info(" 1. In Mattermost: Integrations → Bot Accounts → Add Bot Account")
print_info(" 2. Copy the bot token")
print()
mm_url = prompt("Mattermost server URL (e.g. https://mm.example.com)")
if mm_url:
save_env_value("MATTERMOST_URL", mm_url.rstrip("/"))
token = prompt("Bot token", password=True)
if token:
save_env_value("MATTERMOST_TOKEN", token)
print_success("Mattermost token saved")
# Allowed users
print()
print_info("🔒 Security: Restrict who can use your bot")
print_info(" To find your user ID: click your avatar → Profile")
print_info(" or use the API: GET /api/v4/users/me")
print()
allowed_users = prompt(
"Allowed user IDs (comma-separated, leave empty for open access)"
)
if allowed_users:
save_env_value("MATTERMOST_ALLOWED_USERS", allowed_users.replace(" ", ""))
print_success("Mattermost allowlist configured")
else:
print_info(
"⚠️ No allowlist set - anyone who can message the bot can use it!"
)
# Home channel
print()
print_info("📬 Home Channel: where Hermes delivers cron job results and notifications.")
print_info(" To get a channel ID: click channel name → View Info → copy the ID")
print_info(" You can also set this later by typing /set-home in a Mattermost channel.")
home_channel = prompt("Home channel ID (leave empty to set later with /set-home)")
if home_channel:
save_env_value("MATTERMOST_HOME_CHANNEL", home_channel)
# ── WhatsApp ──
existing_whatsapp = get_env_value("WHATSAPP_ENABLED")
if not existing_whatsapp and prompt_yes_no("Set up WhatsApp?", False):
@ -2503,6 +2649,9 @@ def setup_gateway(config: dict):
get_env_value("TELEGRAM_BOT_TOKEN")
or get_env_value("DISCORD_BOT_TOKEN")
or get_env_value("SLACK_BOT_TOKEN")
or get_env_value("MATTERMOST_TOKEN")
or get_env_value("MATRIX_ACCESS_TOKEN")
or get_env_value("MATRIX_PASSWORD")
or get_env_value("WHATSAPP_ENABLED")
)
if any_messaging:

View file

@ -120,6 +120,7 @@ def show_status(args):
"MiniMax": "MINIMAX_API_KEY",
"MiniMax-CN": "MINIMAX_CN_API_KEY",
"Firecrawl": "FIRECRAWL_API_KEY",
"Tavily": "TAVILY_API_KEY",
"Browserbase": "BROWSERBASE_API_KEY", # Optional — local browser works without this
"FAL": "FAL_KEY",
"Tinker": "TINKER_API_KEY",
@ -252,7 +253,7 @@ def show_status(args):
"Signal": ("SIGNAL_HTTP_URL", "SIGNAL_HOME_CHANNEL"),
"Slack": ("SLACK_BOT_TOKEN", None),
"Email": ("EMAIL_ADDRESS", "EMAIL_HOME_ADDRESS"),
"SMS": ("TELNYX_API_KEY", "SMS_HOME_CHANNEL"),
"SMS": ("TWILIO_ACCOUNT_SID", "SMS_HOME_CHANNEL"),
}
for name, (token_var, home_var) in platforms.items():

View file

@ -110,6 +110,7 @@ PLATFORMS = {
"whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"},
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
}
@ -150,19 +151,37 @@ TOOL_CATEGORIES = {
"web": {
"name": "Web Search & Extract",
"setup_title": "Select Search Provider",
"setup_note": "A free DuckDuckGo search skill is also included — skip this if you don't need Firecrawl.",
"setup_note": "A free DuckDuckGo search skill is also included — skip this if you don't need a premium provider.",
"icon": "🔍",
"providers": [
{
"name": "Firecrawl Cloud",
"tag": "Recommended - hosted service",
"tag": "Hosted service - search, extract, and crawl",
"web_backend": "firecrawl",
"env_vars": [
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
],
},
{
"name": "Parallel",
"tag": "AI-native search and extract",
"web_backend": "parallel",
"env_vars": [
{"key": "PARALLEL_API_KEY", "prompt": "Parallel API key", "url": "https://parallel.ai"},
],
},
{
"name": "Tavily",
"tag": "AI-native search, extract, and crawl",
"web_backend": "tavily",
"env_vars": [
{"key": "TAVILY_API_KEY", "prompt": "Tavily API key", "url": "https://app.tavily.com/home"},
],
},
{
"name": "Firecrawl Self-Hosted",
"tag": "Free - run your own instance",
"web_backend": "firecrawl",
"env_vars": [
{"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"},
],
@ -617,6 +636,9 @@ def _is_provider_active(provider: dict, config: dict) -> bool:
if "browser_provider" in provider:
current = config.get("browser", {}).get("cloud_provider")
return provider["browser_provider"] == current
if provider.get("web_backend"):
current = config.get("web", {}).get("backend")
return current == provider["web_backend"]
return False
@ -649,6 +671,11 @@ def _configure_provider(provider: dict, config: dict):
else:
config.get("browser", {}).pop("cloud_provider", None)
# Set web search backend in config if applicable
if provider.get("web_backend"):
config.setdefault("web", {})["backend"] = provider["web_backend"]
_print_success(f" Web backend set to: {provider['web_backend']}")
if not env_vars:
_print_success(f" {provider['name']} - no configuration needed!")
return
@ -832,6 +859,11 @@ def _reconfigure_provider(provider: dict, config: dict):
config.get("browser", {}).pop("cloud_provider", None)
_print_success(f" Browser set to local mode")
# Set web search backend in config if applicable
if provider.get("web_backend"):
config.setdefault("web", {})["backend"] = provider["web_backend"]
_print_success(f" Web backend set to: {provider['web_backend']}")
if not env_vars:
_print_success(f" {provider['name']} - no configuration needed!")
return
@ -984,12 +1016,19 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
if len(platform_keys) > 1:
platform_choices.append("Configure all platforms (global)")
platform_choices.append("Reconfigure an existing tool's provider or API key")
# Show MCP option if any MCP servers are configured
_has_mcp = bool(config.get("mcp_servers"))
if _has_mcp:
platform_choices.append("Configure MCP server tools")
platform_choices.append("Done")
# Index offsets for the extra options after per-platform entries
_global_idx = len(platform_keys) if len(platform_keys) > 1 else -1
_reconfig_idx = len(platform_keys) + (1 if len(platform_keys) > 1 else 0)
_done_idx = _reconfig_idx + 1
_mcp_idx = (_reconfig_idx + 1) if _has_mcp else -1
_done_idx = _reconfig_idx + (2 if _has_mcp else 1)
while True:
idx = _prompt_choice("Select an option:", platform_choices, default=0)
@ -1004,6 +1043,12 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
print()
continue
# "Configure MCP tools" selected
if idx == _mcp_idx:
_configure_mcp_tools_interactive(config)
print()
continue
# "Configure all platforms (global)" selected
if idx == _global_idx:
# Use the union of all platforms' current tools as the starting state
@ -1090,6 +1135,137 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
print()
# ─── MCP Tools Interactive Configuration ─────────────────────────────────────
def _configure_mcp_tools_interactive(config: dict):
"""Probe MCP servers for available tools and let user toggle them on/off.
Connects to each configured MCP server, discovers tools, then shows
a per-server curses checklist. Writes changes back as ``tools.exclude``
entries in config.yaml.
"""
from hermes_cli.curses_ui import curses_checklist
mcp_servers = config.get("mcp_servers") or {}
if not mcp_servers:
_print_info("No MCP servers configured.")
return
# Count enabled servers
enabled_names = [
k for k, v in mcp_servers.items()
if v.get("enabled", True) not in (False, "false", "0", "no", "off")
]
if not enabled_names:
_print_info("All MCP servers are disabled.")
return
print()
print(color(" Discovering tools from MCP servers...", Colors.YELLOW))
print(color(f" Connecting to {len(enabled_names)} server(s): {', '.join(enabled_names)}", Colors.DIM))
try:
from tools.mcp_tool import probe_mcp_server_tools
server_tools = probe_mcp_server_tools()
except Exception as exc:
_print_error(f"Failed to probe MCP servers: {exc}")
return
if not server_tools:
_print_warning("Could not discover tools from any MCP server.")
_print_info("Check that server commands/URLs are correct and dependencies are installed.")
return
# Report discovery results
failed = [n for n in enabled_names if n not in server_tools]
if failed:
for name in failed:
_print_warning(f" Could not connect to '{name}'")
total_tools = sum(len(tools) for tools in server_tools.values())
print(color(f" Found {total_tools} tool(s) across {len(server_tools)} server(s)", Colors.GREEN))
print()
any_changes = False
for server_name, tools in server_tools.items():
if not tools:
_print_info(f" {server_name}: no tools found")
continue
srv_cfg = mcp_servers.get(server_name, {})
tools_cfg = srv_cfg.get("tools") or {}
include_list = tools_cfg.get("include") or []
exclude_list = tools_cfg.get("exclude") or []
# Build checklist labels
labels = []
for tool_name, description in tools:
desc_short = description[:70] + "..." if len(description) > 70 else description
if desc_short:
labels.append(f"{tool_name} ({desc_short})")
else:
labels.append(tool_name)
# Determine which tools are currently enabled
pre_selected: Set[int] = set()
tool_names = [t[0] for t in tools]
for i, tool_name in enumerate(tool_names):
if include_list:
# Include mode: only included tools are selected
if tool_name in include_list:
pre_selected.add(i)
elif exclude_list:
# Exclude mode: everything except excluded
if tool_name not in exclude_list:
pre_selected.add(i)
else:
# No filter: all enabled
pre_selected.add(i)
chosen = curses_checklist(
f"MCP Server: {server_name} ({len(tools)} tools)",
labels,
pre_selected,
cancel_returns=pre_selected,
)
if chosen == pre_selected:
_print_info(f" {server_name}: no changes")
continue
# Compute new exclude list based on unchecked tools
new_exclude = [tool_names[i] for i in range(len(tool_names)) if i not in chosen]
# Update config
srv_cfg = mcp_servers.setdefault(server_name, {})
tools_cfg = srv_cfg.setdefault("tools", {})
if new_exclude:
tools_cfg["exclude"] = new_exclude
# Remove include if present — we're switching to exclude mode
tools_cfg.pop("include", None)
else:
# All tools enabled — clear filters
tools_cfg.pop("exclude", None)
tools_cfg.pop("include", None)
enabled_count = len(chosen)
disabled_count = len(tools) - enabled_count
_print_success(
f" {server_name}: {enabled_count} enabled, {disabled_count} disabled"
)
any_changes = True
if any_changes:
save_config(config)
print()
print(color(" ✓ MCP tool configuration saved", Colors.GREEN))
else:
print(color(" No changes to MCP tools", Colors.DIM))
# ─── Non-interactive disable/enable ──────────────────────────────────────────

View file

@ -18,6 +18,7 @@ import json
import os
import re
import sqlite3
import threading
import time
from pathlib import Path
from typing import Dict, Any, List, Optional
@ -25,7 +26,7 @@ from typing import Dict, Any, List, Optional
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
SCHEMA_VERSION = 4
SCHEMA_VERSION = 5
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS schema_version (
@ -47,6 +48,17 @@ CREATE TABLE IF NOT EXISTS sessions (
tool_call_count INTEGER DEFAULT 0,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
cache_read_tokens INTEGER DEFAULT 0,
cache_write_tokens INTEGER DEFAULT 0,
reasoning_tokens INTEGER DEFAULT 0,
billing_provider TEXT,
billing_base_url TEXT,
billing_mode TEXT,
estimated_cost_usd REAL,
actual_cost_usd REAL,
cost_status TEXT,
cost_source TEXT,
pricing_version TEXT,
title TEXT,
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
);
@ -104,6 +116,7 @@ class SessionDB:
self.db_path = db_path or DEFAULT_DB_PATH
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._conn = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
@ -152,6 +165,26 @@ class SessionDB:
except sqlite3.OperationalError:
pass # Index already exists
cursor.execute("UPDATE schema_version SET version = 4")
if current_version < 5:
new_columns = [
("cache_read_tokens", "INTEGER DEFAULT 0"),
("cache_write_tokens", "INTEGER DEFAULT 0"),
("reasoning_tokens", "INTEGER DEFAULT 0"),
("billing_provider", "TEXT"),
("billing_base_url", "TEXT"),
("billing_mode", "TEXT"),
("estimated_cost_usd", "REAL"),
("actual_cost_usd", "REAL"),
("cost_status", "TEXT"),
("cost_source", "TEXT"),
("pricing_version", "TEXT"),
]
for name, column_type in new_columns:
try:
cursor.execute(f"ALTER TABLE sessions ADD COLUMN {name} {column_type}")
except sqlite3.OperationalError:
pass
cursor.execute("UPDATE schema_version SET version = 5")
# Unique title index — always ensure it exists (safe to run after migrations
# since the title column is guaranteed to exist at this point)
@ -173,9 +206,10 @@ class SessionDB:
def close(self):
"""Close the database connection."""
if self._conn:
self._conn.close()
self._conn = None
with self._lock:
if self._conn:
self._conn.close()
self._conn = None
# =========================================================================
# Session lifecycle
@ -192,61 +226,111 @@ class SessionDB:
parent_session_id: str = None,
) -> str:
"""Create a new session record. Returns the session_id."""
self._conn.execute(
"""INSERT INTO sessions (id, source, user_id, model, model_config,
system_prompt, parent_session_id, started_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
session_id,
source,
user_id,
model,
json.dumps(model_config) if model_config else None,
system_prompt,
parent_session_id,
time.time(),
),
)
self._conn.commit()
with self._lock:
self._conn.execute(
"""INSERT INTO sessions (id, source, user_id, model, model_config,
system_prompt, parent_session_id, started_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
session_id,
source,
user_id,
model,
json.dumps(model_config) if model_config else None,
system_prompt,
parent_session_id,
time.time(),
),
)
self._conn.commit()
return session_id
def end_session(self, session_id: str, end_reason: str) -> None:
"""Mark a session as ended."""
self._conn.execute(
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
(time.time(), end_reason, session_id),
)
self._conn.commit()
with self._lock:
self._conn.execute(
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
(time.time(), end_reason, session_id),
)
self._conn.commit()
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
"""Store the full assembled system prompt snapshot."""
self._conn.execute(
"UPDATE sessions SET system_prompt = ? WHERE id = ?",
(system_prompt, session_id),
)
self._conn.commit()
with self._lock:
self._conn.execute(
"UPDATE sessions SET system_prompt = ? WHERE id = ?",
(system_prompt, session_id),
)
self._conn.commit()
def update_token_counts(
self, session_id: str, input_tokens: int = 0, output_tokens: int = 0,
self,
session_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
model: str = None,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
reasoning_tokens: int = 0,
estimated_cost_usd: Optional[float] = None,
actual_cost_usd: Optional[float] = None,
cost_status: Optional[str] = None,
cost_source: Optional[str] = None,
pricing_version: Optional[str] = None,
billing_provider: Optional[str] = None,
billing_base_url: Optional[str] = None,
billing_mode: Optional[str] = None,
) -> None:
"""Increment token counters and backfill model if not already set."""
self._conn.execute(
"""UPDATE sessions SET
input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
model = COALESCE(model, ?)
WHERE id = ?""",
(input_tokens, output_tokens, model, session_id),
)
self._conn.commit()
with self._lock:
self._conn.execute(
"""UPDATE sessions SET
input_tokens = input_tokens + ?,
output_tokens = output_tokens + ?,
cache_read_tokens = cache_read_tokens + ?,
cache_write_tokens = cache_write_tokens + ?,
reasoning_tokens = reasoning_tokens + ?,
estimated_cost_usd = COALESCE(estimated_cost_usd, 0) + COALESCE(?, 0),
actual_cost_usd = CASE
WHEN ? IS NULL THEN actual_cost_usd
ELSE COALESCE(actual_cost_usd, 0) + ?
END,
cost_status = COALESCE(?, cost_status),
cost_source = COALESCE(?, cost_source),
pricing_version = COALESCE(?, pricing_version),
billing_provider = COALESCE(billing_provider, ?),
billing_base_url = COALESCE(billing_base_url, ?),
billing_mode = COALESCE(billing_mode, ?),
model = COALESCE(model, ?)
WHERE id = ?""",
(
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
reasoning_tokens,
estimated_cost_usd,
actual_cost_usd,
actual_cost_usd,
cost_status,
cost_source,
pricing_version,
billing_provider,
billing_base_url,
billing_mode,
model,
session_id,
),
)
self._conn.commit()
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get a session by ID."""
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE id = ?", (session_id,)
)
row = cursor.fetchone()
with self._lock:
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE id = ?", (session_id,)
)
row = cursor.fetchone()
return dict(row) if row else None
def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]:
@ -331,38 +415,42 @@ class SessionDB:
Empty/whitespace-only strings are normalized to None (clearing the title).
"""
title = self.sanitize_title(title)
if title:
# Check uniqueness (allow the same session to keep its own title)
with self._lock:
if title:
# Check uniqueness (allow the same session to keep its own title)
cursor = self._conn.execute(
"SELECT id FROM sessions WHERE title = ? AND id != ?",
(title, session_id),
)
conflict = cursor.fetchone()
if conflict:
raise ValueError(
f"Title '{title}' is already in use by session {conflict['id']}"
)
cursor = self._conn.execute(
"SELECT id FROM sessions WHERE title = ? AND id != ?",
"UPDATE sessions SET title = ? WHERE id = ?",
(title, session_id),
)
conflict = cursor.fetchone()
if conflict:
raise ValueError(
f"Title '{title}' is already in use by session {conflict['id']}"
)
cursor = self._conn.execute(
"UPDATE sessions SET title = ? WHERE id = ?",
(title, session_id),
)
self._conn.commit()
return cursor.rowcount > 0
self._conn.commit()
rowcount = cursor.rowcount
return rowcount > 0
def get_session_title(self, session_id: str) -> Optional[str]:
"""Get the title for a session, or None."""
cursor = self._conn.execute(
"SELECT title FROM sessions WHERE id = ?", (session_id,)
)
row = cursor.fetchone()
with self._lock:
cursor = self._conn.execute(
"SELECT title FROM sessions WHERE id = ?", (session_id,)
)
row = cursor.fetchone()
return row["title"] if row else None
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
"""Look up a session by exact title. Returns session dict or None."""
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE title = ?", (title,)
)
row = cursor.fetchone()
with self._lock:
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE title = ?", (title,)
)
row = cursor.fetchone()
return dict(row) if row else None
def resolve_session_by_title(self, title: str) -> Optional[str]:
@ -379,12 +467,13 @@ class SessionDB:
# Also search for numbered variants: "title #2", "title #3", etc.
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
cursor = self._conn.execute(
"SELECT id, title, started_at FROM sessions "
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
(f"{escaped} #%",),
)
numbered = cursor.fetchall()
with self._lock:
cursor = self._conn.execute(
"SELECT id, title, started_at FROM sessions "
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
(f"{escaped} #%",),
)
numbered = cursor.fetchall()
if numbered:
# Return the most recent numbered variant
@ -409,11 +498,12 @@ class SessionDB:
# Find all existing numbered variants
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
cursor = self._conn.execute(
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
(base, f"{escaped} #%"),
)
existing = [row["title"] for row in cursor.fetchall()]
with self._lock:
cursor = self._conn.execute(
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
(base, f"{escaped} #%"),
)
existing = [row["title"] for row in cursor.fetchall()]
if not existing:
return base # No conflict, use the base name as-is
@ -461,9 +551,11 @@ class SessionDB:
LIMIT ? OFFSET ?
"""
params = (source, limit, offset) if source else (limit, offset)
cursor = self._conn.execute(query, params)
with self._lock:
cursor = self._conn.execute(query, params)
rows = cursor.fetchall()
sessions = []
for row in cursor.fetchall():
for row in rows:
s = dict(row)
# Build the preview from the raw substring
raw = s.pop("_preview_raw", "").strip()
@ -497,52 +589,54 @@ class SessionDB:
Also increments the session's message_count (and tool_call_count
if role is 'tool' or tool_calls is present).
"""
cursor = self._conn.execute(
"""INSERT INTO messages (session_id, role, content, tool_call_id,
tool_calls, tool_name, timestamp, token_count, finish_reason)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
session_id,
role,
content,
tool_call_id,
json.dumps(tool_calls) if tool_calls else None,
tool_name,
time.time(),
token_count,
finish_reason,
),
)
msg_id = cursor.lastrowid
# Update counters
# Count actual tool calls from the tool_calls list (not from tool responses).
# A single assistant message can contain multiple parallel tool calls.
num_tool_calls = 0
if tool_calls is not None:
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
if num_tool_calls > 0:
self._conn.execute(
"""UPDATE sessions SET message_count = message_count + 1,
tool_call_count = tool_call_count + ? WHERE id = ?""",
(num_tool_calls, session_id),
)
else:
self._conn.execute(
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
(session_id,),
with self._lock:
cursor = self._conn.execute(
"""INSERT INTO messages (session_id, role, content, tool_call_id,
tool_calls, tool_name, timestamp, token_count, finish_reason)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
session_id,
role,
content,
tool_call_id,
json.dumps(tool_calls) if tool_calls else None,
tool_name,
time.time(),
token_count,
finish_reason,
),
)
msg_id = cursor.lastrowid
self._conn.commit()
# Update counters
# Count actual tool calls from the tool_calls list (not from tool responses).
# A single assistant message can contain multiple parallel tool calls.
num_tool_calls = 0
if tool_calls is not None:
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
if num_tool_calls > 0:
self._conn.execute(
"""UPDATE sessions SET message_count = message_count + 1,
tool_call_count = tool_call_count + ? WHERE id = ?""",
(num_tool_calls, session_id),
)
else:
self._conn.execute(
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
(session_id,),
)
self._conn.commit()
return msg_id
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
"""Load all messages for a session, ordered by timestamp."""
cursor = self._conn.execute(
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
(session_id,),
)
rows = cursor.fetchall()
with self._lock:
cursor = self._conn.execute(
"SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id",
(session_id,),
)
rows = cursor.fetchall()
result = []
for row in rows:
msg = dict(row)
@ -559,13 +653,15 @@ class SessionDB:
Load messages in the OpenAI conversation format (role + content dicts).
Used by the gateway to restore conversation history.
"""
cursor = self._conn.execute(
"SELECT role, content, tool_call_id, tool_calls, tool_name "
"FROM messages WHERE session_id = ? ORDER BY timestamp, id",
(session_id,),
)
with self._lock:
cursor = self._conn.execute(
"SELECT role, content, tool_call_id, tool_calls, tool_name "
"FROM messages WHERE session_id = ? ORDER BY timestamp, id",
(session_id,),
)
rows = cursor.fetchall()
messages = []
for row in cursor.fetchall():
for row in rows:
msg = {"role": row["role"], "content": row["content"]}
if row["tool_call_id"]:
msg["tool_call_id"] = row["tool_call_id"]
@ -675,31 +771,33 @@ class SessionDB:
LIMIT ? OFFSET ?
"""
try:
cursor = self._conn.execute(sql, params)
except sqlite3.OperationalError:
# FTS5 query syntax error despite sanitization — return empty
return []
matches = [dict(row) for row in cursor.fetchall()]
# Add surrounding context (1 message before + after each match)
for match in matches:
with self._lock:
try:
ctx_cursor = self._conn.execute(
"""SELECT role, content FROM messages
WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1
ORDER BY id""",
(match["session_id"], match["id"], match["id"]),
)
context_msgs = [
{"role": r["role"], "content": (r["content"] or "")[:200]}
for r in ctx_cursor.fetchall()
]
match["context"] = context_msgs
except Exception:
match["context"] = []
cursor = self._conn.execute(sql, params)
except sqlite3.OperationalError:
# FTS5 query syntax error despite sanitization — return empty
return []
matches = [dict(row) for row in cursor.fetchall()]
# Remove full content from result (snippet is enough, saves tokens)
# Add surrounding context (1 message before + after each match)
for match in matches:
try:
ctx_cursor = self._conn.execute(
"""SELECT role, content FROM messages
WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1
ORDER BY id""",
(match["session_id"], match["id"], match["id"]),
)
context_msgs = [
{"role": r["role"], "content": (r["content"] or "")[:200]}
for r in ctx_cursor.fetchall()
]
match["context"] = context_msgs
except Exception:
match["context"] = []
# Remove full content from result (snippet is enough, saves tokens)
for match in matches:
match.pop("content", None)
return matches
@ -711,17 +809,18 @@ class SessionDB:
offset: int = 0,
) -> List[Dict[str, Any]]:
"""List sessions, optionally filtered by source."""
if source:
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
(source, limit, offset),
)
else:
cursor = self._conn.execute(
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
(limit, offset),
)
return [dict(row) for row in cursor.fetchall()]
with self._lock:
if source:
cursor = self._conn.execute(
"SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?",
(source, limit, offset),
)
else:
cursor = self._conn.execute(
"SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?",
(limit, offset),
)
return [dict(row) for row in cursor.fetchall()]
# =========================================================================
# Utility
@ -773,26 +872,28 @@ class SessionDB:
def clear_messages(self, session_id: str) -> None:
"""Delete all messages for a session and reset its counters."""
self._conn.execute(
"DELETE FROM messages WHERE session_id = ?", (session_id,)
)
self._conn.execute(
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
(session_id,),
)
self._conn.commit()
with self._lock:
self._conn.execute(
"DELETE FROM messages WHERE session_id = ?", (session_id,)
)
self._conn.execute(
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
(session_id,),
)
self._conn.commit()
def delete_session(self, session_id: str) -> bool:
"""Delete a session and all its messages. Returns True if found."""
cursor = self._conn.execute(
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
)
if cursor.fetchone()[0] == 0:
return False
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
self._conn.commit()
return True
with self._lock:
cursor = self._conn.execute(
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
)
if cursor.fetchone()[0] == 0:
return False
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
self._conn.commit()
return True
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
"""
@ -802,22 +903,23 @@ class SessionDB:
import time as _time
cutoff = _time.time() - (older_than_days * 86400)
if source:
cursor = self._conn.execute(
"""SELECT id FROM sessions
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
(cutoff,),
)
session_ids = [row["id"] for row in cursor.fetchall()]
with self._lock:
if source:
cursor = self._conn.execute(
"""SELECT id FROM sessions
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
(cutoff,),
)
session_ids = [row["id"] for row in cursor.fetchall()]
for sid in session_ids:
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
for sid in session_ids:
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
self._conn.commit()
self._conn.commit()
return len(session_ids)

View file

@ -69,6 +69,8 @@ class HonchoClientConfig:
workspace_id: str = "hermes"
api_key: str | None = None
environment: str = "production"
# Optional base URL for self-hosted Honcho (overrides environment mapping)
base_url: str | None = None
# Identity
peer_name: str | None = None
ai_peer: str = "hermes"
@ -361,13 +363,34 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho:
"Install it with: pip install honcho-ai"
)
logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id)
# Allow config.yaml honcho.base_url to override the SDK's environment
# mapping, enabling remote self-hosted Honcho deployments without
# requiring the server to live on localhost.
resolved_base_url = config.base_url
if not resolved_base_url:
try:
from hermes_cli.config import load_config
hermes_cfg = load_config()
honcho_cfg = hermes_cfg.get("honcho", {})
if isinstance(honcho_cfg, dict):
resolved_base_url = honcho_cfg.get("base_url", "").strip() or None
except Exception:
pass
_honcho_client = Honcho(
workspace_id=config.workspace_id,
api_key=config.api_key,
environment=config.environment,
)
if resolved_base_url:
logger.info("Initializing Honcho client (base_url: %s, workspace: %s)", resolved_base_url, config.workspace_id)
else:
logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id)
kwargs: dict = {
"workspace_id": config.workspace_id,
"api_key": config.api_key,
"environment": config.environment,
}
if resolved_base_url:
kwargs["base_url"] = resolved_base_url
_honcho_client = Honcho(**kwargs)
return _honcho_client

View file

@ -27,6 +27,7 @@ dependencies = [
"prompt_toolkit",
# Tools
"firecrawl-py",
"parallel-web>=0.4.2",
"fal-client",
# Text-to-speech (Edge TTS is free, no API key needed)
"edge-tts",
@ -46,6 +47,7 @@ dev = ["pytest", "pytest-asyncio", "pytest-xdist", "mcp>=1.2.0"]
messaging = ["python-telegram-bot>=20.0", "discord.py[voice]>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
cron = ["croniter"]
slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
matrix = ["matrix-nio[e2e]>=0.24.0"]
cli = ["simple-term-menu"]
tts-premium = ["elevenlabs"]
voice = ["sounddevice>=0.4.6", "numpy>=1.24.0"]
@ -79,9 +81,9 @@ all = [
"hermes-agent[honcho]",
"hermes-agent[mcp]",
"hermes-agent[homeassistant]",
"hermes-agent[sms]",
"hermes-agent[acp]",
"hermes-agent[voice]",
"hermes-agent[sms]",
]
[project.scripts]

View file

@ -18,6 +18,7 @@ PyJWT[crypto]
# Web tools
firecrawl-py
parallel-web>=0.4.2
# Image generation
fal-client

View file

@ -86,6 +86,7 @@ from agent.model_metadata import (
from agent.context_compressor import ContextCompressor
from agent.prompt_caching import apply_anthropic_cache_control
from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt
from agent.usage_pricing import estimate_usage_cost, normalize_usage
from agent.display import (
KawaiiSpinner, build_tool_preview as _build_tool_preview,
get_cute_tool_message as _get_cute_tool_message_impl,
@ -391,6 +392,15 @@ class AIAgent:
else:
self.api_mode = "chat_completions"
# Pre-warm OpenRouter model metadata cache in a background thread.
# fetch_model_metadata() is cached for 1 hour; this avoids a blocking
# HTTP request on the first API response when pricing is estimated.
if self.provider == "openrouter" or "openrouter" in self.base_url.lower():
threading.Thread(
target=lambda: fetch_model_metadata(),
daemon=True,
).start()
self.tool_progress_callback = tool_progress_callback
self.thinking_callback = thinking_callback
self.reasoning_callback = reasoning_callback
@ -407,6 +417,7 @@ class AIAgent:
# Subagent delegation state
self._delegate_depth = 0 # 0 = top-level agent, incremented for children
self._active_children = [] # Running child AIAgents (for interrupt propagation)
self._active_children_lock = threading.Lock()
# Store OpenRouter provider preferences
self.providers_allowed = providers_allowed
@ -456,8 +467,8 @@ class AIAgent:
and Path(getattr(handler, "baseFilename", "")).resolve() == resolved_error_log_path
for handler in root_logger.handlers
)
from agent.redact import RedactingFormatter
if not has_errors_log_handler:
from agent.redact import RedactingFormatter
error_log_dir.mkdir(parents=True, exist_ok=True)
error_file_handler = RotatingFileHandler(
error_log_path, maxBytes=2 * 1024 * 1024, backupCount=2,
@ -849,6 +860,14 @@ class AIAgent:
self.session_completion_tokens = 0
self.session_total_tokens = 0
self.session_api_calls = 0
self.session_input_tokens = 0
self.session_output_tokens = 0
self.session_cache_read_tokens = 0
self.session_cache_write_tokens = 0
self.session_reasoning_tokens = 0
self.session_estimated_cost_usd = 0.0
self.session_cost_status = "unknown"
self.session_cost_source = "none"
if not self.quiet_mode:
if compression_enabled:
@ -1526,7 +1545,9 @@ class AIAgent:
# Signal all tools to abort any in-flight operations immediately
_set_interrupt(True)
# Propagate interrupt to any running child agents (subagent delegation)
for child in self._active_children:
with self._active_children_lock:
children_copy = list(self._active_children)
for child in children_copy:
try:
child.interrupt(message)
except Exception as e:
@ -1936,7 +1957,124 @@ class AIAgent:
prompt_parts.append(PLATFORM_HINTS[platform_key])
return "\n\n".join(prompt_parts)
# =========================================================================
# Pre/post-call guardrails (inspired by PR #1321 — @alireza78a)
# =========================================================================
@staticmethod
def _get_tool_call_id_static(tc) -> str:
"""Extract call ID from a tool_call entry (dict or object)."""
if isinstance(tc, dict):
return tc.get("id", "") or ""
return getattr(tc, "id", "") or ""
@staticmethod
def _sanitize_api_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Fix orphaned tool_call / tool_result pairs before every LLM call.
Runs unconditionally not gated on whether the context compressor
is present so orphans from session loading or manual message
manipulation are always caught.
"""
surviving_call_ids: set = set()
for msg in messages:
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = AIAgent._get_tool_call_id_static(tc)
if cid:
surviving_call_ids.add(cid)
result_call_ids: set = set()
for msg in messages:
if msg.get("role") == "tool":
cid = msg.get("tool_call_id")
if cid:
result_call_ids.add(cid)
# 1. Drop tool results with no matching assistant call
orphaned_results = result_call_ids - surviving_call_ids
if orphaned_results:
messages = [
m for m in messages
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
]
logger.debug(
"Pre-call sanitizer: removed %d orphaned tool result(s)",
len(orphaned_results),
)
# 2. Inject stub results for calls whose result was dropped
missing_results = surviving_call_ids - result_call_ids
if missing_results:
patched: List[Dict[str, Any]] = []
for msg in messages:
patched.append(msg)
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = AIAgent._get_tool_call_id_static(tc)
if cid in missing_results:
patched.append({
"role": "tool",
"content": "[Result unavailable — see context summary above]",
"tool_call_id": cid,
})
messages = patched
logger.debug(
"Pre-call sanitizer: added %d stub tool result(s)",
len(missing_results),
)
return messages
@staticmethod
def _cap_delegate_task_calls(tool_calls: list) -> list:
"""Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN.
The delegate_tool caps the task list inside a single call, but the
model can emit multiple separate delegate_task tool_calls in one
turn. This truncates the excess, preserving all non-delegate calls.
Returns the original list if no truncation was needed.
"""
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task")
if delegate_count <= MAX_CONCURRENT_CHILDREN:
return tool_calls
kept_delegates = 0
truncated = []
for tc in tool_calls:
if tc.function.name == "delegate_task":
if kept_delegates < MAX_CONCURRENT_CHILDREN:
truncated.append(tc)
kept_delegates += 1
else:
truncated.append(tc)
logger.warning(
"Truncated %d excess delegate_task call(s) to enforce "
"MAX_CONCURRENT_CHILDREN=%d limit",
delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN,
)
return truncated
@staticmethod
def _deduplicate_tool_calls(tool_calls: list) -> list:
"""Remove duplicate (tool_name, arguments) pairs within a single turn.
Only the first occurrence of each unique pair is kept.
Returns the original list if no duplicates were found.
"""
seen: set = set()
unique: list = []
for tc in tool_calls:
key = (tc.function.name, tc.function.arguments)
if key not in seen:
seen.add(key)
unique.append(tc)
else:
logger.warning("Removed duplicate tool call: %s", tc.function.name)
return unique if len(unique) < len(tool_calls) else tool_calls
def _repair_tool_call(self, tool_name: str) -> str | None:
"""Attempt to repair a mismatched tool name before aborting.
@ -4863,6 +5001,7 @@ class AIAgent:
codex_ack_continuations = 0
length_continue_retries = 0
truncated_response_prefix = ""
compression_attempts = 0
# Clear any stale interrupt state at start
self.clear_interrupt()
@ -4970,11 +5109,10 @@ class AIAgent:
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
# Safety net: strip orphaned tool results / add stubs for missing
# results before sending to the API. The compressor handles this
# during compression, but orphans can also sneak in from session
# loading or manual message manipulation.
if hasattr(self, 'context_compressor') and self.context_compressor:
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
# results before sending to the API. Runs unconditionally — not
# gated on context_compressor — so orphans from session loading or
# manual message manipulation are always caught.
api_messages = self._sanitize_api_messages(api_messages)
# Calculate approximate request size for logging
total_chars = sum(len(str(msg)) for msg in api_messages)
@ -5008,7 +5146,6 @@ class AIAgent:
api_start_time = time.time()
retry_count = 0
max_retries = 3
compression_attempts = 0
max_compression_attempts = 3
codex_auth_retry_attempted = False
anthropic_auth_retry_attempted = False
@ -5111,6 +5248,13 @@ class AIAgent:
# This is often rate limiting or provider returning malformed response
retry_count += 1
# Eager fallback: empty/malformed responses are a common
# rate-limit symptom. Switch to fallback immediately
# rather than retrying with extended backoff.
if not self._fallback_activated and self._try_activate_fallback():
retry_count = 0
continue
# Check for error field in response (some providers include this)
error_msg = "Unknown"
provider_name = "Unknown"
@ -5269,26 +5413,14 @@ class AIAgent:
# Track actual token usage from response for context management
if hasattr(response, 'usage') and response.usage:
if self.api_mode in ("codex_responses", "anthropic_messages"):
prompt_tokens = getattr(response.usage, 'input_tokens', 0) or 0
if self.api_mode == "anthropic_messages":
# Anthropic splits input into cache_read + cache_creation
# + non-cached input_tokens. Without adding the cached
# portions, the context bar shows only the tiny non-cached
# portion (e.g. 3 tokens) instead of the real total (~18K).
# Other providers (OpenAI/Codex) already include cached
# tokens in their input_tokens/prompt_tokens field.
prompt_tokens += getattr(response.usage, 'cache_read_input_tokens', 0) or 0
prompt_tokens += getattr(response.usage, 'cache_creation_input_tokens', 0) or 0
completion_tokens = getattr(response.usage, 'output_tokens', 0) or 0
total_tokens = (
getattr(response.usage, 'total_tokens', None)
or (prompt_tokens + completion_tokens)
)
else:
prompt_tokens = getattr(response.usage, 'prompt_tokens', 0) or 0
completion_tokens = getattr(response.usage, 'completion_tokens', 0) or 0
total_tokens = getattr(response.usage, 'total_tokens', 0) or 0
canonical_usage = normalize_usage(
response.usage,
provider=self.provider,
api_mode=self.api_mode,
)
prompt_tokens = canonical_usage.prompt_tokens
completion_tokens = canonical_usage.output_tokens
total_tokens = canonical_usage.total_tokens
usage_dict = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
@ -5307,6 +5439,22 @@ class AIAgent:
self.session_completion_tokens += completion_tokens
self.session_total_tokens += total_tokens
self.session_api_calls += 1
self.session_input_tokens += canonical_usage.input_tokens
self.session_output_tokens += canonical_usage.output_tokens
self.session_cache_read_tokens += canonical_usage.cache_read_tokens
self.session_cache_write_tokens += canonical_usage.cache_write_tokens
self.session_reasoning_tokens += canonical_usage.reasoning_tokens
cost_result = estimate_usage_cost(
self.model,
canonical_usage,
provider=self.provider,
base_url=self.base_url,
)
if cost_result.amount_usd is not None:
self.session_estimated_cost_usd += float(cost_result.amount_usd)
self.session_cost_status = cost_result.status
self.session_cost_source = cost_result.source
# Persist token counts to session DB for /insights.
# Gateway sessions persist via session_store.update_session()
@ -5317,8 +5465,19 @@ class AIAgent:
try:
self._session_db.update_token_counts(
self.session_id,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
input_tokens=canonical_usage.input_tokens,
output_tokens=canonical_usage.output_tokens,
cache_read_tokens=canonical_usage.cache_read_tokens,
cache_write_tokens=canonical_usage.cache_write_tokens,
reasoning_tokens=canonical_usage.reasoning_tokens,
estimated_cost_usd=float(cost_result.amount_usd)
if cost_result.amount_usd is not None else None,
cost_status=cost_result.status,
cost_source=cost_result.source,
billing_provider=self.provider,
billing_base_url=self.base_url,
billing_mode="subscription_included"
if cost_result.status == "included" else None,
model=self.model,
)
except Exception:
@ -5449,6 +5608,24 @@ class AIAgent:
# A 413 is a payload-size error — the correct response is to
# compress history and retry, not abort immediately.
status_code = getattr(api_error, "status_code", None)
# Eager fallback for rate-limit errors (429 or quota exhaustion).
# When a fallback model is configured, switch immediately instead
# of burning through retries with exponential backoff -- the
# primary provider won't recover within the retry window.
is_rate_limited = (
status_code == 429
or "rate limit" in error_msg
or "too many requests" in error_msg
or "rate_limit" in error_msg
or "usage limit" in error_msg
or "quota" in error_msg
)
if is_rate_limited and not self._fallback_activated:
if self._try_activate_fallback():
retry_count = 0
continue
is_payload_too_large = (
status_code == 413
or 'request entity too large' in error_msg
@ -5935,24 +6112,45 @@ class AIAgent:
# Don't add anything to messages, just retry the API call
continue
else:
# Instead of returning partial, inject a helpful message and let model recover
self._vprint(f"{self.log_prefix}⚠️ Injecting recovery message for invalid JSON...")
# Instead of returning partial, inject tool error results so the model can recover.
# Using tool results (not user messages) preserves role alternation.
self._vprint(f"{self.log_prefix}⚠️ Injecting recovery tool results for invalid JSON...")
self._invalid_json_retries = 0 # Reset for next attempt
# Add a user message explaining the issue
recovery_msg = (
f"Your tool call to '{tool_name}' had invalid JSON arguments. "
f"Error: {error_msg}. "
f"For tools with no required parameters, use an empty object: {{}}. "
f"Please either retry the tool call with valid JSON, or respond without using that tool."
)
recovery_dict = {"role": "user", "content": recovery_msg}
messages.append(recovery_dict)
# Append the assistant message with its (broken) tool_calls
recovery_assistant = self._build_assistant_message(assistant_message, finish_reason)
messages.append(recovery_assistant)
# Respond with tool error results for each tool call
invalid_names = {name for name, _ in invalid_json_args}
for tc in assistant_message.tool_calls:
if tc.function.name in invalid_names:
err = next(e for n, e in invalid_json_args if n == tc.function.name)
tool_result = (
f"Error: Invalid JSON arguments. {err}. "
f"For tools with no required parameters, use an empty object: {{}}. "
f"Please retry with valid JSON."
)
else:
tool_result = "Skipped: other tool call in this response had invalid JSON."
messages.append({
"role": "tool",
"tool_call_id": tc.id,
"content": tool_result,
})
continue
# Reset retry counter on successful JSON validation
self._invalid_json_retries = 0
# ── Post-call guardrails ──────────────────────────
assistant_message.tool_calls = self._cap_delegate_task_calls(
assistant_message.tool_calls
)
assistant_message.tool_calls = self._deduplicate_tool_calls(
assistant_message.tool_calls
)
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
# If this turn has both content AND tool_calls, capture the content
@ -6133,6 +6331,8 @@ class AIAgent:
if truncated_response_prefix:
final_response = truncated_response_prefix + final_response
truncated_response_prefix = ""
length_continue_retries = 0
# Strip <think> blocks from user-facing response (keep raw in messages for trajectory)
final_response = self._strip_think_blocks(final_response).strip()
@ -6184,10 +6384,11 @@ class AIAgent:
if not pending_handled:
# Error happened before tool processing (e.g. response parsing).
# Use a user-role message so the model can see what went wrong
# without confusing the API with a fabricated assistant turn.
# Choose role to avoid consecutive same-role messages.
last_role = messages[-1].get("role") if messages else None
err_role = "assistant" if last_role == "user" else "user"
sys_err_msg = {
"role": "user",
"role": err_role,
"content": f"[System error during processing: {error_msg}]",
}
messages.append(sys_err_msg)
@ -6239,6 +6440,21 @@ class AIAgent:
"partial": False, # True only when stopped due to invalid tool calls
"interrupted": interrupted,
"response_previewed": getattr(self, "_response_was_previewed", False),
"model": self.model,
"provider": self.provider,
"base_url": self.base_url,
"input_tokens": self.session_input_tokens,
"output_tokens": self.session_output_tokens,
"cache_read_tokens": self.session_cache_read_tokens,
"cache_write_tokens": self.session_cache_write_tokens,
"reasoning_tokens": self.session_reasoning_tokens,
"prompt_tokens": self.session_prompt_tokens,
"completion_tokens": self.session_completion_tokens,
"total_tokens": self.session_total_tokens,
"last_prompt_tokens": getattr(self.context_compressor, "last_prompt_tokens", 0) or 0,
"estimated_cost_usd": self.session_estimated_cost_usd,
"cost_status": self.session_cost_status,
"cost_source": self.session_cost_source,
}
self._response_was_previewed = False

View file

@ -0,0 +1,19 @@
# inference.sh
Run 150+ AI applications in the cloud via the [inference.sh](https://inference.sh) platform.
**One API key for everything** — access image generation, video creation, LLMs, search, 3D, and more through a single account. No need to manage separate API keys for each provider.
## Available Skills
- **cli**: Use the inference.sh CLI (`infsh`) via the terminal tool
## What's Included
- **Image Generation**: FLUX, Reve, Seedream, Grok Imagine, Gemini
- **Video Generation**: Veo, Wan, Seedance, OmniHuman, HunyuanVideo
- **LLMs**: Claude, Gemini, Kimi, GLM-4 (via OpenRouter)
- **Search**: Tavily, Exa
- **3D**: Rodin
- **Social**: Twitter/X automation
- **Audio**: TTS, voice cloning

View file

@ -0,0 +1,155 @@
---
name: inference-sh-cli
description: "Run 150+ AI apps via inference.sh CLI (infsh) — image generation, video creation, LLMs, search, 3D, social automation. Uses the terminal tool. Triggers: inference.sh, infsh, ai apps, flux, veo, image generation, video generation, seedream, seedance, tavily"
version: 1.0.0
author: okaris
license: MIT
metadata:
hermes:
tags: [AI, image-generation, video, LLM, search, inference, FLUX, Veo, Claude]
related_skills: []
---
# inference.sh CLI
Run 150+ AI apps in the cloud with a simple CLI. No GPU required.
All commands use the **terminal tool** to run `infsh` commands.
## When to Use
- User asks to generate images (FLUX, Reve, Seedream, Grok, Gemini image)
- User asks to generate video (Veo, Wan, Seedance, OmniHuman)
- User asks about inference.sh or infsh
- User wants to run AI apps without managing individual provider APIs
- User asks for AI-powered search (Tavily, Exa)
- User needs avatar/lipsync generation
## Prerequisites
The `infsh` CLI must be installed and authenticated. Check with:
```bash
infsh me
```
If not installed:
```bash
curl -fsSL https://cli.inference.sh | sh
infsh login
```
See `references/authentication.md` for full setup details.
## Workflow
### 1. Always Search First
Never guess app names — always search to find the correct app ID:
```bash
infsh app list --search flux
infsh app list --search video
infsh app list --search image
```
### 2. Run an App
Use the exact app ID from the search results. Always use `--json` for machine-readable output:
```bash
infsh app run <app-id> --input '{"prompt": "your prompt here"}' --json
```
### 3. Parse the Output
The JSON output contains URLs to generated media. Present these to the user with `MEDIA:<url>` for inline display.
## Common Commands
### Image Generation
```bash
# Search for image apps
infsh app list --search image
# FLUX Dev with LoRA
infsh app run falai/flux-dev-lora --input '{"prompt": "sunset over mountains", "num_images": 1}' --json
# Gemini image generation
infsh app run google/gemini-2-5-flash-image --input '{"prompt": "futuristic city", "num_images": 1}' --json
# Seedream (ByteDance)
infsh app run bytedance/seedream-5-lite --input '{"prompt": "nature scene"}' --json
# Grok Imagine (xAI)
infsh app run xai/grok-imagine-image --input '{"prompt": "abstract art"}' --json
```
### Video Generation
```bash
# Search for video apps
infsh app list --search video
# Veo 3.1 (Google)
infsh app run google/veo-3-1-fast --input '{"prompt": "drone shot of coastline"}' --json
# Seedance (ByteDance)
infsh app run bytedance/seedance-1-5-pro --input '{"prompt": "dancing figure", "resolution": "1080p"}' --json
# Wan 2.5
infsh app run falai/wan-2-5 --input '{"prompt": "person walking through city"}' --json
```
### Local File Uploads
The CLI automatically uploads local files when you provide a path:
```bash
# Upscale a local image
infsh app run falai/topaz-image-upscaler --input '{"image": "/path/to/photo.jpg", "upscale_factor": 2}' --json
# Image-to-video from local file
infsh app run falai/wan-2-5-i2v --input '{"image": "/path/to/image.png", "prompt": "make it move"}' --json
# Avatar with audio
infsh app run bytedance/omnihuman-1-5 --input '{"audio": "/path/to/audio.mp3", "image": "/path/to/face.jpg"}' --json
```
### Search & Research
```bash
infsh app list --search search
infsh app run tavily/tavily-search --input '{"query": "latest AI news"}' --json
infsh app run exa/exa-search --input '{"query": "machine learning papers"}' --json
```
### Other Categories
```bash
# 3D generation
infsh app list --search 3d
# Audio / TTS
infsh app list --search tts
# Twitter/X automation
infsh app list --search twitter
```
## Pitfalls
1. **Never guess app IDs** — always run `infsh app list --search <term>` first. App IDs change and new apps are added frequently.
2. **Always use `--json`** — raw output is hard to parse. The `--json` flag gives structured output with URLs.
3. **Check authentication** — if commands fail with auth errors, run `infsh login` or verify `INFSH_API_KEY` is set.
4. **Long-running apps** — video generation can take 30-120 seconds. The terminal tool timeout should be sufficient, but warn the user it may take a moment.
5. **Input format** — the `--input` flag takes a JSON string. Make sure to properly escape quotes.
## Reference Docs
- `references/authentication.md` — Setup, login, API keys
- `references/app-discovery.md` — Searching and browsing the app catalog
- `references/running-apps.md` — Running apps, input formats, output handling
- `references/cli-reference.md` — Complete CLI command reference

View file

@ -0,0 +1,112 @@
# Discovering Apps
## List All Apps
```bash
infsh app list
```
## Pagination
```bash
infsh app list --page 2
```
## Filter by Category
```bash
infsh app list --category image
infsh app list --category video
infsh app list --category audio
infsh app list --category text
infsh app list --category other
```
## Search
```bash
infsh app search "flux"
infsh app search "video generation"
infsh app search "tts" -l
infsh app search "image" --category image
```
Or use the flag form:
```bash
infsh app list --search "flux"
infsh app list --search "video generation"
infsh app list --search "tts"
```
## Featured Apps
```bash
infsh app list --featured
```
## Newest First
```bash
infsh app list --new
```
## Detailed View
```bash
infsh app list -l
```
Shows table with app name, category, description, and featured status.
## Save to File
```bash
infsh app list --save apps.json
```
## Your Apps
List apps you've deployed:
```bash
infsh app my
infsh app my -l # detailed
```
## Get App Details
```bash
infsh app get falai/flux-dev-lora
infsh app get falai/flux-dev-lora --json
```
Shows full app info including input/output schema.
## Popular Apps by Category
### Image Generation
- `falai/flux-dev-lora` - FLUX.2 Dev (high quality)
- `falai/flux-2-klein-lora` - FLUX.2 Klein (fastest)
- `infsh/sdxl` - Stable Diffusion XL
- `google/gemini-3-pro-image-preview` - Gemini 3 Pro
- `xai/grok-imagine-image` - Grok image generation
### Video Generation
- `google/veo-3-1-fast` - Veo 3.1 Fast
- `google/veo-3` - Veo 3
- `bytedance/seedance-1-5-pro` - Seedance 1.5 Pro
- `infsh/ltx-video-2` - LTX Video 2 (with audio)
- `bytedance/omnihuman-1-5` - OmniHuman avatar
### Audio
- `infsh/dia-tts` - Conversational TTS
- `infsh/kokoro-tts` - Kokoro TTS
- `infsh/fast-whisper-large-v3` - Fast transcription
- `infsh/diffrythm` - Music generation
## Documentation
- [Browsing the Grid](https://inference.sh/docs/apps/browsing-grid) - Visual app browsing
- [Apps Overview](https://inference.sh/docs/apps/overview) - Understanding apps
- [Running Apps](https://inference.sh/docs/apps/running) - How to run apps

View file

@ -0,0 +1,59 @@
# Authentication & Setup
## Install the CLI
```bash
curl -fsSL https://cli.inference.sh | sh
```
## Login
```bash
infsh login
```
This opens a browser for authentication. After login, credentials are stored locally.
## Check Authentication
```bash
infsh me
```
Shows your user info if authenticated.
## Environment Variable
For CI/CD or scripts, set your API key:
```bash
export INFSH_API_KEY=your-api-key
```
The environment variable overrides the config file.
## Update CLI
```bash
infsh update
```
Or reinstall:
```bash
curl -fsSL https://cli.inference.sh | sh
```
## Troubleshooting
| Error | Solution |
|-------|----------|
| "not authenticated" | Run `infsh login` |
| "command not found" | Reinstall CLI or add to PATH |
| "API key invalid" | Check `INFSH_API_KEY` or re-login |
## Documentation
- [CLI Setup](https://inference.sh/docs/extend/cli-setup) - Complete CLI installation guide
- [API Authentication](https://inference.sh/docs/api/authentication) - API key management
- [Secrets](https://inference.sh/docs/secrets/overview) - Managing credentials

View file

@ -0,0 +1,104 @@
# CLI Reference
## Installation
```bash
curl -fsSL https://cli.inference.sh | sh
```
## Global Commands
| Command | Description |
|---------|-------------|
| `infsh help` | Show help |
| `infsh version` | Show CLI version |
| `infsh update` | Update CLI to latest |
| `infsh login` | Authenticate |
| `infsh me` | Show current user |
## App Commands
### Discovery
| Command | Description |
|---------|-------------|
| `infsh app list` | List available apps |
| `infsh app list --category <cat>` | Filter by category (image, video, audio, text, other) |
| `infsh app search <query>` | Search apps |
| `infsh app list --search <query>` | Search apps (flag form) |
| `infsh app list --featured` | Show featured apps |
| `infsh app list --new` | Sort by newest |
| `infsh app list --page <n>` | Pagination |
| `infsh app list -l` | Detailed table view |
| `infsh app list --save <file>` | Save to JSON file |
| `infsh app my` | List your deployed apps |
| `infsh app get <app>` | Get app details |
| `infsh app get <app> --json` | Get app details as JSON |
### Execution
| Command | Description |
|---------|-------------|
| `infsh app run <app> --input <file>` | Run app with input file |
| `infsh app run <app> --input '<json>'` | Run with inline JSON |
| `infsh app run <app> --input <file> --no-wait` | Run without waiting for completion |
| `infsh app sample <app>` | Show sample input |
| `infsh app sample <app> --save <file>` | Save sample to file |
## Task Commands
| Command | Description |
|---------|-------------|
| `infsh task get <task-id>` | Get task status and result |
| `infsh task get <task-id> --json` | Get task as JSON |
| `infsh task get <task-id> --save <file>` | Save task result to file |
### Development
| Command | Description |
|---------|-------------|
| `infsh app init` | Create new app (interactive) |
| `infsh app init <name>` | Create new app with name |
| `infsh app test --input <file>` | Test app locally |
| `infsh app deploy` | Deploy app |
| `infsh app deploy --dry-run` | Validate without deploying |
| `infsh app pull <id>` | Pull app source |
| `infsh app pull --all` | Pull all your apps |
## Environment Variables
| Variable | Description |
|----------|-------------|
| `INFSH_API_KEY` | API key (overrides config) |
## Shell Completions
```bash
# Bash
infsh completion bash > /etc/bash_completion.d/infsh
# Zsh
infsh completion zsh > "${fpath[1]}/_infsh"
# Fish
infsh completion fish > ~/.config/fish/completions/infsh.fish
```
## App Name Format
Apps use the format `namespace/app-name`:
- `falai/flux-dev-lora` - fal.ai's FLUX 2 Dev
- `google/veo-3` - Google's Veo 3
- `infsh/sdxl` - inference.sh's SDXL
- `bytedance/seedance-1-5-pro` - ByteDance's Seedance
- `xai/grok-imagine-image` - xAI's Grok
Version pinning: `namespace/app-name@version`
## Documentation
- [CLI Setup](https://inference.sh/docs/extend/cli-setup) - Complete CLI installation guide
- [Running Apps](https://inference.sh/docs/apps/running) - How to run apps via CLI
- [Creating an App](https://inference.sh/docs/extend/creating-app) - Build your own apps
- [Deploying](https://inference.sh/docs/extend/deploying) - Deploy apps to the cloud

View file

@ -0,0 +1,171 @@
# Running Apps
## Basic Run
```bash
infsh app run user/app-name --input input.json
```
## Inline JSON
```bash
infsh app run falai/flux-dev-lora --input '{"prompt": "a sunset over mountains"}'
```
## Version Pinning
```bash
infsh app run user/app-name@1.0.0 --input input.json
```
## Local File Uploads
The CLI automatically uploads local files when you provide a file path instead of a URL. Any field that accepts a URL also accepts a local path:
```bash
# Upscale a local image
infsh app run falai/topaz-image-upscaler --input '{"image": "/path/to/photo.jpg", "upscale_factor": 2}'
# Image-to-video from local file
infsh app run falai/wan-2-5-i2v --input '{"image": "./my-image.png", "prompt": "make it move"}'
# Avatar with local audio and image
infsh app run bytedance/omnihuman-1-5 --input '{"audio": "/path/to/speech.mp3", "image": "/path/to/face.jpg"}'
# Post tweet with local media
infsh app run x/post-create --input '{"text": "Check this out!", "media": "./screenshot.png"}'
```
Supported paths:
- Absolute paths: `/home/user/images/photo.jpg`
- Relative paths: `./image.png`, `../data/video.mp4`
- Home directory: `~/Pictures/photo.jpg`
## Generate Sample Input
Before running, generate a sample input file:
```bash
infsh app sample falai/flux-dev-lora
```
Save to file:
```bash
infsh app sample falai/flux-dev-lora --save input.json
```
Then edit `input.json` and run:
```bash
infsh app run falai/flux-dev-lora --input input.json
```
## Workflow Example
### Image Generation with FLUX
```bash
# 1. Get app details
infsh app get falai/flux-dev-lora
# 2. Generate sample input
infsh app sample falai/flux-dev-lora --save input.json
# 3. Edit input.json
# {
# "prompt": "a cat astronaut floating in space",
# "num_images": 1,
# "image_size": "landscape_16_9"
# }
# 4. Run
infsh app run falai/flux-dev-lora --input input.json
```
### Video Generation with Veo
```bash
# 1. Generate sample
infsh app sample google/veo-3-1-fast --save input.json
# 2. Edit prompt
# {
# "prompt": "A drone shot flying over a forest at sunset"
# }
# 3. Run
infsh app run google/veo-3-1-fast --input input.json
```
### Text-to-Speech
```bash
# Quick inline run
infsh app run falai/kokoro-tts --input '{"text": "Hello, this is a test."}'
```
## Task Tracking
When you run an app, the CLI shows the task ID:
```
Running falai/flux-dev-lora
Task ID: abc123def456
```
For long-running tasks, you can check status anytime:
```bash
# Check task status
infsh task get abc123def456
# Get result as JSON
infsh task get abc123def456 --json
# Save result to file
infsh task get abc123def456 --save result.json
```
### Run Without Waiting
For very long tasks, run in background:
```bash
# Submit and return immediately
infsh app run google/veo-3 --input input.json --no-wait
# Check later
infsh task get <task-id>
```
## Output
The CLI returns the app output directly. For file outputs (images, videos, audio), you'll receive URLs to download.
Example output:
```json
{
"images": [
{
"url": "https://cloud.inference.sh/...",
"content_type": "image/png"
}
]
}
```
## Error Handling
| Error | Cause | Solution |
|-------|-------|----------|
| "invalid input" | Schema mismatch | Check `infsh app get` for required fields |
| "app not found" | Wrong app name | Check `infsh app list --search` |
| "quota exceeded" | Out of credits | Check account balance |
## Documentation
- [Running Apps](https://inference.sh/docs/apps/running) - Complete running apps guide
- [Streaming Results](https://inference.sh/docs/api/sdk/streaming) - Real-time progress updates
- [Setup Parameters](https://inference.sh/docs/apps/setup-parameters) - Configuring app inputs

View file

@ -113,11 +113,13 @@ class TestDefaultContextLengths:
def test_gpt4_models_128k_or_1m(self):
# gpt-4.1 and gpt-4.1-mini have 1M context; other gpt-4* have 128k
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
if "gpt-4" in key:
if "gpt-4.1" in key:
assert value == 1047576, f"{key} should be 1047576 (1M)"
else:
assert value == 128000, f"{key} should be 128000"
if "gpt-4" in key and "gpt-4.1" not in key:
assert value == 128000, f"{key} should be 128000"
def test_gpt41_models_1m(self):
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
if "gpt-4.1" in key:
assert value == 1047576, f"{key} should be 1047576"
def test_gemini_models_1m(self):
for key, value in DEFAULT_CONTEXT_LENGTHS.items():

View file

@ -0,0 +1,101 @@
from types import SimpleNamespace
from agent.usage_pricing import (
CanonicalUsage,
estimate_usage_cost,
get_pricing_entry,
normalize_usage,
)
def test_normalize_usage_anthropic_keeps_cache_buckets_separate():
usage = SimpleNamespace(
input_tokens=1000,
output_tokens=500,
cache_read_input_tokens=2000,
cache_creation_input_tokens=400,
)
normalized = normalize_usage(usage, provider="anthropic", api_mode="anthropic_messages")
assert normalized.input_tokens == 1000
assert normalized.output_tokens == 500
assert normalized.cache_read_tokens == 2000
assert normalized.cache_write_tokens == 400
assert normalized.prompt_tokens == 3400
def test_normalize_usage_openai_subtracts_cached_prompt_tokens():
usage = SimpleNamespace(
prompt_tokens=3000,
completion_tokens=700,
prompt_tokens_details=SimpleNamespace(cached_tokens=1800),
)
normalized = normalize_usage(usage, provider="openai", api_mode="chat_completions")
assert normalized.input_tokens == 1200
assert normalized.cache_read_tokens == 1800
assert normalized.output_tokens == 700
def test_openrouter_models_api_pricing_is_converted_from_per_token_to_per_million(monkeypatch):
monkeypatch.setattr(
"agent.usage_pricing.fetch_model_metadata",
lambda: {
"anthropic/claude-opus-4.6": {
"pricing": {
"prompt": "0.000005",
"completion": "0.000025",
"input_cache_read": "0.0000005",
"input_cache_write": "0.00000625",
}
}
},
)
entry = get_pricing_entry(
"anthropic/claude-opus-4.6",
provider="openrouter",
base_url="https://openrouter.ai/api/v1",
)
assert float(entry.input_cost_per_million) == 5.0
assert float(entry.output_cost_per_million) == 25.0
assert float(entry.cache_read_cost_per_million) == 0.5
assert float(entry.cache_write_cost_per_million) == 6.25
def test_estimate_usage_cost_marks_subscription_routes_included():
result = estimate_usage_cost(
"gpt-5.3-codex",
CanonicalUsage(input_tokens=1000, output_tokens=500),
provider="openai-codex",
base_url="https://chatgpt.com/backend-api/codex",
)
assert result.status == "included"
assert float(result.amount_usd) == 0.0
def test_estimate_usage_cost_refuses_cache_pricing_without_official_cache_rate(monkeypatch):
monkeypatch.setattr(
"agent.usage_pricing.fetch_model_metadata",
lambda: {
"google/gemini-2.5-pro": {
"pricing": {
"prompt": "0.00000125",
"completion": "0.00001",
}
}
},
)
result = estimate_usage_cost(
"google/gemini-2.5-pro",
CanonicalUsage(input_tokens=1000, output_tokens=500, cache_read_tokens=100),
provider="openrouter",
base_url="https://openrouter.ai/api/v1",
)
assert result.status == "unknown"

View file

@ -50,13 +50,16 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
return runner
def _watcher_dict(session_id="proc_test"):
return {
def _watcher_dict(session_id="proc_test", thread_id=""):
d = {
"session_id": session_id,
"check_interval": 0,
"platform": "telegram",
"chat_id": "123",
}
if thread_id:
d["thread_id"] = thread_id
return d
# ---------------------------------------------------------------------------
@ -196,3 +199,47 @@ async def test_run_process_watcher_respects_notification_mode(
if expected_fragment is not None:
sent_message = adapter.send.await_args.args[1]
assert expected_fragment in sent_message
@pytest.mark.asyncio
async def test_thread_id_passed_to_send(monkeypatch, tmp_path):
"""thread_id from watcher dict is forwarded as metadata to adapter.send()."""
import tools.process_registry as pr_module
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
await runner._run_process_watcher(_watcher_dict(thread_id="42"))
assert adapter.send.await_count == 1
_, kwargs = adapter.send.call_args
assert kwargs["metadata"] == {"thread_id": "42"}
@pytest.mark.asyncio
async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
"""When thread_id is empty, metadata should be None (general topic)."""
import tools.process_registry as pr_module
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
await runner._run_process_watcher(_watcher_dict())
assert adapter.send.await_count == 1
_, kwargs = adapter.send.call_args
assert kwargs["metadata"] is None

View file

@ -0,0 +1,274 @@
"""Tests for DingTalk platform adapter."""
import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock
import pytest
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestDingTalkRequirements:
def test_returns_false_when_sdk_missing(self, monkeypatch):
with patch.dict("sys.modules", {"dingtalk_stream": None}):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False
)
from gateway.platforms.dingtalk import check_dingtalk_requirements
assert check_dingtalk_requirements() is False
def test_returns_false_when_env_vars_missing(self, monkeypatch):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True
)
monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True)
monkeypatch.delenv("DINGTALK_CLIENT_ID", raising=False)
monkeypatch.delenv("DINGTALK_CLIENT_SECRET", raising=False)
from gateway.platforms.dingtalk import check_dingtalk_requirements
assert check_dingtalk_requirements() is False
def test_returns_true_when_all_available(self, monkeypatch):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True
)
monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True)
monkeypatch.setenv("DINGTALK_CLIENT_ID", "test-id")
monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "test-secret")
from gateway.platforms.dingtalk import check_dingtalk_requirements
assert check_dingtalk_requirements() is True
# ---------------------------------------------------------------------------
# Adapter construction
# ---------------------------------------------------------------------------
class TestDingTalkAdapterInit:
def test_reads_config_from_extra(self):
from gateway.platforms.dingtalk import DingTalkAdapter
config = PlatformConfig(
enabled=True,
extra={"client_id": "cfg-id", "client_secret": "cfg-secret"},
)
adapter = DingTalkAdapter(config)
assert adapter._client_id == "cfg-id"
assert adapter._client_secret == "cfg-secret"
assert adapter.name == "Dingtalk" # base class uses .title()
def test_falls_back_to_env_vars(self, monkeypatch):
monkeypatch.setenv("DINGTALK_CLIENT_ID", "env-id")
monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "env-secret")
from gateway.platforms.dingtalk import DingTalkAdapter
config = PlatformConfig(enabled=True)
adapter = DingTalkAdapter(config)
assert adapter._client_id == "env-id"
assert adapter._client_secret == "env-secret"
# ---------------------------------------------------------------------------
# Message text extraction
# ---------------------------------------------------------------------------
class TestExtractText:
def test_extracts_dict_text(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = {"content": " hello world "}
msg.rich_text = None
assert DingTalkAdapter._extract_text(msg) == "hello world"
def test_extracts_string_text(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = "plain text"
msg.rich_text = None
assert DingTalkAdapter._extract_text(msg) == "plain text"
def test_falls_back_to_rich_text(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = ""
msg.rich_text = [{"text": "part1"}, {"text": "part2"}, {"image": "url"}]
assert DingTalkAdapter._extract_text(msg) == "part1 part2"
def test_returns_empty_for_no_content(self):
from gateway.platforms.dingtalk import DingTalkAdapter
msg = MagicMock()
msg.text = ""
msg.rich_text = None
assert DingTalkAdapter._extract_text(msg) == ""
# ---------------------------------------------------------------------------
# Deduplication
# ---------------------------------------------------------------------------
class TestDeduplication:
def test_first_message_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
assert adapter._is_duplicate("msg-1") is False
def test_second_same_message_is_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-1") is True
def test_different_messages_not_duplicate(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._is_duplicate("msg-1")
assert adapter._is_duplicate("msg-2") is False
def test_cache_cleanup_on_overflow(self):
from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
# Fill beyond max
for i in range(DEDUP_MAX_SIZE + 10):
adapter._is_duplicate(f"msg-{i}")
# Cache should have been pruned
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10
# ---------------------------------------------------------------------------
# Send
# ---------------------------------------------------------------------------
class TestSend:
@pytest.mark.asyncio
async def test_send_posts_to_webhook(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "OK"
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
result = await adapter.send(
"chat-123", "Hello!",
metadata={"session_webhook": "https://dingtalk.example/webhook"}
)
assert result.success is True
mock_client.post.assert_called_once()
call_args = mock_client.post.call_args
assert call_args[0][0] == "https://dingtalk.example/webhook"
payload = call_args[1]["json"]
assert payload["msgtype"] == "markdown"
assert payload["markdown"]["title"] == "Hermes"
assert payload["markdown"]["text"] == "Hello!"
@pytest.mark.asyncio
async def test_send_fails_without_webhook(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._http_client = AsyncMock()
result = await adapter.send("chat-123", "Hello!")
assert result.success is False
assert "session_webhook" in result.error
@pytest.mark.asyncio
async def test_send_uses_cached_webhook(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
mock_response = MagicMock()
mock_response.status_code = 200
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
adapter._session_webhooks["chat-123"] = "https://cached.example/webhook"
result = await adapter.send("chat-123", "Hello!")
assert result.success is True
assert mock_client.post.call_args[0][0] == "https://cached.example/webhook"
@pytest.mark.asyncio
async def test_send_handles_http_error(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
adapter._http_client = mock_client
result = await adapter.send(
"chat-123", "Hello!",
metadata={"session_webhook": "https://example/webhook"}
)
assert result.success is False
assert "400" in result.error
# ---------------------------------------------------------------------------
# Connect / disconnect
# ---------------------------------------------------------------------------
class TestConnect:
@pytest.mark.asyncio
async def test_connect_fails_without_sdk(self, monkeypatch):
monkeypatch.setattr(
"gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False
)
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_connect_fails_without_credentials(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._client_id = ""
adapter._client_secret = ""
result = await adapter.connect()
assert result is False
@pytest.mark.asyncio
async def test_disconnect_cleans_up(self):
from gateway.platforms.dingtalk import DingTalkAdapter
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
adapter._session_webhooks["a"] = "http://x"
adapter._seen_messages["b"] = 1.0
adapter._http_client = AsyncMock()
adapter._stream_task = None
await adapter.disconnect()
assert len(adapter._session_webhooks) == 0
assert len(adapter._seen_messages) == 0
assert adapter._http_client is None
# ---------------------------------------------------------------------------
# Platform enum
# ---------------------------------------------------------------------------
class TestPlatformEnum:
def test_dingtalk_in_platform_enum(self):
assert Platform.DINGTALK.value == "dingtalk"

View file

@ -0,0 +1,448 @@
"""Tests for Matrix platform adapter."""
import json
import re
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestMatrixPlatformEnum:
def test_matrix_enum_exists(self):
assert Platform.MATRIX.value == "matrix"
def test_matrix_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "matrix" in platforms
class TestMatrixConfigLoading:
def test_apply_env_overrides_with_access_token(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX in config.platforms
mc = config.platforms[Platform.MATRIX]
assert mc.enabled is True
assert mc.token == "syt_abc123"
assert mc.extra.get("homeserver") == "https://matrix.example.org"
def test_apply_env_overrides_with_password(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.setenv("MATRIX_PASSWORD", "secret123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_USER_ID", "@bot:example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX in config.platforms
mc = config.platforms[Platform.MATRIX]
assert mc.enabled is True
assert mc.extra.get("password") == "secret123"
assert mc.extra.get("user_id") == "@bot:example.org"
def test_matrix_not_loaded_without_creds(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATRIX not in config.platforms
def test_matrix_encryption_flag(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_ENCRYPTION", "true")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("encryption") is True
def test_matrix_encryption_default_off(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("encryption") is False
def test_matrix_home_room(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_HOME_ROOM", "!room123:example.org")
monkeypatch.setenv("MATRIX_HOME_ROOM_NAME", "Bot Room")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.MATRIX)
assert home is not None
assert home.chat_id == "!room123:example.org"
assert home.name == "Bot Room"
def test_matrix_user_id_stored_in_extra(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
monkeypatch.setenv("MATRIX_USER_ID", "@hermes:example.org")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
mc = config.platforms[Platform.MATRIX]
assert mc.extra.get("user_id") == "@hermes:example.org"
# ---------------------------------------------------------------------------
# Adapter helpers
# ---------------------------------------------------------------------------
def _make_adapter():
"""Create a MatrixAdapter with mocked config."""
from gateway.platforms.matrix import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
},
)
adapter = MatrixAdapter(config)
return adapter
# ---------------------------------------------------------------------------
# mxc:// URL conversion
# ---------------------------------------------------------------------------
class TestMatrixMxcToHttp:
def setup_method(self):
self.adapter = _make_adapter()
def test_basic_mxc_conversion(self):
"""mxc://server/media_id should become an authenticated HTTP URL."""
mxc = "mxc://matrix.org/abc123"
result = self.adapter._mxc_to_http(mxc)
assert result == "https://matrix.example.org/_matrix/client/v1/media/download/matrix.org/abc123"
def test_mxc_with_different_server(self):
"""mxc:// from a different server should still use our homeserver."""
mxc = "mxc://other.server/media456"
result = self.adapter._mxc_to_http(mxc)
assert result.startswith("https://matrix.example.org/")
assert "other.server/media456" in result
def test_non_mxc_url_passthrough(self):
"""Non-mxc URLs should be returned unchanged."""
url = "https://example.com/image.png"
assert self.adapter._mxc_to_http(url) == url
def test_mxc_uses_client_v1_endpoint(self):
"""Should use /_matrix/client/v1/media/download/ not the deprecated path."""
mxc = "mxc://example.com/test123"
result = self.adapter._mxc_to_http(mxc)
assert "/_matrix/client/v1/media/download/" in result
assert "/_matrix/media/v3/download/" not in result
# ---------------------------------------------------------------------------
# DM detection
# ---------------------------------------------------------------------------
class TestMatrixDmDetection:
def setup_method(self):
self.adapter = _make_adapter()
def test_room_in_m_direct_is_dm(self):
"""A room listed in m.direct should be detected as DM."""
self.adapter._joined_rooms = {"!dm_room:ex.org", "!group_room:ex.org"}
self.adapter._dm_rooms = {
"!dm_room:ex.org": True,
"!group_room:ex.org": False,
}
assert self.adapter._dm_rooms.get("!dm_room:ex.org") is True
assert self.adapter._dm_rooms.get("!group_room:ex.org") is False
def test_unknown_room_not_in_cache(self):
"""Unknown rooms should not be in the DM cache."""
self.adapter._dm_rooms = {}
assert self.adapter._dm_rooms.get("!unknown:ex.org") is None
@pytest.mark.asyncio
async def test_refresh_dm_cache_with_m_direct(self):
"""_refresh_dm_cache should populate _dm_rooms from m.direct data."""
self.adapter._joined_rooms = {"!room_a:ex.org", "!room_b:ex.org", "!room_c:ex.org"}
mock_client = MagicMock()
mock_resp = MagicMock()
mock_resp.content = {
"@alice:ex.org": ["!room_a:ex.org"],
"@bob:ex.org": ["!room_b:ex.org"],
}
mock_client.get_account_data = AsyncMock(return_value=mock_resp)
self.adapter._client = mock_client
await self.adapter._refresh_dm_cache()
assert self.adapter._dm_rooms["!room_a:ex.org"] is True
assert self.adapter._dm_rooms["!room_b:ex.org"] is True
assert self.adapter._dm_rooms["!room_c:ex.org"] is False
# ---------------------------------------------------------------------------
# Reply fallback stripping
# ---------------------------------------------------------------------------
class TestMatrixReplyFallbackStripping:
"""Test that Matrix reply fallback lines ('> ' prefix) are stripped."""
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._user_id = "@bot:example.org"
self.adapter._startup_ts = 0.0
self.adapter._dm_rooms = {}
self.adapter._message_handler = AsyncMock()
def _strip_fallback(self, body: str, has_reply: bool = True) -> str:
"""Simulate the reply fallback stripping logic from _on_room_message."""
reply_to = "some_event_id" if has_reply else None
if reply_to and body.startswith("> "):
lines = body.split("\n")
stripped = []
past_fallback = False
for line in lines:
if not past_fallback:
if line.startswith("> ") or line == ">":
continue
if line == "":
past_fallback = True
continue
past_fallback = True
stripped.append(line)
body = "\n".join(stripped) if stripped else body
return body
def test_simple_reply_fallback(self):
body = "> <@alice:ex.org> Original message\n\nActual reply"
result = self._strip_fallback(body)
assert result == "Actual reply"
def test_multiline_reply_fallback(self):
body = "> <@alice:ex.org> Line 1\n> Line 2\n\nMy response"
result = self._strip_fallback(body)
assert result == "My response"
def test_no_reply_fallback_preserved(self):
body = "Just a normal message"
result = self._strip_fallback(body, has_reply=False)
assert result == "Just a normal message"
def test_quote_without_reply_preserved(self):
"""'> ' lines without a reply_to context should be preserved."""
body = "> This is a blockquote"
result = self._strip_fallback(body, has_reply=False)
assert result == "> This is a blockquote"
def test_empty_fallback_separator(self):
"""The blank line between fallback and actual content should be stripped."""
body = "> <@alice:ex.org> hi\n>\n\nResponse"
result = self._strip_fallback(body)
assert result == "Response"
def test_multiline_response_after_fallback(self):
body = "> <@alice:ex.org> Original\n\nLine 1\nLine 2\nLine 3"
result = self._strip_fallback(body)
assert result == "Line 1\nLine 2\nLine 3"
# ---------------------------------------------------------------------------
# Thread detection
# ---------------------------------------------------------------------------
class TestMatrixThreadDetection:
def test_thread_id_from_m_relates_to(self):
"""m.relates_to with rel_type=m.thread should extract the event_id."""
relates_to = {
"rel_type": "m.thread",
"event_id": "$thread_root_event",
"is_falling_back": True,
"m.in_reply_to": {"event_id": "$some_event"},
}
# Simulate the extraction logic from _on_room_message
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id == "$thread_root_event"
def test_no_thread_for_reply(self):
"""m.in_reply_to without m.thread should not set thread_id."""
relates_to = {
"m.in_reply_to": {"event_id": "$reply_event"},
}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
def test_no_thread_for_edit(self):
"""m.replace relation should not set thread_id."""
relates_to = {
"rel_type": "m.replace",
"event_id": "$edited_event",
}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
def test_empty_relates_to(self):
"""Empty m.relates_to should not set thread_id."""
relates_to = {}
thread_id = None
if relates_to.get("rel_type") == "m.thread":
thread_id = relates_to.get("event_id")
assert thread_id is None
# ---------------------------------------------------------------------------
# Format message
# ---------------------------------------------------------------------------
class TestMatrixFormatMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_image_markdown_stripped(self):
"""![alt](url) should be converted to just the URL."""
result = self.adapter.format_message("![cat](https://img.example.com/cat.png)")
assert result == "https://img.example.com/cat.png"
def test_regular_markdown_preserved(self):
"""Standard markdown should be preserved (Matrix supports it)."""
content = "**bold** and *italic* and `code`"
assert self.adapter.format_message(content) == content
def test_plain_text_unchanged(self):
content = "Hello, world!"
assert self.adapter.format_message(content) == content
def test_multiple_images_stripped(self):
content = "![a](http://a.com/1.png) and ![b](http://b.com/2.png)"
result = self.adapter.format_message(content)
assert "![" not in result
assert "http://a.com/1.png" in result
assert "http://b.com/2.png" in result
# ---------------------------------------------------------------------------
# Markdown to HTML conversion
# ---------------------------------------------------------------------------
class TestMatrixMarkdownToHtml:
def setup_method(self):
self.adapter = _make_adapter()
def test_bold_conversion(self):
"""**bold** should produce <strong> tags."""
result = self.adapter._markdown_to_html("**bold**")
assert "<strong>" in result or "<b>" in result
assert "bold" in result
def test_italic_conversion(self):
"""*italic* should produce <em> tags."""
result = self.adapter._markdown_to_html("*italic*")
assert "<em>" in result or "<i>" in result
def test_inline_code(self):
"""`code` should produce <code> tags."""
result = self.adapter._markdown_to_html("`code`")
assert "<code>" in result
def test_plain_text_returns_html(self):
"""Plain text should still be returned (possibly with <br> or <p>)."""
result = self.adapter._markdown_to_html("Hello world")
assert "Hello world" in result
# ---------------------------------------------------------------------------
# Helper: display name extraction
# ---------------------------------------------------------------------------
class TestMatrixDisplayName:
def setup_method(self):
self.adapter = _make_adapter()
def test_get_display_name_from_room_users(self):
"""Should get display name from room's users dict."""
mock_room = MagicMock()
mock_user = MagicMock()
mock_user.display_name = "Alice"
mock_room.users = {"@alice:ex.org": mock_user}
name = self.adapter._get_display_name(mock_room, "@alice:ex.org")
assert name == "Alice"
def test_get_display_name_fallback_to_localpart(self):
"""Should extract localpart from @user:server format."""
mock_room = MagicMock()
mock_room.users = {}
name = self.adapter._get_display_name(mock_room, "@bob:example.org")
assert name == "bob"
def test_get_display_name_no_room(self):
"""Should handle None room gracefully."""
name = self.adapter._get_display_name(None, "@charlie:ex.org")
assert name == "charlie"
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestMatrixRequirements:
def test_check_requirements_with_token(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
from gateway.platforms.matrix import check_matrix_requirements
try:
import nio # noqa: F401
assert check_matrix_requirements() is True
except ImportError:
assert check_matrix_requirements() is False
def test_check_requirements_without_creds(self, monkeypatch):
monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False)
monkeypatch.delenv("MATRIX_PASSWORD", raising=False)
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.platforms.matrix import check_matrix_requirements
assert check_matrix_requirements() is False
def test_check_requirements_without_homeserver(self, monkeypatch):
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
from gateway.platforms.matrix import check_matrix_requirements
assert check_matrix_requirements() is False

View file

@ -0,0 +1,574 @@
"""Tests for Mattermost platform adapter."""
import json
import time
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestMattermostPlatformEnum:
def test_mattermost_enum_exists(self):
assert Platform.MATTERMOST.value == "mattermost"
def test_mattermost_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "mattermost" in platforms
class TestMattermostConfigLoading:
def test_apply_env_overrides_mattermost(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST in config.platforms
mc = config.platforms[Platform.MATTERMOST]
assert mc.enabled is True
assert mc.token == "mm-tok-abc123"
assert mc.extra.get("url") == "https://mm.example.com"
def test_mattermost_not_loaded_without_token(self, monkeypatch):
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST not in config.platforms
def test_connected_platforms_includes_mattermost(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
connected = config.get_connected_platforms()
assert Platform.MATTERMOST in connected
def test_mattermost_home_channel(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL", "ch_abc123")
monkeypatch.setenv("MATTERMOST_HOME_CHANNEL_NAME", "General")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.MATTERMOST)
assert home is not None
assert home.chat_id == "ch_abc123"
assert home.name == "General"
def test_mattermost_url_warning_without_url(self, monkeypatch):
"""MATTERMOST_TOKEN set but MATTERMOST_URL missing should still load."""
monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123")
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
assert Platform.MATTERMOST in config.platforms
assert config.platforms[Platform.MATTERMOST].extra.get("url") == ""
# ---------------------------------------------------------------------------
# Adapter format / truncate
# ---------------------------------------------------------------------------
def _make_adapter():
"""Create a MattermostAdapter with mocked config."""
from gateway.platforms.mattermost import MattermostAdapter
config = PlatformConfig(
enabled=True,
token="test-token",
extra={"url": "https://mm.example.com"},
)
adapter = MattermostAdapter(config)
return adapter
class TestMattermostFormatMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_image_markdown_to_url(self):
"""![alt](url) should be converted to just the URL."""
result = self.adapter.format_message("![cat](https://img.example.com/cat.png)")
assert result == "https://img.example.com/cat.png"
def test_image_markdown_strips_alt_text(self):
result = self.adapter.format_message("Here: ![my image](https://x.com/a.jpg) done")
assert "![" not in result
assert "https://x.com/a.jpg" in result
def test_regular_markdown_preserved(self):
"""Regular markdown (bold, italic, code) should be kept as-is."""
content = "**bold** and *italic* and `code`"
assert self.adapter.format_message(content) == content
def test_regular_links_preserved(self):
"""Non-image links should be preserved."""
content = "[click](https://example.com)"
assert self.adapter.format_message(content) == content
def test_plain_text_unchanged(self):
content = "Hello, world!"
assert self.adapter.format_message(content) == content
def test_multiple_images(self):
content = "![a](http://a.com/1.png) text ![b](http://b.com/2.png)"
result = self.adapter.format_message(content)
assert "![" not in result
assert "http://a.com/1.png" in result
assert "http://b.com/2.png" in result
class TestMattermostTruncateMessage:
def setup_method(self):
self.adapter = _make_adapter()
def test_short_message_single_chunk(self):
msg = "Hello, world!"
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) == 1
assert chunks[0] == msg
def test_long_message_splits(self):
msg = "a " * 2500 # 5000 chars
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) >= 2
for chunk in chunks:
assert len(chunk) <= 4000
def test_custom_max_length(self):
msg = "Hello " * 20
chunks = self.adapter.truncate_message(msg, max_length=50)
assert all(len(c) <= 50 for c in chunks)
def test_exactly_at_limit(self):
msg = "x" * 4000
chunks = self.adapter.truncate_message(msg, 4000)
assert len(chunks) == 1
# ---------------------------------------------------------------------------
# Send
# ---------------------------------------------------------------------------
class TestMattermostSend:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_calls_api_post(self):
"""send() should POST to /api/v4/posts with channel_id and message."""
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post123"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Hello!")
assert result.success is True
assert result.message_id == "post123"
# Verify post was called with correct URL
call_args = self.adapter._session.post.call_args
assert "/api/v4/posts" in call_args[0][0]
# Verify payload
payload = call_args[1]["json"]
assert payload["channel_id"] == "channel_1"
assert payload["message"] == "Hello!"
@pytest.mark.asyncio
async def test_send_empty_content_succeeds(self):
"""Empty content should return success without calling the API."""
result = await self.adapter.send("channel_1", "")
assert result.success is True
@pytest.mark.asyncio
async def test_send_with_thread_reply(self):
"""When reply_mode is 'thread', reply_to should become root_id."""
self.adapter._reply_mode = "thread"
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post456"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
assert result.success is True
payload = self.adapter._session.post.call_args[1]["json"]
assert payload["root_id"] == "root_post"
@pytest.mark.asyncio
async def test_send_without_thread_no_root_id(self):
"""When reply_mode is 'off', reply_to should NOT set root_id."""
self.adapter._reply_mode = "off"
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"id": "post789"})
mock_resp.text = AsyncMock(return_value="")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post")
assert result.success is True
payload = self.adapter._session.post.call_args[1]["json"]
assert "root_id" not in payload
@pytest.mark.asyncio
async def test_send_api_failure(self):
"""When API returns error, send should return failure."""
mock_resp = AsyncMock()
mock_resp.status = 500
mock_resp.json = AsyncMock(return_value={})
mock_resp.text = AsyncMock(return_value="Internal Server Error")
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
self.adapter._session.post = MagicMock(return_value=mock_resp)
result = await self.adapter.send("channel_1", "Hello!")
assert result.success is False
# ---------------------------------------------------------------------------
# WebSocket event parsing
# ---------------------------------------------------------------------------
class TestMattermostWebSocketParsing:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._bot_user_id = "bot_user_id"
# Mock handle_message to capture the MessageEvent without processing
self.adapter.handle_message = AsyncMock()
@pytest.mark.asyncio
async def test_parse_posted_event(self):
"""'posted' events should extract message from double-encoded post JSON."""
post_data = {
"id": "post_abc",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "Hello from Matrix!",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data), # double-encoded JSON string
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.text == "Hello from Matrix!"
assert msg_event.message_id == "post_abc"
@pytest.mark.asyncio
async def test_ignore_own_messages(self):
"""Messages from the bot's own user_id should be ignored."""
post_data = {
"id": "post_self",
"user_id": "bot_user_id", # same as bot
"channel_id": "chan_456",
"message": "Bot echo",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_ignore_non_posted_events(self):
"""Non-'posted' events should be ignored."""
event = {
"event": "typing",
"data": {"user_id": "user_123"},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_ignore_system_posts(self):
"""Posts with a 'type' field (system messages) should be ignored."""
post_data = {
"id": "sys_post",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "user joined",
"type": "system_join_channel",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
@pytest.mark.asyncio
async def test_channel_type_mapping(self):
"""channel_type 'D' should map to 'dm'."""
post_data = {
"id": "post_dm",
"user_id": "user_123",
"channel_id": "chan_dm",
"message": "DM message",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "D",
"sender_name": "@bob",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.source.chat_type == "dm"
@pytest.mark.asyncio
async def test_thread_id_from_root_id(self):
"""Post with root_id should have thread_id set."""
post_data = {
"id": "post_reply",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "Thread reply",
"root_id": "root_post_123",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.called
msg_event = self.adapter.handle_message.call_args[0][0]
assert msg_event.source.thread_id == "root_post_123"
@pytest.mark.asyncio
async def test_invalid_post_json_ignored(self):
"""Invalid JSON in data.post should be silently ignored."""
event = {
"event": "posted",
"data": {
"post": "not-valid-json{{{",
"channel_type": "O",
},
}
await self.adapter._handle_ws_event(event)
assert not self.adapter.handle_message.called
# ---------------------------------------------------------------------------
# File upload (send_image)
# ---------------------------------------------------------------------------
class TestMattermostFileUpload:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_image_downloads_and_uploads(self):
"""send_image should download the URL, upload via /api/v4/files, then post."""
# Mock the download (GET)
mock_dl_resp = AsyncMock()
mock_dl_resp.status = 200
mock_dl_resp.read = AsyncMock(return_value=b"\x89PNG\x00fake-image-data")
mock_dl_resp.content_type = "image/png"
mock_dl_resp.__aenter__ = AsyncMock(return_value=mock_dl_resp)
mock_dl_resp.__aexit__ = AsyncMock(return_value=False)
# Mock the upload (POST to /files)
mock_upload_resp = AsyncMock()
mock_upload_resp.status = 200
mock_upload_resp.json = AsyncMock(return_value={
"file_infos": [{"id": "file_abc123"}]
})
mock_upload_resp.text = AsyncMock(return_value="")
mock_upload_resp.__aenter__ = AsyncMock(return_value=mock_upload_resp)
mock_upload_resp.__aexit__ = AsyncMock(return_value=False)
# Mock the post (POST to /posts)
mock_post_resp = AsyncMock()
mock_post_resp.status = 200
mock_post_resp.json = AsyncMock(return_value={"id": "post_with_file"})
mock_post_resp.text = AsyncMock(return_value="")
mock_post_resp.__aenter__ = AsyncMock(return_value=mock_post_resp)
mock_post_resp.__aexit__ = AsyncMock(return_value=False)
# Route calls: first GET (download), then POST (upload), then POST (create post)
self.adapter._session.get = MagicMock(return_value=mock_dl_resp)
post_call_count = 0
original_post_returns = [mock_upload_resp, mock_post_resp]
def post_side_effect(*args, **kwargs):
nonlocal post_call_count
resp = original_post_returns[min(post_call_count, len(original_post_returns) - 1)]
post_call_count += 1
return resp
self.adapter._session.post = MagicMock(side_effect=post_side_effect)
result = await self.adapter.send_image(
"channel_1", "https://img.example.com/cat.png", caption="A cat"
)
assert result.success is True
assert result.message_id == "post_with_file"
# ---------------------------------------------------------------------------
# Dedup cache
# ---------------------------------------------------------------------------
class TestMattermostDedup:
def setup_method(self):
self.adapter = _make_adapter()
self.adapter._bot_user_id = "bot_user_id"
# Mock handle_message to capture calls without processing
self.adapter.handle_message = AsyncMock()
@pytest.mark.asyncio
async def test_duplicate_post_ignored(self):
"""The same post_id within the TTL window should be ignored."""
post_data = {
"id": "post_dup",
"user_id": "user_123",
"channel_id": "chan_456",
"message": "Hello!",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
# First time: should process
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 1
# Second time (same post_id): should be deduped
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 1 # still 1
@pytest.mark.asyncio
async def test_different_post_ids_both_processed(self):
"""Different post IDs should both be processed."""
for i, pid in enumerate(["post_a", "post_b"]):
post_data = {
"id": pid,
"user_id": "user_123",
"channel_id": "chan_456",
"message": f"Message {i}",
}
event = {
"event": "posted",
"data": {
"post": json.dumps(post_data),
"channel_type": "O",
"sender_name": "@alice",
},
}
await self.adapter._handle_ws_event(event)
assert self.adapter.handle_message.call_count == 2
def test_prune_seen_clears_expired(self):
"""_prune_seen should remove entries older than _SEEN_TTL."""
now = time.time()
# Fill with enough expired entries to trigger pruning
for i in range(self.adapter._SEEN_MAX + 10):
self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago
# Add a fresh one
self.adapter._seen_posts["fresh"] = now
self.adapter._prune_seen()
# Old entries should be pruned, fresh one kept
assert "fresh" in self.adapter._seen_posts
assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX
def test_seen_cache_tracks_post_ids(self):
"""Posts are tracked in _seen_posts dict."""
self.adapter._seen_posts["test_post"] = time.time()
assert "test_post" in self.adapter._seen_posts
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
class TestMattermostRequirements:
def test_check_requirements_with_token_and_url(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com")
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is True
def test_check_requirements_without_token(self, monkeypatch):
monkeypatch.delenv("MATTERMOST_TOKEN", raising=False)
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is False
def test_check_requirements_without_url(self, monkeypatch):
monkeypatch.setenv("MATTERMOST_TOKEN", "test-token")
monkeypatch.delenv("MATTERMOST_URL", raising=False)
from gateway.platforms.mattermost import check_mattermost_requirements
assert check_mattermost_requirements() is False

View file

@ -703,5 +703,15 @@ class TestLastPromptTokens:
store.update_session("k1", model="openai/gpt-5.4")
store._db.update_token_counts.assert_called_once_with(
"s1", 0, 0, model="openai/gpt-5.4"
"s1",
input_tokens=0,
output_tokens=0,
cache_read_tokens=0,
cache_write_tokens=0,
estimated_cost_usd=None,
cost_status=None,
cost_source=None,
billing_provider=None,
billing_base_url=None,
model="openai/gpt-5.4",
)

View file

@ -1,240 +1,215 @@
"""Tests for SMS (Telnyx) platform adapter."""
import json
"""Tests for SMS (Twilio) platform integration.
Covers config loading, format/truncate, echo prevention,
requirements check, and toolset verification.
"""
import os
from unittest.mock import patch
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from gateway.config import Platform, PlatformConfig
from gateway.config import Platform, PlatformConfig, HomeChannel
# ---------------------------------------------------------------------------
# Platform & Config
# ---------------------------------------------------------------------------
class TestSmsPlatformEnum:
def test_sms_enum_exists(self):
assert Platform.SMS.value == "sms"
def test_sms_in_platform_list(self):
platforms = [p.value for p in Platform]
assert "sms" in platforms
# ── Config loading ──────────────────────────────────────────────────
class TestSmsConfigLoading:
def test_apply_env_overrides_sms(self, monkeypatch):
monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123")
"""Verify _apply_env_overrides wires SMS correctly."""
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
def test_sms_platform_enum_exists(self):
assert Platform.SMS.value == "sms"
assert Platform.SMS in config.platforms
sc = config.platforms[Platform.SMS]
assert sc.enabled is True
assert sc.api_key == "KEY_test123"
def test_env_overrides_create_sms_config(self):
from gateway.config import load_gateway_config
def test_sms_not_loaded_without_key(self, monkeypatch):
monkeypatch.delenv("TELNYX_API_KEY", raising=False)
env = {
"TWILIO_ACCOUNT_SID": "ACtest123",
"TWILIO_AUTH_TOKEN": "token_abc",
"TWILIO_PHONE_NUMBER": "+15551234567",
}
with patch.dict(os.environ, env, clear=False):
config = load_gateway_config()
assert Platform.SMS in config.platforms
pc = config.platforms[Platform.SMS]
assert pc.enabled is True
assert pc.api_key == "token_abc"
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
def test_env_overrides_set_home_channel(self):
from gateway.config import load_gateway_config
assert Platform.SMS not in config.platforms
env = {
"TWILIO_ACCOUNT_SID": "ACtest123",
"TWILIO_AUTH_TOKEN": "token_abc",
"TWILIO_PHONE_NUMBER": "+15551234567",
"SMS_HOME_CHANNEL": "+15559876543",
"SMS_HOME_CHANNEL_NAME": "My Phone",
}
with patch.dict(os.environ, env, clear=False):
config = load_gateway_config()
hc = config.platforms[Platform.SMS].home_channel
assert hc is not None
assert hc.chat_id == "+15559876543"
assert hc.name == "My Phone"
assert hc.platform == Platform.SMS
def test_connected_platforms_includes_sms(self, monkeypatch):
monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123")
def test_sms_in_connected_platforms(self):
from gateway.config import load_gateway_config
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
connected = config.get_connected_platforms()
assert Platform.SMS in connected
def test_sms_home_channel(self, monkeypatch):
monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123")
monkeypatch.setenv("SMS_HOME_CHANNEL", "+15559876543")
monkeypatch.setenv("SMS_HOME_CHANNEL_NAME", "Owner")
from gateway.config import GatewayConfig, _apply_env_overrides
config = GatewayConfig()
_apply_env_overrides(config)
home = config.get_home_channel(Platform.SMS)
assert home is not None
assert home.chat_id == "+15559876543"
assert home.name == "Owner"
env = {
"TWILIO_ACCOUNT_SID": "ACtest123",
"TWILIO_AUTH_TOKEN": "token_abc",
}
with patch.dict(os.environ, env, clear=False):
config = load_gateway_config()
connected = config.get_connected_platforms()
assert Platform.SMS in connected
# ---------------------------------------------------------------------------
# Adapter format / truncate
# ---------------------------------------------------------------------------
# ── Format / truncate ───────────────────────────────────────────────
class TestSmsFormatMessage:
def setup_method(self):
class TestSmsFormatAndTruncate:
"""Test SmsAdapter.format_message strips markdown."""
def _make_adapter(self):
from gateway.platforms.sms import SmsAdapter
config = PlatformConfig(enabled=True, api_key="test_key")
with patch.dict("os.environ", {"TELNYX_API_KEY": "test_key"}):
self.adapter = SmsAdapter(config)
def test_strip_bold(self):
assert self.adapter.format_message("**bold**") == "bold"
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = object.__new__(SmsAdapter)
adapter.config = pc
adapter._platform = Platform.SMS
adapter._account_sid = "ACtest"
adapter._auth_token = "tok"
adapter._from_number = "+15550001111"
return adapter
def test_strip_italic(self):
assert self.adapter.format_message("*italic*") == "italic"
def test_strips_bold(self):
adapter = self._make_adapter()
assert adapter.format_message("**hello**") == "hello"
def test_strip_code_block(self):
result = self.adapter.format_message("```python\ncode\n```")
def test_strips_italic(self):
adapter = self._make_adapter()
assert adapter.format_message("*world*") == "world"
def test_strips_code_blocks(self):
adapter = self._make_adapter()
result = adapter.format_message("```python\nprint('hi')\n```")
assert "```" not in result
assert "code" in result
assert "print('hi')" in result
def test_strip_inline_code(self):
assert self.adapter.format_message("`code`") == "code"
def test_strips_inline_code(self):
adapter = self._make_adapter()
assert adapter.format_message("`code`") == "code"
def test_strip_headers(self):
assert self.adapter.format_message("## Header") == "Header"
def test_strips_headers(self):
adapter = self._make_adapter()
assert adapter.format_message("## Title") == "Title"
def test_strip_links(self):
assert self.adapter.format_message("[click](http://example.com)") == "click"
def test_strips_links(self):
adapter = self._make_adapter()
assert adapter.format_message("[click](https://example.com)") == "click"
def test_collapse_newlines(self):
result = self.adapter.format_message("a\n\n\n\nb")
def test_collapses_newlines(self):
adapter = self._make_adapter()
result = adapter.format_message("a\n\n\n\nb")
assert result == "a\n\nb"
class TestSmsTruncateMessage:
def setup_method(self):
# ── Echo prevention ────────────────────────────────────────────────
class TestSmsEchoPrevention:
"""Adapter should ignore messages from its own number."""
def test_own_number_detection(self):
"""The adapter stores _from_number for echo prevention."""
from gateway.platforms.sms import SmsAdapter
config = PlatformConfig(enabled=True, api_key="test_key")
with patch.dict("os.environ", {"TELNYX_API_KEY": "test_key"}):
self.adapter = SmsAdapter(config)
def test_short_message_single_chunk(self):
msg = "Hello, world!"
chunks = self.adapter.truncate_message(msg)
assert len(chunks) == 1
assert chunks[0] == msg
def test_long_message_splits(self):
msg = "a " * 1000 # 2000 chars
chunks = self.adapter.truncate_message(msg)
assert len(chunks) >= 2
for chunk in chunks:
assert len(chunk) <= 1600
def test_custom_max_length(self):
msg = "Hello " * 20
chunks = self.adapter.truncate_message(msg, max_length=50)
assert all(len(c) <= 50 for c in chunks)
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
"TWILIO_PHONE_NUMBER": "+15550001111",
}
with patch.dict(os.environ, env):
pc = PlatformConfig(enabled=True, api_key="tok")
adapter = SmsAdapter(pc)
assert adapter._from_number == "+15550001111"
# ---------------------------------------------------------------------------
# Echo loop prevention
# ---------------------------------------------------------------------------
class TestSmsEchoLoop:
def test_own_number_ignored(self):
from gateway.platforms.sms import SmsAdapter
config = PlatformConfig(enabled=True, api_key="test_key")
with patch.dict("os.environ", {
"TELNYX_API_KEY": "test_key",
"TELNYX_FROM_NUMBERS": "+15551234567,+15559876543",
}):
adapter = SmsAdapter(config)
assert "+15551234567" in adapter._from_numbers
assert "+15559876543" in adapter._from_numbers
# ---------------------------------------------------------------------------
# Auth maps
# ---------------------------------------------------------------------------
class TestSmsAuthMaps:
def test_sms_in_allowed_users_map(self):
"""SMS should be in the platform auth maps in run.py."""
# Verify the env var names are consistent
import os
os.environ.setdefault("SMS_ALLOWED_USERS", "+15551234567")
assert os.getenv("SMS_ALLOWED_USERS") == "+15551234567"
def test_sms_allow_all_env_var(self):
"""SMS_ALLOW_ALL_USERS should be recognized."""
import os
os.environ.setdefault("SMS_ALLOW_ALL_USERS", "true")
assert os.getenv("SMS_ALLOW_ALL_USERS") == "true"
# ---------------------------------------------------------------------------
# Requirements check
# ---------------------------------------------------------------------------
# ── Requirements check ─────────────────────────────────────────────
class TestSmsRequirements:
def test_check_sms_requirements_with_key(self, monkeypatch):
monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123")
def test_check_sms_requirements_missing_sid(self):
from gateway.platforms.sms import check_sms_requirements
# aiohttp is available in test environment
assert check_sms_requirements() is True
def test_check_sms_requirements_without_key(self, monkeypatch):
monkeypatch.delenv("TELNYX_API_KEY", raising=False)
env = {"TWILIO_AUTH_TOKEN": "tok"}
with patch.dict(os.environ, env, clear=True):
assert check_sms_requirements() is False
def test_check_sms_requirements_missing_token(self):
from gateway.platforms.sms import check_sms_requirements
assert check_sms_requirements() is False
env = {"TWILIO_ACCOUNT_SID": "ACtest"}
with patch.dict(os.environ, env, clear=True):
assert check_sms_requirements() is False
def test_check_sms_requirements_both_set(self):
from gateway.platforms.sms import check_sms_requirements
env = {
"TWILIO_ACCOUNT_SID": "ACtest",
"TWILIO_AUTH_TOKEN": "tok",
}
with patch.dict(os.environ, env, clear=False):
# Only returns True if aiohttp is also importable
result = check_sms_requirements()
try:
import aiohttp # noqa: F401
assert result is True
except ImportError:
assert result is False
# ---------------------------------------------------------------------------
# Toolset & integration points
# ---------------------------------------------------------------------------
# ── Toolset verification ───────────────────────────────────────────
class TestSmsToolset:
def test_hermes_sms_toolset_exists(self):
from toolsets import get_toolset
ts = get_toolset("hermes-sms")
assert ts is not None
assert "hermes-sms" in ts.get("description", "").lower() or "sms" in ts.get("description", "").lower()
assert "tools" in ts
def test_hermes_gateway_includes_sms(self):
def test_hermes_sms_in_gateway_includes(self):
from toolsets import get_toolset
gw = get_toolset("hermes-gateway")
assert gw is not None
assert "hermes-sms" in gw["includes"]
class TestSmsPlatformHints:
def test_sms_in_platform_hints(self):
def test_sms_platform_hint_exists(self):
from agent.prompt_builder import PLATFORM_HINTS
assert "sms" in PLATFORM_HINTS
assert "SMS" in PLATFORM_HINTS["sms"] or "sms" in PLATFORM_HINTS["sms"].lower()
assert "concise" in PLATFORM_HINTS["sms"].lower()
class TestSmsCronDelivery:
def test_sms_in_cron_platform_map(self):
"""Verify the cron scheduler can resolve 'sms' platform."""
# The platform_map in _deliver_result should include sms
from gateway.config import Platform
def test_sms_in_scheduler_platform_map(self):
"""Verify cron scheduler recognizes 'sms' as a valid platform."""
# Just check the Platform enum has SMS — the scheduler imports it dynamically
assert Platform.SMS.value == "sms"
class TestSmsSendMessageTool:
def test_sms_in_send_message_platform_map(self):
"""The send_message tool should recognize 'sms' as a valid platform."""
# We verify by checking that SMS is in the Platform enum
# and the code path exists
from gateway.config import Platform
"""Verify send_message_tool recognizes 'sms'."""
# The platform_map is built inside _handle_send; verify SMS enum exists
assert hasattr(Platform, "SMS")
class TestSmsChannelDirectory:
def test_sms_in_session_discovery(self):
"""Verify SMS is included in session-based channel discovery."""
import inspect
from gateway.channel_directory import build_channel_directory
source = inspect.getsource(build_channel_directory)
assert '"sms"' in source
class TestSmsStatus:
def test_sms_in_status_platforms(self):
"""Verify SMS appears in the status command platforms dict."""
import inspect
from hermes_cli.status import show_status
source = inspect.getsource(show_status)
assert '"SMS"' in source or "'SMS'" in source
def test_sms_in_cronjob_deliver_description(self):
"""Verify cronjob_tools mentions sms in deliver description."""
from tools.cronjob_tools import CRONJOB_SCHEMA
deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"]
assert "sms" in deliver_desc.lower()

View file

@ -128,6 +128,13 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
session_entry.session_key,
input_tokens=120,
output_tokens=45,
cache_read_tokens=0,
cache_write_tokens=0,
last_prompt_tokens=80,
model="openai/test-model",
estimated_cost_usd=None,
cost_status=None,
cost_source=None,
provider=None,
base_url=None,
)

View file

@ -316,6 +316,38 @@ class TestSanitizeEnvLines:
assert fixes == 0
class TestOptionalEnvVarsRegistry:
"""Verify that key env vars are registered in OPTIONAL_ENV_VARS."""
def test_tavily_api_key_registered(self):
"""TAVILY_API_KEY is listed in OPTIONAL_ENV_VARS."""
from hermes_cli.config import OPTIONAL_ENV_VARS
assert "TAVILY_API_KEY" in OPTIONAL_ENV_VARS
def test_tavily_api_key_is_tool_category(self):
"""TAVILY_API_KEY is in the 'tool' category."""
from hermes_cli.config import OPTIONAL_ENV_VARS
assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["category"] == "tool"
def test_tavily_api_key_is_password(self):
"""TAVILY_API_KEY is marked as password."""
from hermes_cli.config import OPTIONAL_ENV_VARS
assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["password"] is True
def test_tavily_api_key_has_url(self):
"""TAVILY_API_KEY has a URL."""
from hermes_cli.config import OPTIONAL_ENV_VARS
assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["url"] == "https://app.tavily.com/home"
def test_tavily_in_env_vars_by_version(self):
"""TAVILY_API_KEY is listed in ENV_VARS_BY_VERSION."""
from hermes_cli.config import ENV_VARS_BY_VERSION
all_vars = []
for vars_list in ENV_VARS_BY_VERSION.values():
all_vars.extend(vars_list)
assert "TAVILY_API_KEY" in all_vars
class TestAnthropicTokenMigration:
"""Test that config version 8→9 clears ANTHROPIC_TOKEN."""

View file

@ -0,0 +1,291 @@
"""Tests for MCP tools interactive configuration in hermes_cli.tools_config."""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from hermes_cli.tools_config import _configure_mcp_tools_interactive
# Patch targets: imports happen inside the function body, so patch at source
_PROBE = "tools.mcp_tool.probe_mcp_server_tools"
_CHECKLIST = "hermes_cli.curses_ui.curses_checklist"
_SAVE = "hermes_cli.tools_config.save_config"
def test_no_mcp_servers_prints_info(capsys):
"""Returns immediately when no MCP servers are configured."""
config = {}
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "No MCP servers configured" in captured.out
def test_all_servers_disabled_prints_info(capsys):
"""Returns immediately when all configured servers have enabled=false."""
config = {
"mcp_servers": {
"github": {"command": "npx", "enabled": False},
"slack": {"command": "npx", "enabled": "false"},
}
}
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "disabled" in captured.out
def test_probe_failure_shows_warning(capsys):
"""Shows warning when probe returns no tools."""
config = {"mcp_servers": {"github": {"command": "npx"}}}
with patch(_PROBE, return_value={}):
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "Could not discover" in captured.out
def test_probe_exception_shows_error(capsys):
"""Shows error when probe raises an exception."""
config = {"mcp_servers": {"github": {"command": "npx"}}}
with patch(_PROBE, side_effect=RuntimeError("MCP not installed")):
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "Failed to probe" in captured.out
def test_no_changes_when_checklist_cancelled(capsys):
"""No config changes when user cancels (ESC) the checklist."""
config = {
"mcp_servers": {
"github": {"command": "npx", "args": ["-y", "server-github"]},
}
}
tools = [("create_issue", "Create an issue"), ("search_repos", "Search repos")]
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, return_value={0, 1}), \
patch(_SAVE) as mock_save:
_configure_mcp_tools_interactive(config)
mock_save.assert_not_called()
captured = capsys.readouterr()
assert "no changes" in captured.out.lower()
def test_disabling_tool_writes_exclude_list(capsys):
"""Unchecking a tool adds it to the exclude list."""
config = {
"mcp_servers": {
"github": {"command": "npx"},
}
}
tools = [
("create_issue", "Create an issue"),
("delete_repo", "Delete a repo"),
("search_repos", "Search repos"),
]
# User unchecks delete_repo (index 1)
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, return_value={0, 2}), \
patch(_SAVE) as mock_save:
_configure_mcp_tools_interactive(config)
mock_save.assert_called_once()
tools_cfg = config["mcp_servers"]["github"]["tools"]
assert tools_cfg["exclude"] == ["delete_repo"]
assert "include" not in tools_cfg
def test_enabling_all_clears_filters(capsys):
"""Checking all tools clears both include and exclude lists."""
config = {
"mcp_servers": {
"github": {
"command": "npx",
"tools": {"exclude": ["delete_repo"], "include": ["create_issue"]},
},
}
}
tools = [("create_issue", "Create"), ("delete_repo", "Delete")]
# User checks all tools — pre_selected would be {0} (include mode),
# so returning {0, 1} is a change
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, return_value={0, 1}), \
patch(_SAVE) as mock_save:
_configure_mcp_tools_interactive(config)
mock_save.assert_called_once()
tools_cfg = config["mcp_servers"]["github"]["tools"]
assert "exclude" not in tools_cfg
assert "include" not in tools_cfg
def test_pre_selection_respects_existing_exclude(capsys):
"""Tools in exclude list start unchecked."""
config = {
"mcp_servers": {
"github": {
"command": "npx",
"tools": {"exclude": ["delete_repo"]},
},
}
}
tools = [("create_issue", "Create"), ("delete_repo", "Delete"), ("search", "Search")]
captured_pre_selected = {}
def fake_checklist(title, labels, pre_selected, **kwargs):
captured_pre_selected["value"] = set(pre_selected)
return pre_selected # No changes
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
# create_issue (0) and search (2) should be pre-selected, delete_repo (1) should not
assert captured_pre_selected["value"] == {0, 2}
def test_pre_selection_respects_existing_include(capsys):
"""Only tools in include list start checked."""
config = {
"mcp_servers": {
"github": {
"command": "npx",
"tools": {"include": ["search"]},
},
}
}
tools = [("create_issue", "Create"), ("delete_repo", "Delete"), ("search", "Search")]
captured_pre_selected = {}
def fake_checklist(title, labels, pre_selected, **kwargs):
captured_pre_selected["value"] = set(pre_selected)
return pre_selected # No changes
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
# Only search (2) should be pre-selected
assert captured_pre_selected["value"] == {2}
def test_multiple_servers_each_get_checklist(capsys):
"""Each server gets its own checklist."""
config = {
"mcp_servers": {
"github": {"command": "npx"},
"slack": {"url": "https://mcp.example.com"},
}
}
checklist_calls = []
def fake_checklist(title, labels, pre_selected, **kwargs):
checklist_calls.append(title)
return pre_selected # No changes
with patch(
_PROBE,
return_value={
"github": [("create_issue", "Create")],
"slack": [("send_message", "Send")],
},
), patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
assert len(checklist_calls) == 2
assert any("github" in t for t in checklist_calls)
assert any("slack" in t for t in checklist_calls)
def test_failed_server_shows_warning(capsys):
"""Servers that fail to connect show warnings."""
config = {
"mcp_servers": {
"github": {"command": "npx"},
"broken": {"command": "nonexistent"},
}
}
# Only github succeeds
with patch(
_PROBE, return_value={"github": [("create_issue", "Create")]},
), patch(_CHECKLIST, return_value={0}), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
captured = capsys.readouterr()
assert "broken" in captured.out
def test_description_truncation_in_labels():
"""Long descriptions are truncated in checklist labels."""
config = {
"mcp_servers": {
"github": {"command": "npx"},
}
}
long_desc = "A" * 100
captured_labels = {}
def fake_checklist(title, labels, pre_selected, **kwargs):
captured_labels["value"] = labels
return pre_selected
with patch(
_PROBE, return_value={"github": [("my_tool", long_desc)]},
), patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
label = captured_labels["value"][0]
assert "..." in label
assert len(label) < len(long_desc) + 30 # truncated + tool name + parens
def test_switching_from_include_to_exclude(capsys):
"""When user modifies selection, include list is replaced by exclude list."""
config = {
"mcp_servers": {
"github": {
"command": "npx",
"tools": {"include": ["create_issue"]},
},
}
}
tools = [("create_issue", "Create"), ("search", "Search"), ("delete", "Delete")]
# User selects create_issue and search (deselects delete)
# pre_selected would be {0} (only create_issue from include), so {0, 1} is a change
with patch(_PROBE, return_value={"github": tools}), \
patch(_CHECKLIST, return_value={0, 1}), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
tools_cfg = config["mcp_servers"]["github"]["tools"]
assert tools_cfg["exclude"] == ["delete"]
assert "include" not in tools_cfg
def test_empty_tools_server_skipped(capsys):
"""Server with no tools shows info message and skips checklist."""
config = {
"mcp_servers": {
"empty": {"command": "npx"},
}
}
checklist_calls = []
def fake_checklist(title, labels, pre_selected, **kwargs):
checklist_calls.append(title)
return pre_selected
with patch(_PROBE, return_value={"empty": []}), \
patch(_CHECKLIST, side_effect=fake_checklist), \
patch(_SAVE):
_configure_mcp_tools_interactive(config)
assert len(checklist_calls) == 0
captured = capsys.readouterr()
assert "no tools found" in captured.out

View file

@ -5,6 +5,13 @@ from hermes_cli.config import load_config, save_config
from hermes_cli.setup import setup_model_provider
def _maybe_keep_current_tts(question, choices):
if question != "Select TTS provider:":
return None
assert choices[-1].startswith("Keep current (")
return len(choices) - 1
def _clear_provider_env(monkeypatch):
for key in (
"NOUS_API_KEY",
@ -25,16 +32,22 @@ def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider(
config = load_config()
# Provider selection always comes first. Depending on available vision
# backends, setup may either skip the optional vision step or prompt for
# it before the default-model choice. Provide enough selections for both
# paths while still ending on "keep current model".
prompt_choices = iter([0, 2, 2])
monkeypatch.setattr(
"hermes_cli.setup.prompt_choice",
lambda *args, **kwargs: next(prompt_choices),
)
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
return 0
if question == "Configure vision:":
return len(choices) - 1
if question == "Select default model:":
assert choices[-1] == "Keep current (anthropic/claude-opus-4.6)"
return len(choices) - 1
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
def _fake_login_nous(*args, **kwargs):
auth_path = tmp_path / "auth.json"
@ -53,7 +66,6 @@ def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider(
"hermes_cli.auth.fetch_nous_models",
lambda *args, **kwargs: ["gemini-3-flash"],
)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
setup_model_provider(config)
save_config(config)
@ -75,21 +87,29 @@ def test_custom_setup_clears_active_oauth_provider(tmp_path, monkeypatch):
config = load_config()
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: 3)
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
return 3
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
prompt_values = iter(
[
"https://custom.example/v1",
"custom-api-key",
"custom/model",
"",
]
)
monkeypatch.setattr(
"hermes_cli.setup.prompt",
lambda *args, **kwargs: next(prompt_values),
)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
setup_model_provider(config)
save_config(config)
@ -111,11 +131,17 @@ def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, mon
config = load_config()
prompt_choices = iter([1, 0])
monkeypatch.setattr(
"hermes_cli.setup.prompt_choice",
lambda *args, **kwargs: next(prompt_choices),
)
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
return 1
if question == "Select default model:":
return 0
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
monkeypatch.setattr("hermes_cli.auth._login_openai_codex", lambda *args, **kwargs: None)
@ -137,7 +163,6 @@ def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, mon
"hermes_cli.codex_models.get_codex_model_ids",
_fake_get_codex_model_ids,
)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
setup_model_provider(config)
save_config(config)

View file

@ -6,6 +6,13 @@ from hermes_cli.config import load_config, save_config, save_env_value
from hermes_cli.setup import _print_setup_summary, setup_model_provider
def _maybe_keep_current_tts(question, choices):
if question != "Select TTS provider:":
return None
assert choices[-1].startswith("Keep current (")
return len(choices) - 1
def _read_env(home):
env_path = home / ".env"
data = {}
@ -50,19 +57,18 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
}
save_config(config)
calls = {"count": 0}
def fake_prompt_choice(question, choices, default=0):
calls["count"] += 1
if calls["count"] == 1:
if question == "Select your inference provider:":
assert choices[-1] == "Keep current (Custom: https://example.invalid/v1)"
return len(choices) - 1
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError("Model menu should not appear for keep-current custom")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
@ -73,7 +79,6 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
assert reloaded["model"]["provider"] == "custom"
assert reloaded["model"]["default"] == "custom/model"
assert reloaded["model"]["base_url"] == "https://example.invalid/v1"
assert calls["count"] == 1
def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
@ -87,8 +92,9 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
return 3 # Custom endpoint
if question == "Configure vision:":
return len(choices) - 1 # Skip
if question == "Select TTS provider:":
return len(choices) - 1 # Keep current
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
def fake_prompt(message, current=None, **kwargs):
@ -103,7 +109,6 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt)
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None)
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
@ -144,25 +149,23 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm
save_config(config)
captured = {"provider_choices": None, "model_choices": None}
calls = {"count": 0}
def fake_prompt_choice(question, choices, default=0):
calls["count"] += 1
if calls["count"] == 1:
if question == "Select your inference provider:":
captured["provider_choices"] = list(choices)
assert choices[-1] == "Keep current (Anthropic)"
return len(choices) - 1
if calls["count"] == 2:
if question == "Configure vision:":
assert question == "Configure vision:"
assert choices[-1] == "Skip for now"
return len(choices) - 1
if calls["count"] == 3:
if question == "Select default model:":
captured["model_choices"] = list(choices)
return len(choices) - 1 # keep current model
if calls["count"] == 4:
assert question == "Select TTS provider:"
return len(choices) - 1 # Keep current
raise AssertionError("Unexpected extra prompt_choice call")
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
@ -179,7 +182,6 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm
assert captured["model_choices"] is not None
assert captured["model_choices"][0] == "claude-opus-4-6"
assert "anthropic/claude-opus-4.6 (recommended)" not in captured["model_choices"]
assert calls["count"] == 4 # provider, vision, model, TTS
def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_path, monkeypatch):
@ -193,15 +195,24 @@ def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_pa
}
save_config(config)
picks = iter([
10, # keep current provider (shifted +1 by kilocode insertion)
1, # configure vision with OpenAI
5, # use default gpt-4o-mini vision model
4, # keep current Anthropic model
4, # TTS: Keep current
])
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
assert choices[-1] == "Keep current (Anthropic)"
return len(choices) - 1
if question == "Configure vision:":
return 1
if question == "Select vision model:":
assert choices[-1] == "Use default (gpt-4o-mini)"
return len(choices) - 1
if question == "Select default model:":
assert choices[-1] == "Keep current (claude-opus-4-6)"
return len(choices) - 1
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: next(picks))
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr(
"hermes_cli.setup.prompt",
lambda message, *args, **kwargs: "sk-openai" if "OpenAI API key" in message else "",
@ -237,8 +248,17 @@ def test_setup_switch_custom_to_codex_clears_custom_endpoint_and_updates_config(
}
save_config(config)
picks = iter([1, 0, 4]) # provider, model; 4 = TTS Keep current
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: next(picks))
def fake_prompt_choice(question, choices, default=0):
if question == "Select your inference provider:":
return 1
if question == "Select default model:":
return 0
tts_idx = _maybe_keep_current_tts(question, choices)
if tts_idx is not None:
return tts_idx
raise AssertionError(f"Unexpected prompt_choice call: {question}")
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)

View file

@ -0,0 +1,14 @@
from types import SimpleNamespace
from hermes_cli.status import show_status
def test_show_status_includes_tavily_key(monkeypatch, capsys, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("TAVILY_API_KEY", "tvly-1234567890abcdef")
show_status(SimpleNamespace(all=False, deep=False))
output = capsys.readouterr().out
assert "Tavily" in output
assert "tvly...cdef" in output

View file

@ -4,6 +4,7 @@ from types import SimpleNamespace
import pytest
from hermes_cli import config as hermes_config
from hermes_cli import main as hermes_main
@ -235,3 +236,82 @@ def test_stash_local_changes_if_needed_raises_when_stash_ref_missing(monkeypatch
with pytest.raises(CalledProcessError):
hermes_main._stash_local_changes_if_needed(["git"], Path(tmp_path))
# ---------------------------------------------------------------------------
# Update uses .[all] with fallback to .
# ---------------------------------------------------------------------------
def _setup_update_mocks(monkeypatch, tmp_path):
"""Common setup for cmd_update tests."""
(tmp_path / ".git").mkdir()
monkeypatch.setattr(hermes_main, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(hermes_main, "_stash_local_changes_if_needed", lambda *a, **kw: None)
monkeypatch.setattr(hermes_main, "_restore_stashed_changes", lambda *a, **kw: True)
monkeypatch.setattr(hermes_config, "get_missing_env_vars", lambda required_only=True: [])
monkeypatch.setattr(hermes_config, "get_missing_config_fields", lambda: [])
monkeypatch.setattr(hermes_config, "check_config_version", lambda: (5, 5))
monkeypatch.setattr(hermes_config, "migrate_config", lambda **kw: {"env_added": [], "config_added": []})
def test_cmd_update_tries_extras_first_then_falls_back(monkeypatch, tmp_path):
"""When .[all] fails, update should fall back to . instead of aborting."""
_setup_update_mocks(monkeypatch, tmp_path)
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
recorded = []
def fake_run(cmd, **kwargs):
recorded.append(cmd)
if cmd == ["git", "fetch", "origin"]:
return SimpleNamespace(stdout="", stderr="", returncode=0)
if cmd == ["git", "rev-parse", "--abbrev-ref", "HEAD"]:
return SimpleNamespace(stdout="main\n", stderr="", returncode=0)
if cmd == ["git", "rev-list", "HEAD..origin/main", "--count"]:
return SimpleNamespace(stdout="1\n", stderr="", returncode=0)
if cmd == ["git", "pull", "origin", "main"]:
return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0)
# .[all] fails
if ".[all]" in cmd:
raise CalledProcessError(returncode=1, cmd=cmd)
# bare . succeeds
if cmd == ["/usr/bin/uv", "pip", "install", "-e", ".", "--quiet"]:
return SimpleNamespace(returncode=0)
return SimpleNamespace(returncode=0)
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
hermes_main.cmd_update(SimpleNamespace())
install_cmds = [c for c in recorded if "pip" in c and "install" in c]
assert len(install_cmds) == 2
assert ".[all]" in install_cmds[0]
assert "." in install_cmds[1] and ".[all]" not in install_cmds[1]
def test_cmd_update_succeeds_with_extras(monkeypatch, tmp_path):
"""When .[all] succeeds, no fallback should be attempted."""
_setup_update_mocks(monkeypatch, tmp_path)
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
recorded = []
def fake_run(cmd, **kwargs):
recorded.append(cmd)
if cmd == ["git", "fetch", "origin"]:
return SimpleNamespace(stdout="", stderr="", returncode=0)
if cmd == ["git", "rev-parse", "--abbrev-ref", "HEAD"]:
return SimpleNamespace(stdout="main\n", stderr="", returncode=0)
if cmd == ["git", "rev-list", "HEAD..origin/main", "--count"]:
return SimpleNamespace(stdout="1\n", stderr="", returncode=0)
if cmd == ["git", "pull", "origin", "main"]:
return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0)
return SimpleNamespace(returncode=0)
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
hermes_main.cmd_update(SimpleNamespace())
install_cmds = [c for c in recorded if "pip" in c and "install" in c]
assert len(install_cmds) == 1
assert ".[all]" in install_cmds[0]

View file

@ -63,11 +63,13 @@ class TestFromEnv:
class TestFromGlobalConfig:
def test_missing_config_falls_back_to_env(self, tmp_path):
config = HonchoClientConfig.from_global_config(
config_path=tmp_path / "nonexistent.json"
)
with patch.dict(os.environ, {}, clear=True):
config = HonchoClientConfig.from_global_config(
config_path=tmp_path / "nonexistent.json"
)
# Should fall back to from_env
assert config.enabled is True or config.api_key is None # depends on env
assert config.enabled is False
assert config.api_key is None
def test_reads_full_config(self, tmp_path):
config_file = tmp_path / "config.json"

View file

@ -3,7 +3,7 @@
Comprehensive Test Suite for Web Tools Module
This script tests all web tools functionality to ensure they work correctly.
Run this after any updates to the web_tools.py module or Firecrawl library.
Run this after any updates to the web_tools.py module or backend libraries.
Usage:
python test_web_tools.py # Run all tests
@ -11,7 +11,7 @@ Usage:
python test_web_tools.py --verbose # Show detailed output
Requirements:
- FIRECRAWL_API_KEY environment variable must be set
- PARALLEL_API_KEY or FIRECRAWL_API_KEY environment variable must be set
- An auxiliary LLM provider (OPENROUTER_API_KEY or Nous Portal auth) (optional, for LLM tests)
"""
@ -28,12 +28,14 @@ from typing import List
# Import the web tools to test (updated path after moving tools/)
from tools.web_tools import (
web_search_tool,
web_extract_tool,
web_search_tool,
web_extract_tool,
web_crawl_tool,
check_firecrawl_api_key,
check_web_api_key,
check_auxiliary_model,
get_debug_session_info
get_debug_session_info,
_get_backend,
)
@ -121,12 +123,13 @@ class WebToolsTester:
"""Test environment setup and API keys"""
print_section("Environment Check")
# Check Firecrawl API key
if not check_firecrawl_api_key():
self.log_result("Firecrawl API Key", "failed", "FIRECRAWL_API_KEY not set")
# Check web backend API key (Parallel or Firecrawl)
if not check_web_api_key():
self.log_result("Web Backend API Key", "failed", "PARALLEL_API_KEY or FIRECRAWL_API_KEY not set")
return False
else:
self.log_result("Firecrawl API Key", "passed", "Found")
backend = _get_backend()
self.log_result("Web Backend API Key", "passed", f"Using {backend} backend")
# Check auxiliary LLM provider (optional)
if not check_auxiliary_model():
@ -578,7 +581,9 @@ class WebToolsTester:
},
"results": self.test_results,
"environment": {
"web_backend": _get_backend() if check_web_api_key() else None,
"firecrawl_api_key": check_firecrawl_api_key(),
"parallel_api_key": bool(os.getenv("PARALLEL_API_KEY")),
"auxiliary_model": check_auxiliary_model(),
"debug_mode": get_debug_session_info()["enabled"]
}

View file

@ -24,6 +24,7 @@ def main() -> int:
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
parent.model = "test/model"
parent.base_url = "http://localhost:1"

View file

@ -0,0 +1,263 @@
"""Unit tests for AIAgent pre/post-LLM-call guardrails.
Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
- _sanitize_api_messages() Phase 1: orphaned tool pair repair
- _cap_delegate_task_calls() Phase 2a: subagent concurrency limit
- _deduplicate_tool_calls() Phase 2b: identical call deduplication
"""
import types
from run_agent import AIAgent
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace:
"""Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object."""
tc = types.SimpleNamespace()
tc.function = types.SimpleNamespace(name=name, arguments=arguments)
return tc
def tool_result(call_id: str, content: str = "ok") -> dict:
return {"role": "tool", "tool_call_id": call_id, "content": content}
def assistant_dict_call(call_id: str, name: str = "terminal") -> dict:
"""Dict-style tool_call (as stored in message history)."""
return {"id": call_id, "function": {"name": name, "arguments": "{}"}}
# ---------------------------------------------------------------------------
# Phase 1 — _sanitize_api_messages
# ---------------------------------------------------------------------------
class TestSanitizeApiMessages:
def test_orphaned_result_removed(self):
msgs = [
{"role": "assistant", "tool_calls": [assistant_dict_call("c1")]},
tool_result("c1"),
tool_result("c_ORPHAN"),
]
out = AIAgent._sanitize_api_messages(msgs)
assert len(out) == 2
assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out)
def test_orphaned_call_gets_stub_result(self):
msgs = [
{"role": "assistant", "tool_calls": [assistant_dict_call("c2")]},
]
out = AIAgent._sanitize_api_messages(msgs)
assert len(out) == 2
stub = out[1]
assert stub["role"] == "tool"
assert stub["tool_call_id"] == "c2"
assert stub["content"]
def test_clean_messages_pass_through(self):
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "tool_calls": [assistant_dict_call("c3")]},
tool_result("c3"),
{"role": "assistant", "content": "done"},
]
out = AIAgent._sanitize_api_messages(msgs)
assert out == msgs
def test_mixed_orphaned_result_and_orphaned_call(self):
msgs = [
{"role": "assistant", "tool_calls": [
assistant_dict_call("c4"),
assistant_dict_call("c5"),
]},
tool_result("c4"),
tool_result("c_DANGLING"),
]
out = AIAgent._sanitize_api_messages(msgs)
ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"]
assert "c_DANGLING" not in ids
assert "c4" in ids
assert "c5" in ids
def test_empty_list_is_safe(self):
assert AIAgent._sanitize_api_messages([]) == []
def test_no_tool_messages(self):
msgs = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
out = AIAgent._sanitize_api_messages(msgs)
assert out == msgs
def test_sdk_object_tool_calls(self):
tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace(
name="terminal", arguments="{}"
))
msgs = [
{"role": "assistant", "tool_calls": [tc_obj]},
]
out = AIAgent._sanitize_api_messages(msgs)
assert len(out) == 2
assert out[1]["tool_call_id"] == "c6"
# ---------------------------------------------------------------------------
# Phase 2a — _cap_delegate_task_calls
# ---------------------------------------------------------------------------
class TestCapDelegateTaskCalls:
def test_excess_delegates_truncated(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
out = AIAgent._cap_delegate_task_calls(tcs)
delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task")
assert delegate_count == MAX_CONCURRENT_CHILDREN
def test_non_delegate_calls_preserved(self):
tcs = (
[make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)]
+ [make_tc("terminal"), make_tc("web_search")]
)
out = AIAgent._cap_delegate_task_calls(tcs)
names = [tc.function.name for tc in out]
assert "terminal" in names
assert "web_search" in names
def test_at_limit_passes_through(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)]
out = AIAgent._cap_delegate_task_calls(tcs)
assert out is tcs
def test_below_limit_passes_through(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)]
out = AIAgent._cap_delegate_task_calls(tcs)
assert out is tcs
def test_no_delegate_calls_unchanged(self):
tcs = [make_tc("terminal"), make_tc("web_search")]
out = AIAgent._cap_delegate_task_calls(tcs)
assert out is tcs
def test_empty_list_safe(self):
assert AIAgent._cap_delegate_task_calls([]) == []
def test_original_list_not_mutated(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
original_len = len(tcs)
AIAgent._cap_delegate_task_calls(tcs)
assert len(tcs) == original_len
def test_interleaved_order_preserved(self):
delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}')
for i in range(MAX_CONCURRENT_CHILDREN + 1)]
t1 = make_tc("terminal", '{"cmd":"ls"}')
w1 = make_tc("web_search", '{"q":"x"}')
tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:]
out = AIAgent._cap_delegate_task_calls(tcs)
expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN]
assert len(out) == len(expected)
for i, (actual, exp) in enumerate(zip(out, expected)):
assert actual is exp, f"mismatch at index {i}"
# ---------------------------------------------------------------------------
# Phase 2b — _deduplicate_tool_calls
# ---------------------------------------------------------------------------
class TestDeduplicateToolCalls:
def test_duplicate_pair_deduplicated(self):
tcs = [
make_tc("web_search", '{"query":"foo"}'),
make_tc("web_search", '{"query":"foo"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert len(out) == 1
def test_multiple_duplicates(self):
tcs = [
make_tc("web_search", '{"q":"a"}'),
make_tc("web_search", '{"q":"a"}'),
make_tc("terminal", '{"cmd":"ls"}'),
make_tc("terminal", '{"cmd":"ls"}'),
make_tc("terminal", '{"cmd":"pwd"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert len(out) == 3
def test_same_tool_different_args_kept(self):
tcs = [
make_tc("terminal", '{"cmd":"ls"}'),
make_tc("terminal", '{"cmd":"pwd"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert out is tcs
def test_different_tools_same_args_kept(self):
tcs = [
make_tc("tool_a", '{"x":1}'),
make_tc("tool_b", '{"x":1}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert out is tcs
def test_clean_list_unchanged(self):
tcs = [
make_tc("web_search", '{"q":"x"}'),
make_tc("terminal", '{"cmd":"ls"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert out is tcs
def test_empty_list_safe(self):
assert AIAgent._deduplicate_tool_calls([]) == []
def test_first_occurrence_kept(self):
tc1 = make_tc("terminal", '{"cmd":"ls"}')
tc2 = make_tc("terminal", '{"cmd":"ls"}')
out = AIAgent._deduplicate_tool_calls([tc1, tc2])
assert len(out) == 1
assert out[0] is tc1
def test_original_list_not_mutated(self):
tcs = [
make_tc("web_search", '{"q":"dup"}'),
make_tc("web_search", '{"q":"dup"}'),
]
original_len = len(tcs)
AIAgent._deduplicate_tool_calls(tcs)
assert len(tcs) == original_len
# ---------------------------------------------------------------------------
# _get_tool_call_id_static
# ---------------------------------------------------------------------------
class TestGetToolCallIdStatic:
def test_dict_with_valid_id(self):
assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123"
def test_dict_with_none_id(self):
assert AIAgent._get_tool_call_id_static({"id": None}) == ""
def test_dict_without_id_key(self):
assert AIAgent._get_tool_call_id_static({"function": {}}) == ""
def test_object_with_valid_id(self):
tc = types.SimpleNamespace(id="call_456")
assert AIAgent._get_tool_call_id_static(tc) == "call_456"
def test_object_with_none_id(self):
tc = types.SimpleNamespace(id=None)
assert AIAgent._get_tool_call_id_static(tc) == ""
def test_object_without_id_attr(self):
tc = types.SimpleNamespace()
assert AIAgent._get_tool_call_id_static(tc) == ""

View file

@ -98,11 +98,14 @@ class TestProviderRegistry:
# =============================================================================
PROVIDER_ENV_VARS = (
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN",
"CLAUDE_CODE_OAUTH_TOKEN",
"GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY",
"KIMI_API_KEY", "KIMI_BASE_URL", "MINIMAX_API_KEY", "MINIMAX_CN_API_KEY",
"AI_GATEWAY_API_KEY", "AI_GATEWAY_BASE_URL",
"KILOCODE_API_KEY", "KILOCODE_BASE_URL",
"DASHSCOPE_API_KEY", "OPENCODE_ZEN_API_KEY", "OPENCODE_GO_API_KEY",
"NOUS_API_KEY",
"OPENAI_BASE_URL",
)
@ -111,6 +114,7 @@ PROVIDER_ENV_VARS = (
def _clear_provider_env(monkeypatch):
for key in PROVIDER_ENV_VARS:
monkeypatch.delenv(key, raising=False)
monkeypatch.setattr("hermes_cli.auth._load_auth_store", lambda: {})
class TestResolveProvider:

View file

@ -43,6 +43,7 @@ class TestCLISubagentInterrupt(unittest.TestCase):
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
parent.model = "test/model"
parent.base_url = "http://localhost:1"
@ -112,21 +113,21 @@ class TestCLISubagentInterrupt(unittest.TestCase):
mock_instance._interrupt_requested = False
mock_instance._interrupt_message = None
mock_instance._active_children = []
mock_instance._active_children_lock = threading.Lock()
mock_instance.quiet_mode = True
mock_instance.run_conversation = mock_child_run_conversation
mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg)
mock_instance.tools = []
MockAgent.return_value = mock_instance
# Register child manually (normally done by _build_child_agent)
parent._active_children.append(mock_instance)
result = _run_single_child(
task_index=0,
goal="Do something slow",
context=None,
toolsets=["terminal"],
model=None,
max_iterations=50,
child=mock_instance,
parent_agent=parent,
task_count=1,
)
delegate_result[0] = result
except Exception as e:

View file

@ -16,6 +16,10 @@ def _make_cli(model: str = "anthropic/claude-sonnet-4-20250514"):
def _attach_agent(
cli_obj,
*,
input_tokens: int | None = None,
output_tokens: int | None = None,
cache_read_tokens: int = 0,
cache_write_tokens: int = 0,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
@ -26,6 +30,12 @@ def _attach_agent(
):
cli_obj.agent = SimpleNamespace(
model=cli_obj.model,
provider="anthropic" if cli_obj.model.startswith("anthropic/") else None,
base_url="",
session_input_tokens=input_tokens if input_tokens is not None else prompt_tokens,
session_output_tokens=output_tokens if output_tokens is not None else completion_tokens,
session_cache_read_tokens=cache_read_tokens,
session_cache_write_tokens=cache_write_tokens,
session_prompt_tokens=prompt_tokens,
session_completion_tokens=completion_tokens,
session_total_tokens=total_tokens,
@ -68,20 +78,19 @@ class TestCLIStatusBar:
assert "$0.06" not in text # cost hidden by default
assert "15m" in text
def test_build_status_bar_text_shows_cost_when_enabled(self):
def test_build_status_bar_text_no_cost_in_status_bar(self):
cli_obj = _attach_agent(
_make_cli(),
prompt_tokens=10000,
completion_tokens=2400,
total_tokens=12400,
completion_tokens=5000,
total_tokens=15000,
api_calls=7,
context_tokens=12400,
context_tokens=50000,
context_length=200_000,
)
cli_obj.show_cost = True
text = cli_obj._build_status_bar_text(width=120)
assert "$" in text # cost is shown when enabled
assert "$" not in text # cost is never shown in status bar
def test_build_status_bar_text_collapses_for_narrow_terminal(self):
cli_obj = _attach_agent(
@ -128,8 +137,8 @@ class TestCLIUsageReport:
output = capsys.readouterr().out
assert "Model:" in output
assert "Input cost:" in output
assert "Output cost:" in output
assert "Cost status:" in output
assert "Cost source:" in output
assert "Total cost:" in output
assert "$" in output
assert "0.064" in output

View file

@ -657,7 +657,7 @@ class TestSchemaInit:
def test_schema_version(self, db):
cursor = db._conn.execute("SELECT version FROM schema_version")
version = cursor.fetchone()[0]
assert version == 4
assert version == 5
def test_title_column_exists(self, db):
"""Verify the title column was created in the sessions table."""
@ -713,12 +713,12 @@ class TestSchemaInit:
conn.commit()
conn.close()
# Open with SessionDB — should migrate to v4
# Open with SessionDB — should migrate to v5
migrated_db = SessionDB(db_path=db_path)
# Verify migration
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
assert cursor.fetchone()[0] == 4
assert cursor.fetchone()[0] == 5
# Verify title column exists and is NULL for existing sessions
session = migrated_db.get_session("existing")

View file

@ -123,28 +123,16 @@ def populated_db(db):
# =========================================================================
class TestPricing:
def test_exact_match(self):
pricing = _get_pricing("gpt-4o")
assert pricing["input"] == 2.50
assert pricing["output"] == 10.00
def test_provider_prefix_stripped(self):
pricing = _get_pricing("anthropic/claude-sonnet-4-20250514")
assert pricing["input"] == 3.00
assert pricing["output"] == 15.00
def test_prefix_match(self):
pricing = _get_pricing("claude-3-5-sonnet-20241022")
assert pricing["input"] == 3.00
def test_keyword_heuristic_opus(self):
def test_unknown_models_do_not_use_heuristics(self):
pricing = _get_pricing("some-new-opus-model")
assert pricing["input"] == 15.00
assert pricing["output"] == 75.00
def test_keyword_heuristic_haiku(self):
assert pricing == _DEFAULT_PRICING
pricing = _get_pricing("anthropic/claude-haiku-future")
assert pricing["input"] == 0.80
assert pricing == _DEFAULT_PRICING
def test_unknown_model_returns_zero_cost(self):
"""Unknown/custom models should NOT have fabricated costs."""
@ -168,40 +156,12 @@ class TestPricing:
pricing = _get_pricing("")
assert pricing == _DEFAULT_PRICING
def test_deepseek_heuristic(self):
pricing = _get_pricing("deepseek-v3")
assert pricing["input"] == 0.14
def test_gemini_heuristic(self):
pricing = _get_pricing("gemini-3.0-ultra")
assert pricing["input"] == 0.15
def test_dated_model_gpt4o_mini(self):
"""gpt-4o-mini-2024-07-18 should match gpt-4o-mini, NOT gpt-4o."""
pricing = _get_pricing("gpt-4o-mini-2024-07-18")
assert pricing["input"] == 0.15 # gpt-4o-mini price, not gpt-4o's 2.50
def test_dated_model_o3_mini(self):
"""o3-mini-2025-01-31 should match o3-mini, NOT o3."""
pricing = _get_pricing("o3-mini-2025-01-31")
assert pricing["input"] == 1.10 # o3-mini price, not o3's 10.00
def test_dated_model_gpt41_mini(self):
"""gpt-4.1-mini-2025-04-14 should match gpt-4.1-mini, NOT gpt-4.1."""
pricing = _get_pricing("gpt-4.1-mini-2025-04-14")
assert pricing["input"] == 0.40 # gpt-4.1-mini, not gpt-4.1's 2.00
def test_dated_model_gpt41_nano(self):
"""gpt-4.1-nano-2025-04-14 should match gpt-4.1-nano, NOT gpt-4.1."""
pricing = _get_pricing("gpt-4.1-nano-2025-04-14")
assert pricing["input"] == 0.10 # gpt-4.1-nano, not gpt-4.1's 2.00
class TestHasKnownPricing:
def test_known_commercial_model(self):
assert _has_known_pricing("gpt-4o") is True
assert _has_known_pricing("gpt-4o", provider="openai") is True
assert _has_known_pricing("anthropic/claude-sonnet-4-20250514") is True
assert _has_known_pricing("deepseek-chat") is True
assert _has_known_pricing("gpt-4.1", provider="openai") is True
def test_unknown_custom_model(self):
assert _has_known_pricing("FP16_Hermes_4.5") is False
@ -210,26 +170,39 @@ class TestHasKnownPricing:
assert _has_known_pricing("") is False
assert _has_known_pricing(None) is False
def test_heuristic_matched_models(self):
"""Models matched by keyword heuristics should be considered known."""
assert _has_known_pricing("some-opus-model") is True
assert _has_known_pricing("future-sonnet-v2") is True
def test_heuristic_matched_models_are_not_considered_known(self):
assert _has_known_pricing("some-opus-model") is False
assert _has_known_pricing("future-sonnet-v2") is False
class TestEstimateCost:
def test_basic_cost(self):
# gpt-4o: 2.50/M input, 10.00/M output
cost = _estimate_cost("gpt-4o", 1_000_000, 1_000_000)
assert cost == pytest.approx(12.50, abs=0.01)
cost, status = _estimate_cost(
"anthropic/claude-sonnet-4-20250514",
1_000_000,
1_000_000,
provider="anthropic",
)
assert status == "estimated"
assert cost == pytest.approx(18.0, abs=0.01)
def test_zero_tokens(self):
cost = _estimate_cost("gpt-4o", 0, 0)
cost, status = _estimate_cost("gpt-4o", 0, 0, provider="openai")
assert status == "estimated"
assert cost == 0.0
def test_small_usage(self):
cost = _estimate_cost("gpt-4o", 1000, 500)
# 1000 * 2.50/1M + 500 * 10.00/1M = 0.0025 + 0.005 = 0.0075
assert cost == pytest.approx(0.0075, abs=0.0001)
def test_cache_aware_usage(self):
cost, status = _estimate_cost(
"anthropic/claude-sonnet-4-20250514",
1000,
500,
cache_read_tokens=2000,
cache_write_tokens=400,
provider="anthropic",
)
assert status == "estimated"
expected = (1000 * 3.0 + 500 * 15.0 + 2000 * 0.30 + 400 * 3.75) / 1_000_000
assert cost == pytest.approx(expected, abs=0.0001)
# =========================================================================
@ -660,8 +633,13 @@ class TestEdgeCases:
def test_mixed_commercial_and_custom_models(self, db):
"""Mix of commercial and custom models: only commercial ones get costs."""
db.create_session(session_id="s1", source="cli", model="gpt-4o")
db.update_token_counts("s1", input_tokens=10000, output_tokens=5000)
db.create_session(session_id="s1", source="cli", model="anthropic/claude-sonnet-4-20250514")
db.update_token_counts(
"s1",
input_tokens=10000,
output_tokens=5000,
billing_provider="anthropic",
)
db.create_session(session_id="s2", source="cli", model="my-local-llama")
db.update_token_counts("s2", input_tokens=10000, output_tokens=5000)
db._conn.commit()
@ -672,13 +650,13 @@ class TestEdgeCases:
# Cost should only come from gpt-4o, not from the custom model
overview = report["overview"]
assert overview["estimated_cost"] > 0
assert "gpt-4o" in overview["models_with_pricing"] # list now, not set
assert "claude-sonnet-4-20250514" in overview["models_with_pricing"] # list now, not set
assert "my-local-llama" in overview["models_without_pricing"]
# Verify individual model entries
gpt = next(m for m in report["models"] if m["model"] == "gpt-4o")
assert gpt["has_pricing"] is True
assert gpt["cost"] > 0
claude = next(m for m in report["models"] if m["model"] == "claude-sonnet-4-20250514")
assert claude["has_pricing"] is True
assert claude["cost"] > 0
llama = next(m for m in report["models"] if m["model"] == "my-local-llama")
assert llama["has_pricing"] is False

View file

@ -57,6 +57,7 @@ def main() -> int:
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
parent.model = "test/model"
parent.base_url = "http://localhost:1"

View file

@ -30,12 +30,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
parent._active_children.append(child)
@ -60,6 +62,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
child._interrupt_message = "msg"
child.quiet_mode = True
child._active_children = []
child._active_children_lock = threading.Lock()
# Global is set
set_interrupt(True)
@ -78,6 +81,7 @@ class TestInterruptPropagationToChild(unittest.TestCase):
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
child.api_mode = "chat_completions"
child.log_prefix = ""
@ -119,12 +123,14 @@ class TestInterruptPropagationToChild(unittest.TestCase):
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
# Register child (simulating what _run_single_child does)

View file

@ -47,6 +47,28 @@ class TestCLIQuickCommands:
args = cli.console.print.call_args[0][0]
assert "no output" in args.lower()
def test_alias_command_routes_to_target(self):
"""Alias quick commands rewrite to the target command."""
cli = self._make_cli({"shortcut": {"type": "alias", "target": "/help"}})
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
cli.process_command("/shortcut")
# Should recursively call process_command with /help
spy.assert_any_call("/help")
def test_alias_command_passes_args(self):
"""Alias quick commands forward user arguments to the target."""
cli = self._make_cli({"sc": {"type": "alias", "target": "/context"}})
with patch.object(cli, "process_command", wraps=cli.process_command) as spy:
cli.process_command("/sc some args")
spy.assert_any_call("/context some args")
def test_alias_no_target_shows_error(self):
cli = self._make_cli({"broken": {"type": "alias", "target": ""}})
cli.process_command("/broken")
cli.console.print.assert_called_once()
args = cli.console.print.call_args[0][0]
assert "no target defined" in args.lower()
def test_unsupported_type_shows_error(self):
cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}})
cli.process_command("/bad")

View file

@ -55,6 +55,7 @@ class TestRealSubagentInterrupt(unittest.TestCase):
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
parent.model = "test/model"
parent.base_url = "http://localhost:1"
@ -103,19 +104,28 @@ class TestRealSubagentInterrupt(unittest.TestCase):
return original_run(self_agent, *args, **kwargs)
with patch.object(AIAgent, 'run_conversation', patched_run):
# Build a real child agent (AIAgent is NOT patched here,
# only run_conversation and _build_system_prompt are)
child = AIAgent(
base_url="http://localhost:1",
api_key="test-key",
model="test/model",
provider="test",
api_mode="chat_completions",
max_iterations=5,
enabled_toolsets=["terminal"],
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
platform="cli",
)
child._delegate_depth = 1
parent._active_children.append(child)
result = _run_single_child(
task_index=0,
goal="Test task",
context=None,
toolsets=["terminal"],
model="test/model",
max_iterations=5,
child=child,
parent_agent=parent,
task_count=1,
override_provider="test",
override_base_url="http://localhost:1",
override_api_key="test",
override_api_mode="chat_completions",
)
result_holder[0] = result
except Exception as e:

View file

@ -12,6 +12,7 @@ Run with: python -m pytest tests/test_delegate.py -v
import json
import os
import sys
import threading
import unittest
from unittest.mock import MagicMock, patch
@ -44,6 +45,7 @@ def _make_mock_parent(depth=0):
parent._session_db = None
parent._delegate_depth = depth
parent._active_children = []
parent._active_children_lock = threading.Lock()
return parent
@ -722,7 +724,12 @@ class TestDelegationProviderIntegration(unittest.TestCase):
}
parent = _make_mock_parent(depth=0)
with patch("tools.delegate_tool._run_single_child") as mock_run:
# Patch _build_child_agent since credentials are now passed there
# (agents are built in the main thread before being handed to workers)
with patch("tools.delegate_tool._build_child_agent") as mock_build, \
patch("tools.delegate_tool._run_single_child") as mock_run:
mock_child = MagicMock()
mock_build.return_value = mock_child
mock_run.return_value = {
"task_index": 0, "status": "completed",
"summary": "Done", "api_calls": 1, "duration_seconds": 1.0
@ -731,7 +738,8 @@ class TestDelegationProviderIntegration(unittest.TestCase):
tasks = [{"goal": "Task A"}, {"goal": "Task B"}]
delegate_task(tasks=tasks, parent_agent=parent)
for call in mock_run.call_args_list:
self.assertEqual(mock_build.call_count, 2)
for call in mock_build.call_args_list:
self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout")
self.assertEqual(call.kwargs.get("override_provider"), "openrouter")
self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1")

View file

@ -0,0 +1,210 @@
"""Tests for probe_mcp_server_tools() in tools.mcp_tool."""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@pytest.fixture(autouse=True)
def _reset_mcp_state():
"""Ensure clean MCP module state before/after each test."""
import tools.mcp_tool as mcp
old_loop = mcp._mcp_loop
old_thread = mcp._mcp_thread
old_servers = dict(mcp._servers)
yield
mcp._servers.clear()
mcp._servers.update(old_servers)
mcp._mcp_loop = old_loop
mcp._mcp_thread = old_thread
class TestProbeMcpServerTools:
"""Tests for the lightweight probe_mcp_server_tools function."""
def test_returns_empty_when_mcp_not_available(self):
with patch("tools.mcp_tool._MCP_AVAILABLE", False):
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result == {}
def test_returns_empty_when_no_config(self):
with patch("tools.mcp_tool._load_mcp_config", return_value={}):
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result == {}
def test_returns_empty_when_all_servers_disabled(self):
config = {
"github": {"command": "npx", "enabled": False},
"slack": {"command": "npx", "enabled": "off"},
}
with patch("tools.mcp_tool._load_mcp_config", return_value=config):
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result == {}
def test_returns_tools_from_successful_server(self):
"""Successfully probed server returns its tools list."""
config = {
"github": {"command": "npx", "connect_timeout": 5},
}
mock_tool_1 = SimpleNamespace(name="create_issue", description="Create a new issue")
mock_tool_2 = SimpleNamespace(name="search_repos", description="Search repositories")
mock_server = MagicMock()
mock_server._tools = [mock_tool_1, mock_tool_2]
mock_server.shutdown = AsyncMock()
async def fake_connect(name, cfg):
return mock_server
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
patch("tools.mcp_tool._stop_mcp_loop"):
# Simulate running the async probe
def run_coro(coro, timeout=120):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
mock_run.side_effect = run_coro
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert "github" in result
assert len(result["github"]) == 2
assert result["github"][0] == ("create_issue", "Create a new issue")
assert result["github"][1] == ("search_repos", "Search repositories")
mock_server.shutdown.assert_awaited_once()
def test_failed_server_omitted_from_results(self):
"""Servers that fail to connect are silently skipped."""
config = {
"github": {"command": "npx", "connect_timeout": 5},
"broken": {"command": "nonexistent", "connect_timeout": 5},
}
mock_tool = SimpleNamespace(name="create_issue", description="Create")
mock_server = MagicMock()
mock_server._tools = [mock_tool]
mock_server.shutdown = AsyncMock()
async def fake_connect(name, cfg):
if name == "broken":
raise ConnectionError("Server not found")
return mock_server
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
patch("tools.mcp_tool._stop_mcp_loop"):
def run_coro(coro, timeout=120):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
mock_run.side_effect = run_coro
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert "github" in result
assert "broken" not in result
def test_handles_tool_without_description(self):
"""Tools without descriptions get empty string."""
config = {"github": {"command": "npx", "connect_timeout": 5}}
mock_tool = SimpleNamespace(name="my_tool") # no description attribute
mock_server = MagicMock()
mock_server._tools = [mock_tool]
mock_server.shutdown = AsyncMock()
async def fake_connect(name, cfg):
return mock_server
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
patch("tools.mcp_tool._stop_mcp_loop"):
def run_coro(coro, timeout=120):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
mock_run.side_effect = run_coro
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result["github"][0] == ("my_tool", "")
def test_cleanup_called_even_on_failure(self):
"""_stop_mcp_loop is called even when probe fails."""
config = {"github": {"command": "npx", "connect_timeout": 5}}
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop", side_effect=RuntimeError("boom")), \
patch("tools.mcp_tool._stop_mcp_loop") as mock_stop:
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert result == {}
mock_stop.assert_called_once()
def test_skips_disabled_servers(self):
"""Disabled servers are not probed."""
config = {
"github": {"command": "npx", "connect_timeout": 5},
"disabled_one": {"command": "npx", "enabled": False},
}
mock_tool = SimpleNamespace(name="create_issue", description="Create")
mock_server = MagicMock()
mock_server._tools = [mock_tool]
mock_server.shutdown = AsyncMock()
connect_calls = []
async def fake_connect(name, cfg):
connect_calls.append(name)
return mock_server
with patch("tools.mcp_tool._load_mcp_config", return_value=config), \
patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch("tools.mcp_tool._ensure_mcp_loop"), \
patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \
patch("tools.mcp_tool._stop_mcp_loop"):
def run_coro(coro, timeout=120):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
mock_run.side_effect = run_coro
from tools.mcp_tool import probe_mcp_server_tools
result = probe_mcp_server_tools()
assert "github" in result
assert "disabled_one" not in result
assert "disabled_one" not in connect_calls

View file

@ -2596,17 +2596,19 @@ class TestMCPSelectiveToolLoading:
async def run():
with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
patch.dict("tools.mcp_tool._servers", {}, clear=True), \
patch("tools.registry.registry", mock_registry), \
patch("toolsets.create_custom_toolset"):
return await _discover_and_register_server(
registered = await _discover_and_register_server(
"ink_existing",
{"url": "https://mcp.example.com", "tools": {"include": ["create_service"]}},
)
return registered, _existing_tool_names()
try:
registered = asyncio.run(run())
registered, existing = asyncio.run(run())
assert registered == ["mcp_ink_existing_create_service"]
assert _existing_tool_names() == ["mcp_ink_existing_create_service"]
assert existing == ["mcp_ink_existing_create_service"]
finally:
_servers.pop("ink_existing", None)

View file

@ -294,6 +294,61 @@ class TestCheckpoint:
recovered = registry.recover_from_checkpoint()
assert recovered == 0
def test_write_checkpoint_includes_watcher_metadata(self, registry, tmp_path):
with patch("tools.process_registry.CHECKPOINT_PATH", tmp_path / "procs.json"):
s = _make_session()
s.watcher_platform = "telegram"
s.watcher_chat_id = "999"
s.watcher_thread_id = "42"
s.watcher_interval = 60
registry._running[s.id] = s
registry._write_checkpoint()
data = json.loads((tmp_path / "procs.json").read_text())
assert len(data) == 1
assert data[0]["watcher_platform"] == "telegram"
assert data[0]["watcher_chat_id"] == "999"
assert data[0]["watcher_thread_id"] == "42"
assert data[0]["watcher_interval"] == 60
def test_recover_enqueues_watchers(self, registry, tmp_path):
checkpoint = tmp_path / "procs.json"
checkpoint.write_text(json.dumps([{
"session_id": "proc_live",
"command": "sleep 999",
"pid": os.getpid(), # current process — guaranteed alive
"task_id": "t1",
"session_key": "sk1",
"watcher_platform": "telegram",
"watcher_chat_id": "123",
"watcher_thread_id": "42",
"watcher_interval": 60,
}]))
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 1
assert len(registry.pending_watchers) == 1
w = registry.pending_watchers[0]
assert w["session_id"] == "proc_live"
assert w["platform"] == "telegram"
assert w["chat_id"] == "123"
assert w["thread_id"] == "42"
assert w["check_interval"] == 60
def test_recover_skips_watcher_when_no_interval(self, registry, tmp_path):
checkpoint = tmp_path / "procs.json"
checkpoint.write_text(json.dumps([{
"session_id": "proc_live",
"command": "sleep 999",
"pid": os.getpid(),
"task_id": "t1",
"watcher_interval": 0,
}]))
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 1
assert len(registry.pending_watchers) == 0
# =========================================================================
# Kill process

View file

@ -25,7 +25,7 @@ def _make_config():
def _install_telegram_mock(monkeypatch, bot):
parse_mode = SimpleNamespace(MARKDOWN_V2="MarkdownV2")
parse_mode = SimpleNamespace(MARKDOWN_V2="MarkdownV2", HTML="HTML")
constants_mod = SimpleNamespace(ParseMode=parse_mode)
telegram_mod = SimpleNamespace(Bot=lambda token: bot, constants=constants_mod)
monkeypatch.setitem(sys.modules, "telegram", telegram_mod)
@ -391,3 +391,97 @@ class TestSendToPlatformChunking:
assert len(sent_calls) >= 3
assert all(call == [] for call in sent_calls[:-1])
assert sent_calls[-1] == media
# ---------------------------------------------------------------------------
# HTML auto-detection in Telegram send
# ---------------------------------------------------------------------------
class TestSendTelegramHtmlDetection:
"""Verify that messages containing HTML tags are sent with parse_mode=HTML
and that plain / markdown messages use MarkdownV2."""
def _make_bot(self):
bot = MagicMock()
bot.send_message = AsyncMock(return_value=SimpleNamespace(message_id=1))
bot.send_photo = AsyncMock()
bot.send_video = AsyncMock()
bot.send_voice = AsyncMock()
bot.send_audio = AsyncMock()
bot.send_document = AsyncMock()
return bot
def test_html_message_uses_html_parse_mode(self, monkeypatch):
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
asyncio.run(
_send_telegram("tok", "123", "<b>Hello</b> world")
)
bot.send_message.assert_awaited_once()
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "HTML"
assert kwargs["text"] == "<b>Hello</b> world"
def test_plain_text_uses_markdown_v2(self, monkeypatch):
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
asyncio.run(
_send_telegram("tok", "123", "Just plain text, no tags")
)
bot.send_message.assert_awaited_once()
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "MarkdownV2"
def test_html_with_code_and_pre_tags(self, monkeypatch):
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
html = "<pre>code block</pre> and <code>inline</code>"
asyncio.run(_send_telegram("tok", "123", html))
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "HTML"
def test_closing_tag_detected(self, monkeypatch):
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
asyncio.run(_send_telegram("tok", "123", "text </div> more"))
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "HTML"
def test_angle_brackets_in_math_not_detected(self, monkeypatch):
"""Expressions like 'x < 5' or '3 > 2' should not trigger HTML mode."""
bot = self._make_bot()
_install_telegram_mock(monkeypatch, bot)
asyncio.run(_send_telegram("tok", "123", "if x < 5 then y > 2"))
kwargs = bot.send_message.await_args.kwargs
assert kwargs["parse_mode"] == "MarkdownV2"
def test_html_parse_failure_falls_back_to_plain(self, monkeypatch):
"""If Telegram rejects the HTML, fall back to plain text."""
bot = self._make_bot()
bot.send_message = AsyncMock(
side_effect=[
Exception("Bad Request: can't parse entities: unsupported html tag"),
SimpleNamespace(message_id=2), # plain fallback succeeds
]
)
_install_telegram_mock(monkeypatch, bot)
result = asyncio.run(
_send_telegram("tok", "123", "<invalid>broken html</invalid>")
)
assert result["success"] is True
assert bot.send_message.await_count == 2
second_call = bot.send_message.await_args_list[1].kwargs
assert second_call["parse_mode"] is None

View file

@ -1,8 +1,11 @@
"""Tests for Firecrawl client configuration and singleton behavior.
"""Tests for web backend client configuration and singleton behavior.
Coverage:
_get_firecrawl_client() configuration matrix, singleton caching,
constructor failure recovery, return value verification, edge cases.
_get_backend() backend selection logic with env var combinations.
_get_parallel_client() Parallel client configuration, singleton caching.
check_web_api_key() unified availability check.
"""
import os
@ -117,3 +120,212 @@ class TestFirecrawlClientConfig:
from tools.web_tools import _get_firecrawl_client
with pytest.raises(ValueError):
_get_firecrawl_client()
class TestBackendSelection:
"""Test suite for _get_backend() backend selection logic.
The backend is configured via config.yaml (web.backend), set by
``hermes tools``. Falls back to key-based detection for legacy/manual
setups.
"""
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", "TAVILY_API_KEY")
def setup_method(self):
for key in self._ENV_KEYS:
os.environ.pop(key, None)
def teardown_method(self):
for key in self._ENV_KEYS:
os.environ.pop(key, None)
# ── Config-based selection (web.backend in config.yaml) ───────────
def test_config_parallel(self):
"""web.backend=parallel in config → 'parallel' regardless of keys."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "parallel"}):
assert _get_backend() == "parallel"
def test_config_firecrawl(self):
"""web.backend=firecrawl in config → 'firecrawl' even if Parallel key set."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "firecrawl"}), \
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
assert _get_backend() == "firecrawl"
def test_config_tavily(self):
"""web.backend=tavily in config → 'tavily' regardless of other keys."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}):
assert _get_backend() == "tavily"
def test_config_tavily_overrides_env_keys(self):
"""web.backend=tavily in config → 'tavily' even if Firecrawl key set."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}), \
patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
assert _get_backend() == "tavily"
def test_config_case_insensitive(self):
"""web.backend=Parallel (mixed case) → 'parallel'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "Parallel"}):
assert _get_backend() == "parallel"
def test_config_tavily_case_insensitive(self):
"""web.backend=Tavily (mixed case) → 'tavily'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "Tavily"}):
assert _get_backend() == "tavily"
# ── Fallback (no web.backend in config) ───────────────────────────
def test_fallback_parallel_only_key(self):
"""Only PARALLEL_API_KEY set → 'parallel'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
assert _get_backend() == "parallel"
def test_fallback_tavily_only_key(self):
"""Only TAVILY_API_KEY set → 'tavily'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}):
assert _get_backend() == "tavily"
def test_fallback_tavily_with_firecrawl_prefers_firecrawl(self):
"""Tavily + Firecrawl keys, no config → 'firecrawl' (backward compat)."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test", "FIRECRAWL_API_KEY": "fc-test"}):
assert _get_backend() == "firecrawl"
def test_fallback_tavily_with_parallel_prefers_parallel(self):
"""Tavily + Parallel keys, no config → 'parallel' (Parallel takes priority over Tavily)."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test", "PARALLEL_API_KEY": "par-test"}):
# Parallel + no Firecrawl → parallel
assert _get_backend() == "parallel"
def test_fallback_both_keys_defaults_to_firecrawl(self):
"""Both keys set, no config → 'firecrawl' (backward compat)."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key", "FIRECRAWL_API_KEY": "fc-test"}):
assert _get_backend() == "firecrawl"
def test_fallback_firecrawl_only_key(self):
"""Only FIRECRAWL_API_KEY set → 'firecrawl'."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}), \
patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
assert _get_backend() == "firecrawl"
def test_fallback_no_keys_defaults_to_firecrawl(self):
"""No keys, no config → 'firecrawl' (will fail at client init)."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={}):
assert _get_backend() == "firecrawl"
def test_invalid_config_falls_through_to_fallback(self):
"""web.backend=invalid → ignored, uses key-based fallback."""
from tools.web_tools import _get_backend
with patch("tools.web_tools._load_web_config", return_value={"backend": "nonexistent"}), \
patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
assert _get_backend() == "parallel"
class TestParallelClientConfig:
"""Test suite for Parallel client initialization."""
def setup_method(self):
import tools.web_tools
tools.web_tools._parallel_client = None
os.environ.pop("PARALLEL_API_KEY", None)
def teardown_method(self):
import tools.web_tools
tools.web_tools._parallel_client = None
os.environ.pop("PARALLEL_API_KEY", None)
def test_creates_client_with_key(self):
"""PARALLEL_API_KEY set → creates Parallel client."""
with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
from tools.web_tools import _get_parallel_client
from parallel import Parallel
client = _get_parallel_client()
assert client is not None
assert isinstance(client, Parallel)
def test_no_key_raises_with_helpful_message(self):
"""No PARALLEL_API_KEY → ValueError with guidance."""
from tools.web_tools import _get_parallel_client
with pytest.raises(ValueError, match="PARALLEL_API_KEY"):
_get_parallel_client()
def test_singleton_returns_same_instance(self):
"""Second call returns cached client."""
with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
from tools.web_tools import _get_parallel_client
client1 = _get_parallel_client()
client2 = _get_parallel_client()
assert client1 is client2
class TestCheckWebApiKey:
"""Test suite for check_web_api_key() unified availability check."""
_ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", "TAVILY_API_KEY")
def setup_method(self):
for key in self._ENV_KEYS:
os.environ.pop(key, None)
def teardown_method(self):
for key in self._ENV_KEYS:
os.environ.pop(key, None)
def test_parallel_key_only(self):
with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_firecrawl_key_only(self):
with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_firecrawl_url_only(self):
with patch.dict(os.environ, {"FIRECRAWL_API_URL": "http://localhost:3002"}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_tavily_key_only(self):
with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_no_keys_returns_false(self):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is False
def test_both_keys_returns_true(self):
with patch.dict(os.environ, {
"PARALLEL_API_KEY": "test-key",
"FIRECRAWL_API_KEY": "fc-test",
}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True
def test_all_three_keys_returns_true(self):
with patch.dict(os.environ, {
"PARALLEL_API_KEY": "test-key",
"FIRECRAWL_API_KEY": "fc-test",
"TAVILY_API_KEY": "tvly-test",
}):
from tools.web_tools import check_web_api_key
assert check_web_api_key() is True

View file

@ -0,0 +1,255 @@
"""Tests for Tavily web backend integration.
Coverage:
_tavily_request() API key handling, endpoint construction, error propagation.
_normalize_tavily_search_results() search response normalization.
_normalize_tavily_documents() extract/crawl response normalization, failed_results.
web_search_tool / web_extract_tool / web_crawl_tool Tavily dispatch paths.
"""
import json
import os
import asyncio
import pytest
from unittest.mock import patch, MagicMock
# ─── _tavily_request ─────────────────────────────────────────────────────────
class TestTavilyRequest:
"""Test suite for the _tavily_request helper."""
def test_raises_without_api_key(self):
"""No TAVILY_API_KEY → ValueError with guidance."""
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("TAVILY_API_KEY", None)
from tools.web_tools import _tavily_request
with pytest.raises(ValueError, match="TAVILY_API_KEY"):
_tavily_request("search", {"query": "test"})
def test_posts_with_api_key_in_body(self):
"""api_key is injected into the JSON payload."""
mock_response = MagicMock()
mock_response.json.return_value = {"results": []}
mock_response.raise_for_status = MagicMock()
with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test-key"}):
with patch("tools.web_tools.httpx.post", return_value=mock_response) as mock_post:
from tools.web_tools import _tavily_request
result = _tavily_request("search", {"query": "hello"})
mock_post.assert_called_once()
call_kwargs = mock_post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert payload["api_key"] == "tvly-test-key"
assert payload["query"] == "hello"
assert "api.tavily.com/search" in call_kwargs.args[0]
def test_raises_on_http_error(self):
"""Non-2xx responses propagate as httpx.HTTPStatusError."""
import httpx as _httpx
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = _httpx.HTTPStatusError(
"401 Unauthorized", request=MagicMock(), response=mock_response
)
with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-bad-key"}):
with patch("tools.web_tools.httpx.post", return_value=mock_response):
from tools.web_tools import _tavily_request
with pytest.raises(_httpx.HTTPStatusError):
_tavily_request("search", {"query": "test"})
# ─── _normalize_tavily_search_results ─────────────────────────────────────────
class TestNormalizeTavilySearchResults:
"""Test search result normalization."""
def test_basic_normalization(self):
from tools.web_tools import _normalize_tavily_search_results
raw = {
"results": [
{"title": "Python Docs", "url": "https://docs.python.org", "content": "Official docs", "score": 0.9},
{"title": "Tutorial", "url": "https://example.com", "content": "A tutorial", "score": 0.8},
]
}
result = _normalize_tavily_search_results(raw)
assert result["success"] is True
web = result["data"]["web"]
assert len(web) == 2
assert web[0]["title"] == "Python Docs"
assert web[0]["url"] == "https://docs.python.org"
assert web[0]["description"] == "Official docs"
assert web[0]["position"] == 1
assert web[1]["position"] == 2
def test_empty_results(self):
from tools.web_tools import _normalize_tavily_search_results
result = _normalize_tavily_search_results({"results": []})
assert result["success"] is True
assert result["data"]["web"] == []
def test_missing_fields(self):
from tools.web_tools import _normalize_tavily_search_results
result = _normalize_tavily_search_results({"results": [{}]})
web = result["data"]["web"]
assert web[0]["title"] == ""
assert web[0]["url"] == ""
assert web[0]["description"] == ""
# ─── _normalize_tavily_documents ──────────────────────────────────────────────
class TestNormalizeTavilyDocuments:
"""Test extract/crawl document normalization."""
def test_basic_document(self):
from tools.web_tools import _normalize_tavily_documents
raw = {
"results": [{
"url": "https://example.com",
"title": "Example",
"raw_content": "Full page content here",
}]
}
docs = _normalize_tavily_documents(raw)
assert len(docs) == 1
assert docs[0]["url"] == "https://example.com"
assert docs[0]["title"] == "Example"
assert docs[0]["content"] == "Full page content here"
assert docs[0]["raw_content"] == "Full page content here"
assert docs[0]["metadata"]["sourceURL"] == "https://example.com"
def test_falls_back_to_content_when_no_raw_content(self):
from tools.web_tools import _normalize_tavily_documents
raw = {"results": [{"url": "https://example.com", "content": "Snippet"}]}
docs = _normalize_tavily_documents(raw)
assert docs[0]["content"] == "Snippet"
def test_failed_results_included(self):
from tools.web_tools import _normalize_tavily_documents
raw = {
"results": [],
"failed_results": [
{"url": "https://fail.com", "error": "timeout"},
],
}
docs = _normalize_tavily_documents(raw)
assert len(docs) == 1
assert docs[0]["url"] == "https://fail.com"
assert docs[0]["error"] == "timeout"
assert docs[0]["content"] == ""
def test_failed_urls_included(self):
from tools.web_tools import _normalize_tavily_documents
raw = {
"results": [],
"failed_urls": ["https://bad.com"],
}
docs = _normalize_tavily_documents(raw)
assert len(docs) == 1
assert docs[0]["url"] == "https://bad.com"
assert docs[0]["error"] == "extraction failed"
def test_fallback_url(self):
from tools.web_tools import _normalize_tavily_documents
raw = {"results": [{"content": "data"}]}
docs = _normalize_tavily_documents(raw, fallback_url="https://fallback.com")
assert docs[0]["url"] == "https://fallback.com"
# ─── web_search_tool (Tavily dispatch) ────────────────────────────────────────
class TestWebSearchTavily:
"""Test web_search_tool dispatch to Tavily."""
def test_search_dispatches_to_tavily(self):
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [{"title": "Result", "url": "https://r.com", "content": "desc", "score": 0.9}]
}
mock_response.raise_for_status = MagicMock()
with patch("tools.web_tools._get_backend", return_value="tavily"), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
patch("tools.web_tools.httpx.post", return_value=mock_response), \
patch("tools.interrupt.is_interrupted", return_value=False):
from tools.web_tools import web_search_tool
result = json.loads(web_search_tool("test query", limit=3))
assert result["success"] is True
assert len(result["data"]["web"]) == 1
assert result["data"]["web"][0]["title"] == "Result"
# ─── web_extract_tool (Tavily dispatch) ───────────────────────────────────────
class TestWebExtractTavily:
"""Test web_extract_tool dispatch to Tavily."""
def test_extract_dispatches_to_tavily(self):
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [{"url": "https://example.com", "raw_content": "Extracted content", "title": "Page"}]
}
mock_response.raise_for_status = MagicMock()
with patch("tools.web_tools._get_backend", return_value="tavily"), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
patch("tools.web_tools.httpx.post", return_value=mock_response), \
patch("tools.web_tools.process_content_with_llm", return_value=None):
from tools.web_tools import web_extract_tool
result = json.loads(asyncio.get_event_loop().run_until_complete(
web_extract_tool(["https://example.com"], use_llm_processing=False)
))
assert "results" in result
assert len(result["results"]) == 1
assert result["results"][0]["url"] == "https://example.com"
# ─── web_crawl_tool (Tavily dispatch) ─────────────────────────────────────────
class TestWebCrawlTavily:
"""Test web_crawl_tool dispatch to Tavily."""
def test_crawl_dispatches_to_tavily(self):
mock_response = MagicMock()
mock_response.json.return_value = {
"results": [
{"url": "https://example.com/page1", "raw_content": "Page 1 content", "title": "Page 1"},
{"url": "https://example.com/page2", "raw_content": "Page 2 content", "title": "Page 2"},
]
}
mock_response.raise_for_status = MagicMock()
with patch("tools.web_tools._get_backend", return_value="tavily"), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
patch("tools.web_tools.httpx.post", return_value=mock_response), \
patch("tools.web_tools.check_website_access", return_value=None), \
patch("tools.interrupt.is_interrupted", return_value=False):
from tools.web_tools import web_crawl_tool
result = json.loads(asyncio.get_event_loop().run_until_complete(
web_crawl_tool("https://example.com", use_llm_processing=False)
))
assert "results" in result
assert len(result["results"]) == 2
assert result["results"][0]["title"] == "Page 1"
def test_crawl_sends_instructions(self):
"""Instructions are included in the Tavily crawl payload."""
mock_response = MagicMock()
mock_response.json.return_value = {"results": []}
mock_response.raise_for_status = MagicMock()
with patch("tools.web_tools._get_backend", return_value="tavily"), \
patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \
patch("tools.web_tools.httpx.post", return_value=mock_response) as mock_post, \
patch("tools.web_tools.check_website_access", return_value=None), \
patch("tools.interrupt.is_interrupted", return_value=False):
from tools.web_tools import web_crawl_tool
asyncio.get_event_loop().run_until_complete(
web_crawl_tool("https://example.com", instructions="Find docs", use_llm_processing=False)
)
call_kwargs = mock_post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert payload["instructions"] == "Find docs"
assert payload["url"] == "https://example.com"

View file

@ -0,0 +1,495 @@
import json
from pathlib import Path
import pytest
import yaml
from tools.website_policy import WebsitePolicyError, check_website_access, load_website_blocklist
def test_load_website_blocklist_merges_config_and_shared_file(tmp_path):
shared = tmp_path / "community-blocklist.txt"
shared.write_text("# comment\nexample.org\nsub.bad.net\n", encoding="utf-8")
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["example.com", "https://www.evil.test/path"],
"shared_files": [str(shared)],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
policy = load_website_blocklist(config_path)
assert policy["enabled"] is True
assert {rule["pattern"] for rule in policy["rules"]} == {
"example.com",
"evil.test",
"example.org",
"sub.bad.net",
}
def test_check_website_access_matches_parent_domain_subdomains(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["example.com"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
blocked = check_website_access("https://docs.example.com/page", config_path=config_path)
assert blocked is not None
assert blocked["host"] == "docs.example.com"
assert blocked["rule"] == "example.com"
def test_check_website_access_supports_wildcard_subdomains_only(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["*.tracking.example"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
assert check_website_access("https://a.tracking.example", config_path=config_path) is not None
assert check_website_access("https://www.tracking.example", config_path=config_path) is not None
assert check_website_access("https://tracking.example", config_path=config_path) is None
def test_default_config_exposes_website_blocklist_shape():
from hermes_cli.config import DEFAULT_CONFIG
website_blocklist = DEFAULT_CONFIG["security"]["website_blocklist"]
assert website_blocklist["enabled"] is False
assert website_blocklist["domains"] == []
assert website_blocklist["shared_files"] == []
def test_load_website_blocklist_uses_enabled_default_when_section_missing(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.safe_dump({"display": {"tool_progress": "all"}}, sort_keys=False), encoding="utf-8")
policy = load_website_blocklist(config_path)
assert policy == {"enabled": False, "rules": []}
def test_load_website_blocklist_raises_clean_error_for_invalid_domains_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": "example.com",
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
with pytest.raises(WebsitePolicyError, match="security.website_blocklist.domains must be a list"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_shared_files_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"shared_files": "community-blocklist.txt",
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
with pytest.raises(WebsitePolicyError, match="security.website_blocklist.shared_files must be a list"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_top_level_config_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.safe_dump(["not", "a", "mapping"], sort_keys=False), encoding="utf-8")
with pytest.raises(WebsitePolicyError, match="config root must be a mapping"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_security_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(yaml.safe_dump({"security": []}, sort_keys=False), encoding="utf-8")
with pytest.raises(WebsitePolicyError, match="security must be a mapping"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_website_blocklist_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": "block everything",
}
},
sort_keys=False,
),
encoding="utf-8",
)
with pytest.raises(WebsitePolicyError, match="security.website_blocklist must be a mapping"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_invalid_enabled_type(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": "false",
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
with pytest.raises(WebsitePolicyError, match="security.website_blocklist.enabled must be a boolean"):
load_website_blocklist(config_path)
def test_load_website_blocklist_raises_clean_error_for_malformed_yaml(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text("security: [oops\n", encoding="utf-8")
with pytest.raises(WebsitePolicyError, match="Invalid config YAML"):
load_website_blocklist(config_path)
def test_load_website_blocklist_wraps_shared_file_read_errors(tmp_path, monkeypatch):
shared = tmp_path / "community-blocklist.txt"
shared.write_text("example.org\n", encoding="utf-8")
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"shared_files": [str(shared)],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
def failing_read_text(self, *args, **kwargs):
raise PermissionError("no permission")
monkeypatch.setattr(Path, "read_text", failing_read_text)
# Unreadable shared files are now warned and skipped (not raised),
# so the blocklist loads successfully but without those rules.
result = load_website_blocklist(config_path)
assert result["enabled"] is True
assert result["rules"] == [] # shared file rules skipped
def test_check_website_access_uses_dynamic_hermes_home(monkeypatch, tmp_path):
hermes_home = tmp_path / "hermes-home"
hermes_home.mkdir()
(hermes_home / "config.yaml").write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["dynamic.example"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
blocked = check_website_access("https://dynamic.example/path")
assert blocked is not None
assert blocked["rule"] == "dynamic.example"
def test_check_website_access_blocks_scheme_less_urls(tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"domains": ["blocked.test"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
blocked = check_website_access("www.blocked.test/path", config_path=config_path)
assert blocked is not None
assert blocked["host"] == "www.blocked.test"
assert blocked["rule"] == "blocked.test"
def test_browser_navigate_returns_policy_block(monkeypatch):
from tools import browser_tool
monkeypatch.setattr(
browser_tool,
"check_website_access",
lambda url: {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
},
)
monkeypatch.setattr(
browser_tool,
"_run_browser_command",
lambda *args, **kwargs: pytest.fail("browser command should not run for blocked URL"),
)
result = json.loads(browser_tool.browser_navigate("https://blocked.test"))
assert result["success"] is False
assert result["blocked_by_policy"]["rule"] == "blocked.test"
def test_browser_navigate_allows_when_shared_file_missing(monkeypatch, tmp_path):
"""Missing shared blocklist files are warned and skipped, not fatal."""
from tools import browser_tool
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump(
{
"security": {
"website_blocklist": {
"enabled": True,
"shared_files": ["missing-blocklist.txt"],
}
}
},
sort_keys=False,
),
encoding="utf-8",
)
# check_website_access should return None (allow) — missing file is skipped
result = check_website_access("https://allowed.test", config_path=config_path)
assert result is None
@pytest.mark.asyncio
async def test_web_extract_short_circuits_blocked_url(monkeypatch):
from tools import web_tools
monkeypatch.setattr(
web_tools,
"check_website_access",
lambda url: {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
},
)
monkeypatch.setattr(
web_tools,
"_get_firecrawl_client",
lambda: pytest.fail("firecrawl should not run for blocked URL"),
)
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
result = json.loads(await web_tools.web_extract_tool(["https://blocked.test"], use_llm_processing=False))
assert result["results"][0]["url"] == "https://blocked.test"
assert "Blocked by website policy" in result["results"][0]["error"]
def test_check_website_access_fails_open_on_malformed_config(tmp_path, monkeypatch):
"""Malformed config with default path should fail open (return None), not crash."""
config_path = tmp_path / "config.yaml"
config_path.write_text("security: [oops\n", encoding="utf-8")
# With explicit config_path (test mode), errors propagate
with pytest.raises(WebsitePolicyError):
check_website_access("https://example.com", config_path=config_path)
# Simulate default path by pointing HERMES_HOME to tmp_path
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools import website_policy
website_policy.invalidate_cache()
# With default path, errors are caught and fail open
result = check_website_access("https://example.com")
assert result is None # allowed, not crashed
@pytest.mark.asyncio
async def test_web_extract_blocks_redirected_final_url(monkeypatch):
from tools import web_tools
def fake_check(url):
if url == "https://allowed.test":
return None
if url == "https://blocked.test/final":
return {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
}
pytest.fail(f"unexpected URL checked: {url}")
class FakeFirecrawlClient:
def scrape(self, url, formats):
return {
"markdown": "secret content",
"metadata": {
"title": "Redirected",
"sourceURL": "https://blocked.test/final",
},
}
monkeypatch.setattr(web_tools, "check_website_access", fake_check)
monkeypatch.setattr(web_tools, "_get_firecrawl_client", lambda: FakeFirecrawlClient())
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
result = json.loads(await web_tools.web_extract_tool(["https://allowed.test"], use_llm_processing=False))
assert result["results"][0]["url"] == "https://blocked.test/final"
assert result["results"][0]["content"] == ""
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
@pytest.mark.asyncio
async def test_web_crawl_short_circuits_blocked_url(monkeypatch):
from tools import web_tools
# web_crawl_tool checks for Firecrawl env before website policy
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
monkeypatch.setattr(
web_tools,
"check_website_access",
lambda url: {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
},
)
monkeypatch.setattr(
web_tools,
"_get_firecrawl_client",
lambda: pytest.fail("firecrawl should not run for blocked crawl URL"),
)
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
result = json.loads(await web_tools.web_crawl_tool("https://blocked.test", use_llm_processing=False))
assert result["results"][0]["url"] == "https://blocked.test"
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
@pytest.mark.asyncio
async def test_web_crawl_blocks_redirected_final_url(monkeypatch):
from tools import web_tools
# web_crawl_tool checks for Firecrawl env before website policy
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
def fake_check(url):
if url == "https://allowed.test":
return None
if url == "https://blocked.test/final":
return {
"host": "blocked.test",
"rule": "blocked.test",
"source": "config",
"message": "Blocked by website policy",
}
pytest.fail(f"unexpected URL checked: {url}")
class FakeCrawlClient:
def crawl(self, url, **kwargs):
return {
"data": [
{
"markdown": "secret crawl content",
"metadata": {
"title": "Redirected crawl page",
"sourceURL": "https://blocked.test/final",
},
}
]
}
monkeypatch.setattr(web_tools, "check_website_access", fake_check)
monkeypatch.setattr(web_tools, "_get_firecrawl_client", lambda: FakeCrawlClient())
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
result = json.loads(await web_tools.web_crawl_tool("https://allowed.test", use_llm_processing=False))
assert result["results"][0]["content"] == ""
assert result["results"][0]["error"] == "Blocked by website policy"
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"

View file

@ -65,6 +65,11 @@ import requests
from typing import Dict, Any, Optional, List
from pathlib import Path
from agent.auxiliary_client import call_llm
try:
from tools.website_policy import check_website_access
except Exception:
check_website_access = lambda url: None # noqa: E731 — fail-open if policy module unavailable
from tools.browser_providers.base import CloudBrowserProvider
from tools.browser_providers.browserbase import BrowserbaseProvider
from tools.browser_providers.browser_use import BrowserUseProvider
@ -550,6 +555,11 @@ def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
session_info = provider.create_session(task_id)
with _cleanup_lock:
# Double-check: another thread may have created a session while we
# were doing the network call. Use the existing one to avoid leaking
# orphan cloud sessions.
if task_id in _active_sessions:
return _active_sessions[task_id]
_active_sessions[task_id] = session_info
return session_info
@ -901,6 +911,15 @@ def browser_navigate(url: str, task_id: Optional[str] = None) -> str:
Returns:
JSON string with navigation result (includes stealth features info on first nav)
"""
# Website policy check — block before navigating
blocked = check_website_access(url)
if blocked:
return json.dumps({
"success": False,
"error": blocked["message"],
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
})
effective_task_id = task_id or "default"
# Get session info to check if this is a new session

View file

@ -16,13 +16,10 @@ The parent's context only sees the delegation call and the summary result,
never the child's intermediate tool calls or reasoning.
"""
import contextlib
import io
import json
import logging
logger = logging.getLogger(__name__)
import os
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional
@ -150,7 +147,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
return _callback
def _run_single_child(
def _build_child_agent(
task_index: int,
goal: str,
context: Optional[str],
@ -158,16 +155,15 @@ def _run_single_child(
model: Optional[str],
max_iterations: int,
parent_agent,
task_count: int = 1,
# Credential overrides from delegation config (provider:model resolution)
override_provider: Optional[str] = None,
override_base_url: Optional[str] = None,
override_api_key: Optional[str] = None,
override_api_mode: Optional[str] = None,
) -> Dict[str, Any]:
):
"""
Spawn and run a single child agent. Called from within a thread.
Returns a structured result dict.
Build a child AIAgent on the main thread (thread-safe construction).
Returns the constructed child agent without running it.
When override_* params are set (from delegation config), the child uses
those credentials instead of inheriting from the parent. This enables
@ -176,8 +172,6 @@ def _run_single_child(
"""
from run_agent import AIAgent
child_start = time.monotonic()
# When no explicit toolsets given, inherit from parent's enabled toolsets
# so disabled tools (e.g. web) don't leak to subagents.
if toolsets:
@ -188,65 +182,84 @@ def _run_single_child(
child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS)
child_prompt = _build_child_system_prompt(goal, context)
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
parent_api_key = getattr(parent_agent, "api_key", None)
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
parent_api_key = parent_agent._client_kwargs.get("api_key")
try:
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
parent_api_key = getattr(parent_agent, "api_key", None)
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
parent_api_key = parent_agent._client_kwargs.get("api_key")
# Build progress callback to relay tool calls to parent display
child_progress_cb = _build_child_progress_callback(task_index, parent_agent)
# Build progress callback to relay tool calls to parent display
child_progress_cb = _build_child_progress_callback(task_index, parent_agent, task_count)
# Share the parent's iteration budget so subagent tool calls
# count toward the session-wide limit.
shared_budget = getattr(parent_agent, "iteration_budget", None)
# Share the parent's iteration budget so subagent tool calls
# count toward the session-wide limit.
shared_budget = getattr(parent_agent, "iteration_budget", None)
# Resolve effective credentials: config override > parent inherit
effective_model = model or parent_agent.model
effective_provider = override_provider or getattr(parent_agent, "provider", None)
effective_base_url = override_base_url or parent_agent.base_url
effective_api_key = override_api_key or parent_api_key
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None)
# Resolve effective credentials: config override > parent inherit
effective_model = model or parent_agent.model
effective_provider = override_provider or getattr(parent_agent, "provider", None)
effective_base_url = override_base_url or parent_agent.base_url
effective_api_key = override_api_key or parent_api_key
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None)
child = AIAgent(
base_url=effective_base_url,
api_key=effective_api_key,
model=effective_model,
provider=effective_provider,
api_mode=effective_api_mode,
max_iterations=max_iterations,
max_tokens=getattr(parent_agent, "max_tokens", None),
reasoning_config=getattr(parent_agent, "reasoning_config", None),
prefill_messages=getattr(parent_agent, "prefill_messages", None),
enabled_toolsets=child_toolsets,
quiet_mode=True,
ephemeral_system_prompt=child_prompt,
log_prefix=f"[subagent-{task_index}]",
platform=parent_agent.platform,
skip_context_files=True,
skip_memory=True,
clarify_callback=None,
session_db=getattr(parent_agent, '_session_db', None),
providers_allowed=parent_agent.providers_allowed,
providers_ignored=parent_agent.providers_ignored,
providers_order=parent_agent.providers_order,
provider_sort=parent_agent.provider_sort,
tool_progress_callback=child_progress_cb,
iteration_budget=shared_budget,
)
child = AIAgent(
base_url=effective_base_url,
api_key=effective_api_key,
model=effective_model,
provider=effective_provider,
api_mode=effective_api_mode,
max_iterations=max_iterations,
max_tokens=getattr(parent_agent, "max_tokens", None),
reasoning_config=getattr(parent_agent, "reasoning_config", None),
prefill_messages=getattr(parent_agent, "prefill_messages", None),
enabled_toolsets=child_toolsets,
quiet_mode=True,
ephemeral_system_prompt=child_prompt,
log_prefix=f"[subagent-{task_index}]",
platform=parent_agent.platform,
skip_context_files=True,
skip_memory=True,
clarify_callback=None,
session_db=getattr(parent_agent, '_session_db', None),
providers_allowed=parent_agent.providers_allowed,
providers_ignored=parent_agent.providers_ignored,
providers_order=parent_agent.providers_order,
provider_sort=parent_agent.provider_sort,
tool_progress_callback=child_progress_cb,
iteration_budget=shared_budget,
)
# Set delegation depth so children can't spawn grandchildren
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
# Set delegation depth so children can't spawn grandchildren
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
# Register child for interrupt propagation
if hasattr(parent_agent, '_active_children'):
# Register child for interrupt propagation
if hasattr(parent_agent, '_active_children'):
lock = getattr(parent_agent, '_active_children_lock', None)
if lock:
with lock:
parent_agent._active_children.append(child)
else:
parent_agent._active_children.append(child)
# Run with stdout/stderr suppressed to prevent interleaved output
devnull = io.StringIO()
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
result = child.run_conversation(user_message=goal)
return child
def _run_single_child(
task_index: int,
goal: str,
child=None,
parent_agent=None,
**_kwargs,
) -> Dict[str, Any]:
"""
Run a pre-built child agent. Called from within a thread.
Returns a structured result dict.
"""
child_start = time.monotonic()
# Get the progress callback from the child agent
child_progress_cb = getattr(child, 'tool_progress_callback', None)
try:
result = child.run_conversation(user_message=goal)
# Flush any remaining batched progress to gateway
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
@ -355,11 +368,15 @@ def _run_single_child(
# Unregister child from interrupt propagation
if hasattr(parent_agent, '_active_children'):
try:
parent_agent._active_children.remove(child)
lock = getattr(parent_agent, '_active_children_lock', None)
if lock:
with lock:
parent_agent._active_children.remove(child)
else:
parent_agent._active_children.remove(child)
except (ValueError, UnboundLocalError) as e:
logger.debug("Could not remove child from active_children: %s", e)
def delegate_task(
goal: Optional[str] = None,
context: Optional[str] = None,
@ -428,51 +445,38 @@ def delegate_task(
# Track goal labels for progress display (truncated for readability)
task_labels = [t["goal"][:40] for t in task_list]
if n_tasks == 1:
# Single task -- run directly (no thread pool overhead)
t = task_list[0]
result = _run_single_child(
task_index=0,
goal=t["goal"],
context=t.get("context"),
toolsets=t.get("toolsets") or toolsets,
model=creds["model"],
max_iterations=effective_max_iter,
parent_agent=parent_agent,
task_count=1,
override_provider=creds["provider"],
override_base_url=creds["base_url"],
# Build all child agents on the main thread (thread-safe construction)
children = []
for i, t in enumerate(task_list):
child = _build_child_agent(
task_index=i, goal=t["goal"], context=t.get("context"),
toolsets=t.get("toolsets") or toolsets, model=creds["model"],
max_iterations=effective_max_iter, parent_agent=parent_agent,
override_provider=creds["provider"], override_base_url=creds["base_url"],
override_api_key=creds["api_key"],
override_api_mode=creds["api_mode"],
)
children.append((i, t, child))
if n_tasks == 1:
# Single task -- run directly (no thread pool overhead)
_i, _t, child = children[0]
result = _run_single_child(0, _t["goal"], child, parent_agent)
results.append(result)
else:
# Batch -- run in parallel with per-task progress lines
completed_count = 0
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
# Save stdout/stderr before the executor — redirect_stdout in child
# threads races on sys.stdout and can leave it as devnull permanently.
_saved_stdout = sys.stdout
_saved_stderr = sys.stderr
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor:
futures = {}
for i, t in enumerate(task_list):
for i, t, child in children:
future = executor.submit(
_run_single_child,
task_index=i,
goal=t["goal"],
context=t.get("context"),
toolsets=t.get("toolsets") or toolsets,
model=creds["model"],
max_iterations=effective_max_iter,
child=child,
parent_agent=parent_agent,
task_count=n_tasks,
override_provider=creds["provider"],
override_base_url=creds["base_url"],
override_api_key=creds["api_key"],
override_api_mode=creds["api_mode"],
)
futures[future] = i
@ -515,10 +519,6 @@ def delegate_task(
except Exception as e:
logger.debug("Spinner update_text failed: %s", e)
# Restore stdout/stderr in case redirect_stdout race left them as devnull
sys.stdout = _saved_stdout
sys.stderr = _saved_stderr
# Sort by task_index so results match input order
results.sort(key=lambda r: r["task_index"])

View file

@ -82,6 +82,9 @@ def _build_provider_env_blocklist() -> frozenset:
"FIREWORKS_API_KEY", # Fireworks AI
"XAI_API_KEY", # xAI (Grok)
"HELICONE_API_KEY", # LLM Observability proxy
"PARALLEL_API_KEY",
"FIRECRAWL_API_KEY",
"FIRECRAWL_API_URL",
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
"TELEGRAM_HOME_CHANNEL",
"TELEGRAM_HOME_CHANNEL_NAME",

View file

@ -94,7 +94,7 @@ def _get_safe_write_root() -> Optional[str]:
def _is_write_denied(path: str) -> bool:
"""Return True if path is on the write deny list."""
resolved = os.path.realpath(os.path.expanduser(path))
resolved = os.path.realpath(os.path.expanduser(str(path)))
# 1) Static deny list
if resolved in WRITE_DENIED_PATHS:

View file

@ -254,10 +254,9 @@ def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, in
if '\n'.join(check_lines) == modified_pattern:
# Found match - calculate original positions
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
if end_pos >= len(content):
end_pos = len(content)
start_pos, end_pos = _calculate_line_positions(
content_lines, i, i + pattern_line_count, len(content)
)
matches.append((start_pos, end_pos))
return matches
@ -309,9 +308,10 @@ def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
if similarity >= threshold:
# Calculate positions using ORIGINAL lines to ensure correct character offsets in the file
start_pos = sum(len(line) + 1 for line in orig_content_lines[:i])
end_pos = sum(len(line) + 1 for line in orig_content_lines[:i + pattern_line_count]) - 1
matches.append((start_pos, min(end_pos, len(content))))
start_pos, end_pos = _calculate_line_positions(
orig_content_lines, i, i + pattern_line_count, len(content)
)
matches.append((start_pos, end_pos))
return matches
@ -343,10 +343,9 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]
# Need at least 50% of lines to have high similarity
if high_similarity_count >= len(pattern_lines) * 0.5:
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
if end_pos >= len(content):
end_pos = len(content)
start_pos, end_pos = _calculate_line_positions(
content_lines, i, i + pattern_line_count, len(content)
)
matches.append((start_pos, end_pos))
return matches
@ -356,6 +355,26 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]
# Helper Functions
# =============================================================================
def _calculate_line_positions(content_lines: List[str], start_line: int,
end_line: int, content_length: int) -> Tuple[int, int]:
"""Calculate start and end character positions from line indices.
Args:
content_lines: List of lines (without newlines)
start_line: Starting line index (0-based)
end_line: Ending line index (exclusive, 0-based)
content_length: Total length of the original content string
Returns:
Tuple of (start_pos, end_pos) in the original content
"""
start_pos = sum(len(line) + 1 for line in content_lines[:start_line])
end_pos = sum(len(line) + 1 for line in content_lines[:end_line]) - 1
if end_pos >= content_length:
end_pos = content_length
return start_pos, end_pos
def _find_normalized_matches(content: str, content_lines: List[str],
content_normalized_lines: List[str],
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
@ -383,13 +402,9 @@ def _find_normalized_matches(content: str, content_lines: List[str],
if block == pattern_normalized:
# Found a match - calculate original positions
start_pos = sum(len(line) + 1 for line in content_lines[:i])
end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1
# Handle case where end is past content
if end_pos >= len(content):
end_pos = len(content)
start_pos, end_pos = _calculate_line_positions(
content_lines, i, i + num_pattern_lines, len(content)
)
matches.append((start_pos, end_pos))
return matches

View file

@ -1624,6 +1624,72 @@ def get_mcp_status() -> List[dict]:
return result
def probe_mcp_server_tools() -> Dict[str, List[tuple]]:
"""Temporarily connect to configured MCP servers and list their tools.
Designed for ``hermes tools`` interactive configuration connects to each
enabled server, grabs tool names and descriptions, then disconnects.
Does NOT register tools in the Hermes registry.
Returns:
Dict mapping server name to list of (tool_name, description) tuples.
Servers that fail to connect are omitted from the result.
"""
if not _MCP_AVAILABLE:
return {}
servers_config = _load_mcp_config()
if not servers_config:
return {}
enabled = {
k: v for k, v in servers_config.items()
if _parse_boolish(v.get("enabled", True), default=True)
}
if not enabled:
return {}
_ensure_mcp_loop()
result: Dict[str, List[tuple]] = {}
probed_servers: List[MCPServerTask] = []
async def _probe_all():
names = list(enabled.keys())
coros = []
for name, cfg in enabled.items():
ct = cfg.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
coros.append(asyncio.wait_for(_connect_server(name, cfg), timeout=ct))
outcomes = await asyncio.gather(*coros, return_exceptions=True)
for name, outcome in zip(names, outcomes):
if isinstance(outcome, Exception):
logger.debug("Probe: failed to connect to '%s': %s", name, outcome)
continue
probed_servers.append(outcome)
tools = []
for t in outcome._tools:
desc = getattr(t, "description", "") or ""
tools.append((t.name, desc))
result[name] = tools
# Shut down all probed connections
await asyncio.gather(
*(s.shutdown() for s in probed_servers),
return_exceptions=True,
)
try:
_run_on_mcp_loop(_probe_all(), timeout=120)
except Exception as exc:
logger.debug("MCP probe failed: %s", exc)
finally:
_stop_mcp_loop()
return result
def shutdown_mcp_servers():
"""Close all MCP server connections and stop the background loop.

View file

@ -23,11 +23,13 @@ Design:
- Frozen snapshot pattern: system prompt is stable, tool responses show live state
"""
import fcntl
import json
import logging
import os
import re
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Any, List, Optional
@ -120,14 +122,43 @@ class MemoryStore:
"user": self._render_block("user", self.user_entries),
}
@staticmethod
@contextmanager
def _file_lock(path: Path):
"""Acquire an exclusive file lock for read-modify-write safety.
Uses a separate .lock file so the memory file itself can still be
atomically replaced via os.replace().
"""
lock_path = path.with_suffix(path.suffix + ".lock")
lock_path.parent.mkdir(parents=True, exist_ok=True)
fd = open(lock_path, "w")
try:
fcntl.flock(fd, fcntl.LOCK_EX)
yield
finally:
fcntl.flock(fd, fcntl.LOCK_UN)
fd.close()
@staticmethod
def _path_for(target: str) -> Path:
if target == "user":
return MEMORY_DIR / "USER.md"
return MEMORY_DIR / "MEMORY.md"
def _reload_target(self, target: str):
"""Re-read entries from disk into in-memory state.
Called under file lock to get the latest state before mutating.
"""
fresh = self._read_file(self._path_for(target))
fresh = list(dict.fromkeys(fresh)) # deduplicate
self._set_entries(target, fresh)
def save_to_disk(self, target: str):
"""Persist entries to the appropriate file. Called after every mutation."""
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
if target == "memory":
self._write_file(MEMORY_DIR / "MEMORY.md", self.memory_entries)
elif target == "user":
self._write_file(MEMORY_DIR / "USER.md", self.user_entries)
self._write_file(self._path_for(target), self._entries_for(target))
def _entries_for(self, target: str) -> List[str]:
if target == "user":
@ -162,33 +193,37 @@ class MemoryStore:
if scan_error:
return {"success": False, "error": scan_error}
entries = self._entries_for(target)
limit = self._char_limit(target)
with self._file_lock(self._path_for(target)):
# Re-read from disk under lock to pick up writes from other sessions
self._reload_target(target)
# Reject exact duplicates
if content in entries:
return self._success_response(target, "Entry already exists (no duplicate added).")
entries = self._entries_for(target)
limit = self._char_limit(target)
# Calculate what the new total would be
new_entries = entries + [content]
new_total = len(ENTRY_DELIMITER.join(new_entries))
# Reject exact duplicates
if content in entries:
return self._success_response(target, "Entry already exists (no duplicate added).")
if new_total > limit:
current = self._char_count(target)
return {
"success": False,
"error": (
f"Memory at {current:,}/{limit:,} chars. "
f"Adding this entry ({len(content)} chars) would exceed the limit. "
f"Replace or remove existing entries first."
),
"current_entries": entries,
"usage": f"{current:,}/{limit:,}",
}
# Calculate what the new total would be
new_entries = entries + [content]
new_total = len(ENTRY_DELIMITER.join(new_entries))
entries.append(content)
self._set_entries(target, entries)
self.save_to_disk(target)
if new_total > limit:
current = self._char_count(target)
return {
"success": False,
"error": (
f"Memory at {current:,}/{limit:,} chars. "
f"Adding this entry ({len(content)} chars) would exceed the limit. "
f"Replace or remove existing entries first."
),
"current_entries": entries,
"usage": f"{current:,}/{limit:,}",
}
entries.append(content)
self._set_entries(target, entries)
self.save_to_disk(target)
return self._success_response(target, "Entry added.")
@ -206,44 +241,47 @@ class MemoryStore:
if scan_error:
return {"success": False, "error": scan_error}
entries = self._entries_for(target)
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
with self._file_lock(self._path_for(target)):
self._reload_target(target)
if len(matches) == 0:
return {"success": False, "error": f"No entry matched '{old_text}'."}
entries = self._entries_for(target)
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
if len(matches) > 1:
# If all matches are identical (exact duplicates), operate on the first one
unique_texts = set(e for _, e in matches)
if len(unique_texts) > 1:
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
if len(matches) == 0:
return {"success": False, "error": f"No entry matched '{old_text}'."}
if len(matches) > 1:
# If all matches are identical (exact duplicates), operate on the first one
unique_texts = set(e for _, e in matches)
if len(unique_texts) > 1:
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
return {
"success": False,
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
"matches": previews,
}
# All identical -- safe to replace just the first
idx = matches[0][0]
limit = self._char_limit(target)
# Check that replacement doesn't blow the budget
test_entries = entries.copy()
test_entries[idx] = new_content
new_total = len(ENTRY_DELIMITER.join(test_entries))
if new_total > limit:
return {
"success": False,
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
"matches": previews,
"error": (
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
f"Shorten the new content or remove other entries first."
),
}
# All identical -- safe to replace just the first
idx = matches[0][0]
limit = self._char_limit(target)
# Check that replacement doesn't blow the budget
test_entries = entries.copy()
test_entries[idx] = new_content
new_total = len(ENTRY_DELIMITER.join(test_entries))
if new_total > limit:
return {
"success": False,
"error": (
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
f"Shorten the new content or remove other entries first."
),
}
entries[idx] = new_content
self._set_entries(target, entries)
self.save_to_disk(target)
entries[idx] = new_content
self._set_entries(target, entries)
self.save_to_disk(target)
return self._success_response(target, "Entry replaced.")
@ -253,28 +291,31 @@ class MemoryStore:
if not old_text:
return {"success": False, "error": "old_text cannot be empty."}
entries = self._entries_for(target)
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
with self._file_lock(self._path_for(target)):
self._reload_target(target)
if len(matches) == 0:
return {"success": False, "error": f"No entry matched '{old_text}'."}
entries = self._entries_for(target)
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
if len(matches) > 1:
# If all matches are identical (exact duplicates), remove the first one
unique_texts = set(e for _, e in matches)
if len(unique_texts) > 1:
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
return {
"success": False,
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
"matches": previews,
}
# All identical -- safe to remove just the first
if len(matches) == 0:
return {"success": False, "error": f"No entry matched '{old_text}'."}
idx = matches[0][0]
entries.pop(idx)
self._set_entries(target, entries)
self.save_to_disk(target)
if len(matches) > 1:
# If all matches are identical (exact duplicates), remove the first one
unique_texts = set(e for _, e in matches)
if len(unique_texts) > 1:
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
return {
"success": False,
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
"matches": previews,
}
# All identical -- safe to remove just the first
idx = matches[0][0]
entries.pop(idx)
self._set_entries(target, entries)
self.save_to_disk(target)
return self._success_response(target, "Entry removed.")

View file

@ -78,6 +78,11 @@ class ProcessSession:
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
max_output_chars: int = MAX_OUTPUT_CHARS
detached: bool = False # True if recovered from crash (no pipe)
# Watcher/notification metadata (persisted for crash recovery)
watcher_platform: str = ""
watcher_chat_id: str = ""
watcher_thread_id: str = ""
watcher_interval: int = 0 # 0 = no watcher configured
_lock: threading.Lock = field(default_factory=threading.Lock)
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
@ -709,6 +714,10 @@ class ProcessRegistry:
"started_at": s.started_at,
"task_id": s.task_id,
"session_key": s.session_key,
"watcher_platform": s.watcher_platform,
"watcher_chat_id": s.watcher_chat_id,
"watcher_thread_id": s.watcher_thread_id,
"watcher_interval": s.watcher_interval,
})
# Atomic write to avoid corruption on crash
@ -755,12 +764,27 @@ class ProcessRegistry:
cwd=entry.get("cwd"),
started_at=entry.get("started_at", time.time()),
detached=True, # Can't read output, but can report status + kill
watcher_platform=entry.get("watcher_platform", ""),
watcher_chat_id=entry.get("watcher_chat_id", ""),
watcher_thread_id=entry.get("watcher_thread_id", ""),
watcher_interval=entry.get("watcher_interval", 0),
)
with self._lock:
self._running[session.id] = session
recovered += 1
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
# Re-enqueue watcher so gateway can resume notifications
if session.watcher_interval > 0:
self.pending_watchers.append({
"session_id": session.id,
"check_interval": session.watcher_interval,
"session_key": session.session_key,
"platform": session.watcher_platform,
"chat_id": session.watcher_chat_id,
"thread_id": session.watcher_thread_id,
})
# Clear the checkpoint (will be rewritten as processes finish)
try:
from utils import atomic_json_write

View file

@ -355,20 +355,31 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
"""Send via Telegram Bot API (one-shot, no polling needed).
Applies markdownMarkdownV2 formatting (same as the gateway adapter)
so that bold, links, and headers render correctly.
so that bold, links, and headers render correctly. If the message
already contains HTML tags, it is sent with ``parse_mode='HTML'``
instead, bypassing MarkdownV2 conversion.
"""
try:
from telegram import Bot
from telegram.constants import ParseMode
# Reuse the gateway adapter's format_message for markdown→MarkdownV2
try:
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2
_adapter = TelegramAdapter.__new__(TelegramAdapter)
formatted = _adapter.format_message(message)
except Exception:
# Fallback: send as-is if formatting unavailable
# Auto-detect HTML tags — if present, skip MarkdownV2 and send as HTML.
# Inspired by github.com/ashaney — PR #1568.
_has_html = bool(re.search(r'<[a-zA-Z/][^>]*>', message))
if _has_html:
formatted = message
send_parse_mode = ParseMode.HTML
else:
# Reuse the gateway adapter's format_message for markdown→MarkdownV2
try:
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2
_adapter = TelegramAdapter.__new__(TelegramAdapter)
formatted = _adapter.format_message(message)
except Exception:
# Fallback: send as-is if formatting unavailable
formatted = message
send_parse_mode = ParseMode.MARKDOWN_V2
bot = Bot(token=token)
int_chat_id = int(chat_id)
@ -384,16 +395,19 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
try:
last_msg = await bot.send_message(
chat_id=int_chat_id, text=formatted,
parse_mode=ParseMode.MARKDOWN_V2, **thread_kwargs
parse_mode=send_parse_mode, **thread_kwargs
)
except Exception as md_error:
# MarkdownV2 failed, fall back to plain text
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
logger.warning("MarkdownV2 parse failed in _send_telegram, falling back to plain text: %s", md_error)
try:
from gateway.platforms.telegram import _strip_mdv2
plain = _strip_mdv2(formatted)
except Exception:
# Parse failed, fall back to plain text
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower() or "html" in str(md_error).lower():
logger.warning("Parse mode %s failed in _send_telegram, falling back to plain text: %s", send_parse_mode, md_error)
if not _has_html:
try:
from gateway.platforms.telegram import _strip_mdv2
plain = _strip_mdv2(formatted)
except Exception:
plain = message
else:
plain = message
last_msg = await bot.send_message(
chat_id=int_chat_id, text=plain,
@ -565,50 +579,55 @@ async def _send_email(extra, chat_id, message):
return {"error": f"Email send failed: {e}"}
async def _send_sms(api_key, chat_id, message):
"""Send via Telnyx SMS REST API (one-shot, no persistent connection needed)."""
async def _send_sms(auth_token, chat_id, message):
"""Send a single SMS via Twilio REST API.
Uses HTTP Basic auth (Account SID : Auth Token) and form-encoded POST.
Chunking is handled by _send_to_platform() before this is called.
"""
try:
import aiohttp
except ImportError:
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
import base64
account_sid = os.getenv("TWILIO_ACCOUNT_SID", "")
from_number = os.getenv("TWILIO_PHONE_NUMBER", "")
if not account_sid or not auth_token or not from_number:
return {"error": "SMS not configured (TWILIO_ACCOUNT_SID, TWILIO_AUTH_TOKEN, TWILIO_PHONE_NUMBER required)"}
# Strip markdown — SMS renders it as literal characters
message = re.sub(r"\*\*(.+?)\*\*", r"\1", message, flags=re.DOTALL)
message = re.sub(r"\*(.+?)\*", r"\1", message, flags=re.DOTALL)
message = re.sub(r"__(.+?)__", r"\1", message, flags=re.DOTALL)
message = re.sub(r"_(.+?)_", r"\1", message, flags=re.DOTALL)
message = re.sub(r"```[a-z]*\n?", "", message)
message = re.sub(r"`(.+?)`", r"\1", message)
message = re.sub(r"^#{1,6}\s+", "", message, flags=re.MULTILINE)
message = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", message)
message = re.sub(r"\n{3,}", "\n\n", message)
message = message.strip()
try:
from_number = os.getenv("TELNYX_FROM_NUMBERS", "").split(",")[0].strip()
if not from_number:
return {"error": "TELNYX_FROM_NUMBERS not configured"}
if not api_key:
api_key = os.getenv("TELNYX_API_KEY", "")
if not api_key:
return {"error": "TELNYX_API_KEY not configured"}
creds = f"{account_sid}:{auth_token}"
encoded = base64.b64encode(creds.encode("ascii")).decode("ascii")
url = f"https://api.twilio.com/2010-04-01/Accounts/{account_sid}/Messages.json"
headers = {"Authorization": f"Basic {encoded}"}
# Strip markdown for SMS
text = re.sub(r"\*\*(.+?)\*\*", r"\1", message, flags=re.DOTALL)
text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
text = re.sub(r"```[a-z]*\n?", "", text)
text = re.sub(r"`(.+?)`", r"\1", text)
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
text = text.strip()
# Chunk to 1600 chars
chunks = [text[i:i+1600] for i in range(0, len(text), 1600)] if len(text) > 1600 else [text]
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
message_ids = []
async with aiohttp.ClientSession() as session:
for chunk in chunks:
payload = {"from": from_number, "to": chat_id, "text": chunk}
async with session.post(
"https://api.telnyx.com/v2/messages",
json=payload,
headers=headers,
) as resp:
body = await resp.json()
if resp.status >= 400:
return {"error": f"Telnyx API error ({resp.status}): {body}"}
message_ids.append(body.get("data", {}).get("id", ""))
return {"success": True, "platform": "sms", "chat_id": chat_id, "message_ids": message_ids}
form_data = aiohttp.FormData()
form_data.add_field("From", from_number)
form_data.add_field("To", chat_id)
form_data.add_field("Body", message)
async with session.post(url, data=form_data, headers=headers) as resp:
body = await resp.json()
if resp.status >= 400:
error_msg = body.get("message", str(body))
return {"error": f"Twilio API error ({resp.status}): {error_msg}"}
msg_sid = body.get("sid", "")
return {"success": True, "platform": "sms", "chat_id": chat_id, "message_id": msg_sid}
except Exception as e:
return {"error": f"SMS send failed: {e}"}

View file

@ -1082,13 +1082,23 @@ def terminal_tool(
result_data["check_interval_note"] = (
f"Requested {check_interval}s raised to minimum 30s"
)
watcher_platform = os.getenv("HERMES_SESSION_PLATFORM", "")
watcher_chat_id = os.getenv("HERMES_SESSION_CHAT_ID", "")
watcher_thread_id = os.getenv("HERMES_SESSION_THREAD_ID", "")
# Store on session for checkpoint persistence
proc_session.watcher_platform = watcher_platform
proc_session.watcher_chat_id = watcher_chat_id
proc_session.watcher_thread_id = watcher_thread_id
proc_session.watcher_interval = effective_interval
process_registry.pending_watchers.append({
"session_id": proc_session.id,
"check_interval": effective_interval,
"session_key": session_key,
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID", ""),
"platform": watcher_platform,
"chat_id": watcher_chat_id,
"thread_id": watcher_thread_id,
})
return json.dumps(result_data, ensure_ascii=False)

View file

@ -3,16 +3,16 @@
Standalone Web Tools Module
This module provides generic web tools that work with multiple backend providers.
Currently uses Firecrawl as the backend, and the interface makes it easy to swap
providers without changing the function signatures.
Backend is selected during ``hermes tools`` setup (web.backend in config.yaml).
Available tools:
- web_search_tool: Search the web for information
- web_extract_tool: Extract content from specific web pages
- web_crawl_tool: Crawl websites with specific instructions
- web_crawl_tool: Crawl websites with specific instructions (Firecrawl only)
Backend compatibility:
- Firecrawl: https://docs.firecrawl.dev/introduction
- Firecrawl: https://docs.firecrawl.dev/introduction (search, extract, crawl)
- Parallel: https://docs.parallel.ai (search, extract)
LLM Processing:
- Uses OpenRouter API with Gemini 3 Flash Preview for intelligent content extraction
@ -46,12 +46,50 @@ import os
import re
import asyncio
from typing import List, Dict, Any, Optional
import httpx
from firecrawl import Firecrawl
from agent.auxiliary_client import async_call_llm
from tools.debug_helpers import DebugSession
from tools.website_policy import check_website_access
logger = logging.getLogger(__name__)
# ─── Backend Selection ────────────────────────────────────────────────────────
def _load_web_config() -> dict:
"""Load the ``web:`` section from ~/.hermes/config.yaml."""
try:
from hermes_cli.config import load_config
return load_config().get("web", {})
except (ImportError, Exception):
return {}
def _get_backend() -> str:
"""Determine which web backend to use.
Reads ``web.backend`` from config.yaml (set by ``hermes tools``).
Falls back to whichever API key is present for users who configured
keys manually without running setup.
"""
configured = _load_web_config().get("backend", "").lower().strip()
if configured in ("parallel", "firecrawl", "tavily"):
return configured
# Fallback for manual / legacy config — use whichever key is present.
has_firecrawl = bool(os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL"))
has_parallel = bool(os.getenv("PARALLEL_API_KEY"))
has_tavily = bool(os.getenv("TAVILY_API_KEY"))
if has_tavily and not has_firecrawl and not has_parallel:
return "tavily"
if has_parallel and not has_firecrawl:
return "parallel"
# Default to firecrawl (backward compat, or when both are set)
return "firecrawl"
# ─── Firecrawl Client ────────────────────────────────────────────────────────
_firecrawl_client = None
def _get_firecrawl_client():
@ -80,6 +118,129 @@ def _get_firecrawl_client():
_firecrawl_client = Firecrawl(**kwargs)
return _firecrawl_client
# ─── Parallel Client ─────────────────────────────────────────────────────────
_parallel_client = None
_async_parallel_client = None
def _get_parallel_client():
"""Get or create the Parallel sync client (lazy initialization).
Requires PARALLEL_API_KEY environment variable.
"""
from parallel import Parallel
global _parallel_client
if _parallel_client is None:
api_key = os.getenv("PARALLEL_API_KEY")
if not api_key:
raise ValueError(
"PARALLEL_API_KEY environment variable not set. "
"Get your API key at https://parallel.ai"
)
_parallel_client = Parallel(api_key=api_key)
return _parallel_client
def _get_async_parallel_client():
"""Get or create the Parallel async client (lazy initialization).
Requires PARALLEL_API_KEY environment variable.
"""
from parallel import AsyncParallel
global _async_parallel_client
if _async_parallel_client is None:
api_key = os.getenv("PARALLEL_API_KEY")
if not api_key:
raise ValueError(
"PARALLEL_API_KEY environment variable not set. "
"Get your API key at https://parallel.ai"
)
_async_parallel_client = AsyncParallel(api_key=api_key)
return _async_parallel_client
# ─── Tavily Client ───────────────────────────────────────────────────────────
_TAVILY_BASE_URL = "https://api.tavily.com"
def _tavily_request(endpoint: str, payload: dict) -> dict:
"""Send a POST request to the Tavily API.
Auth is provided via ``api_key`` in the JSON body (no header-based auth).
Raises ``ValueError`` if ``TAVILY_API_KEY`` is not set.
"""
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
raise ValueError(
"TAVILY_API_KEY environment variable not set. "
"Get your API key at https://app.tavily.com/home"
)
payload["api_key"] = api_key
url = f"{_TAVILY_BASE_URL}/{endpoint.lstrip('/')}"
logger.info("Tavily %s request to %s", endpoint, url)
response = httpx.post(url, json=payload, timeout=60)
response.raise_for_status()
return response.json()
def _normalize_tavily_search_results(response: dict) -> dict:
"""Normalize Tavily /search response to the standard web search format.
Tavily returns ``{results: [{title, url, content, score, ...}]}``.
We map to ``{success, data: {web: [{title, url, description, position}]}}``.
"""
web_results = []
for i, result in enumerate(response.get("results", [])):
web_results.append({
"title": result.get("title", ""),
"url": result.get("url", ""),
"description": result.get("content", ""),
"position": i + 1,
})
return {"success": True, "data": {"web": web_results}}
def _normalize_tavily_documents(response: dict, fallback_url: str = "") -> List[Dict[str, Any]]:
"""Normalize Tavily /extract or /crawl response to the standard document format.
Maps results to ``{url, title, content, raw_content, metadata}`` and
includes any ``failed_results`` / ``failed_urls`` as error entries.
"""
documents: List[Dict[str, Any]] = []
for result in response.get("results", []):
url = result.get("url", fallback_url)
raw = result.get("raw_content", "") or result.get("content", "")
documents.append({
"url": url,
"title": result.get("title", ""),
"content": raw,
"raw_content": raw,
"metadata": {"sourceURL": url, "title": result.get("title", "")},
})
# Handle failed results
for fail in response.get("failed_results", []):
documents.append({
"url": fail.get("url", fallback_url),
"title": "",
"content": "",
"raw_content": "",
"error": fail.get("error", "extraction failed"),
"metadata": {"sourceURL": fail.get("url", fallback_url)},
})
for fail_url in response.get("failed_urls", []):
url_str = fail_url if isinstance(fail_url, str) else str(fail_url)
documents.append({
"url": url_str,
"title": "",
"content": "",
"raw_content": "",
"error": "extraction failed",
"metadata": {"sourceURL": url_str},
})
return documents
DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000
# Allow per-task override via env var
@ -427,13 +588,89 @@ def clean_base64_images(text: str) -> str:
return cleaned_text
# ─── Parallel Search & Extract Helpers ────────────────────────────────────────
def _parallel_search(query: str, limit: int = 5) -> dict:
"""Search using the Parallel SDK and return results as a dict."""
from tools.interrupt import is_interrupted
if is_interrupted():
return {"error": "Interrupted", "success": False}
mode = os.getenv("PARALLEL_SEARCH_MODE", "agentic").lower().strip()
if mode not in ("fast", "one-shot", "agentic"):
mode = "agentic"
logger.info("Parallel search: '%s' (mode=%s, limit=%d)", query, mode, limit)
response = _get_parallel_client().beta.search(
search_queries=[query],
objective=query,
mode=mode,
max_results=min(limit, 20),
)
web_results = []
for i, result in enumerate(response.results or []):
excerpts = result.excerpts or []
web_results.append({
"url": result.url or "",
"title": result.title or "",
"description": " ".join(excerpts) if excerpts else "",
"position": i + 1,
})
return {"success": True, "data": {"web": web_results}}
async def _parallel_extract(urls: List[str]) -> List[Dict[str, Any]]:
"""Extract content from URLs using the Parallel async SDK.
Returns a list of result dicts matching the structure expected by the
LLM post-processing pipeline (url, title, content, metadata).
"""
from tools.interrupt import is_interrupted
if is_interrupted():
return [{"url": u, "error": "Interrupted", "title": ""} for u in urls]
logger.info("Parallel extract: %d URL(s)", len(urls))
response = await _get_async_parallel_client().beta.extract(
urls=urls,
full_content=True,
)
results = []
for result in response.results or []:
content = result.full_content or ""
if not content:
content = "\n\n".join(result.excerpts or [])
url = result.url or ""
title = result.title or ""
results.append({
"url": url,
"title": title,
"content": content,
"raw_content": content,
"metadata": {"sourceURL": url, "title": title},
})
for error in response.errors or []:
results.append({
"url": error.url or "",
"title": "",
"content": "",
"error": error.content or error.error_type or "extraction failed",
"metadata": {"sourceURL": error.url or ""},
})
return results
def web_search_tool(query: str, limit: int = 5) -> str:
"""
Search the web for information using available search API backend.
This function provides a generic interface for web search that can work
with multiple backends. Currently uses Firecrawl.
with multiple backends (Parallel or Firecrawl).
Note: This function returns search result metadata only (URLs, titles, descriptions).
Use web_extract_tool to get full content from specific URLs.
@ -477,17 +714,44 @@ def web_search_tool(query: str, limit: int = 5) -> str:
if is_interrupted():
return json.dumps({"error": "Interrupted", "success": False})
# Dispatch to the configured backend
backend = _get_backend()
if backend == "parallel":
response_data = _parallel_search(query, limit)
debug_call_data["results_count"] = len(response_data.get("data", {}).get("web", []))
result_json = json.dumps(response_data, indent=2, ensure_ascii=False)
debug_call_data["final_response_size"] = len(result_json)
_debug.log_call("web_search_tool", debug_call_data)
_debug.save()
return result_json
if backend == "tavily":
logger.info("Tavily search: '%s' (limit: %d)", query, limit)
raw = _tavily_request("search", {
"query": query,
"max_results": min(limit, 20),
"include_raw_content": False,
"include_images": False,
})
response_data = _normalize_tavily_search_results(raw)
debug_call_data["results_count"] = len(response_data.get("data", {}).get("web", []))
result_json = json.dumps(response_data, indent=2, ensure_ascii=False)
debug_call_data["final_response_size"] = len(result_json)
_debug.log_call("web_search_tool", debug_call_data)
_debug.save()
return result_json
logger.info("Searching the web for: '%s' (limit: %d)", query, limit)
response = _get_firecrawl_client().search(
query=query,
limit=limit
)
# The response is a SearchData object with web, news, and images attributes
# When not scraping, the results are directly in these attributes
web_results = []
# Check if response has web attribute (SearchData object)
if hasattr(response, 'web'):
# Response is a SearchData object with web attribute
@ -595,100 +859,137 @@ async def web_extract_tool(
try:
logger.info("Extracting content from %d URL(s)", len(urls))
# Determine requested formats for Firecrawl v2
formats: List[str] = []
if format == "markdown":
formats = ["markdown"]
elif format == "html":
formats = ["html"]
else:
# Default: request markdown for LLM-readiness and include html as backup
formats = ["markdown", "html"]
# Always use individual scraping for simplicity and reliability
# Batch scraping adds complexity without much benefit for small numbers of URLs
results: List[Dict[str, Any]] = []
from tools.interrupt import is_interrupted as _is_interrupted
for url in urls:
if _is_interrupted():
results.append({"url": url, "error": "Interrupted", "title": ""})
continue
try:
logger.info("Scraping: %s", url)
scrape_result = _get_firecrawl_client().scrape(
url=url,
formats=formats
)
# Process the result - properly handle object serialization
metadata = {}
title = ""
content_markdown = None
content_html = None
# Extract data from the scrape result
if hasattr(scrape_result, 'model_dump'):
# Pydantic model - use model_dump to get dict
result_dict = scrape_result.model_dump()
content_markdown = result_dict.get('markdown')
content_html = result_dict.get('html')
metadata = result_dict.get('metadata', {})
elif hasattr(scrape_result, '__dict__'):
# Regular object with attributes
content_markdown = getattr(scrape_result, 'markdown', None)
content_html = getattr(scrape_result, 'html', None)
# Handle metadata - convert to dict if it's an object
metadata_obj = getattr(scrape_result, 'metadata', {})
if hasattr(metadata_obj, 'model_dump'):
metadata = metadata_obj.model_dump()
elif hasattr(metadata_obj, '__dict__'):
metadata = metadata_obj.__dict__
elif isinstance(metadata_obj, dict):
metadata = metadata_obj
else:
metadata = {}
elif isinstance(scrape_result, dict):
# Already a dictionary
content_markdown = scrape_result.get('markdown')
content_html = scrape_result.get('html')
metadata = scrape_result.get('metadata', {})
# Ensure metadata is a dict (not an object)
if not isinstance(metadata, dict):
if hasattr(metadata, 'model_dump'):
metadata = metadata.model_dump()
elif hasattr(metadata, '__dict__'):
metadata = metadata.__dict__
else:
metadata = {}
# Get title from metadata
title = metadata.get("title", "")
# Choose content based on requested format
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
results.append({
"url": metadata.get("sourceURL", url),
"title": title,
"content": chosen_content,
"raw_content": chosen_content,
"metadata": metadata # Now guaranteed to be a dict
})
except Exception as scrape_err:
logger.debug("Scrape failed for %s: %s", url, scrape_err)
results.append({
"url": url,
"title": "",
"content": "",
"raw_content": "",
"error": str(scrape_err)
})
# Dispatch to the configured backend
backend = _get_backend()
if backend == "parallel":
results = await _parallel_extract(urls)
elif backend == "tavily":
logger.info("Tavily extract: %d URL(s)", len(urls))
raw = _tavily_request("extract", {
"urls": urls,
"include_images": False,
})
results = _normalize_tavily_documents(raw, fallback_url=urls[0] if urls else "")
else:
# ── Firecrawl extraction ──
# Determine requested formats for Firecrawl v2
formats: List[str] = []
if format == "markdown":
formats = ["markdown"]
elif format == "html":
formats = ["html"]
else:
# Default: request markdown for LLM-readiness and include html as backup
formats = ["markdown", "html"]
# Always use individual scraping for simplicity and reliability
# Batch scraping adds complexity without much benefit for small numbers of URLs
results: List[Dict[str, Any]] = []
from tools.interrupt import is_interrupted as _is_interrupted
for url in urls:
if _is_interrupted():
results.append({"url": url, "error": "Interrupted", "title": ""})
continue
# Website policy check — block before fetching
blocked = check_website_access(url)
if blocked:
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
results.append({
"url": url, "title": "", "content": "",
"error": blocked["message"],
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
})
continue
try:
logger.info("Scraping: %s", url)
scrape_result = _get_firecrawl_client().scrape(
url=url,
formats=formats
)
# Process the result - properly handle object serialization
metadata = {}
title = ""
content_markdown = None
content_html = None
# Extract data from the scrape result
if hasattr(scrape_result, 'model_dump'):
# Pydantic model - use model_dump to get dict
result_dict = scrape_result.model_dump()
content_markdown = result_dict.get('markdown')
content_html = result_dict.get('html')
metadata = result_dict.get('metadata', {})
elif hasattr(scrape_result, '__dict__'):
# Regular object with attributes
content_markdown = getattr(scrape_result, 'markdown', None)
content_html = getattr(scrape_result, 'html', None)
# Handle metadata - convert to dict if it's an object
metadata_obj = getattr(scrape_result, 'metadata', {})
if hasattr(metadata_obj, 'model_dump'):
metadata = metadata_obj.model_dump()
elif hasattr(metadata_obj, '__dict__'):
metadata = metadata_obj.__dict__
elif isinstance(metadata_obj, dict):
metadata = metadata_obj
else:
metadata = {}
elif isinstance(scrape_result, dict):
# Already a dictionary
content_markdown = scrape_result.get('markdown')
content_html = scrape_result.get('html')
metadata = scrape_result.get('metadata', {})
# Ensure metadata is a dict (not an object)
if not isinstance(metadata, dict):
if hasattr(metadata, 'model_dump'):
metadata = metadata.model_dump()
elif hasattr(metadata, '__dict__'):
metadata = metadata.__dict__
else:
metadata = {}
# Get title from metadata
title = metadata.get("title", "")
# Re-check final URL after redirect
final_url = metadata.get("sourceURL", url)
final_blocked = check_website_access(final_url)
if final_blocked:
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
results.append({
"url": final_url, "title": title, "content": "", "raw_content": "",
"error": final_blocked["message"],
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
})
continue
# Choose content based on requested format
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
results.append({
"url": final_url,
"title": title,
"content": chosen_content,
"raw_content": chosen_content,
"metadata": metadata # Now guaranteed to be a dict
})
except Exception as scrape_err:
logger.debug("Scrape failed for %s: %s", url, scrape_err)
results.append({
"url": url,
"title": "",
"content": "",
"raw_content": "",
"error": str(scrape_err)
})
response = {"results": results}
@ -778,6 +1079,7 @@ async def web_extract_tool(
"title": r.get("title", ""),
"content": r.get("content", ""),
"error": r.get("error"),
**({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {}),
}
for r in response.get("results", [])
]
@ -862,6 +1164,91 @@ async def web_crawl_tool(
}
try:
backend = _get_backend()
# Tavily supports crawl via its /crawl endpoint
if backend == "tavily":
# Ensure URL has protocol
if not url.startswith(('http://', 'https://')):
url = f'https://{url}'
# Website policy check
blocked = check_website_access(url)
if blocked:
logger.info("Blocked web_crawl for %s by rule %s", blocked["host"], blocked["rule"])
return json.dumps({"results": [{"url": url, "title": "", "content": "", "error": blocked["message"],
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}}]}, ensure_ascii=False)
from tools.interrupt import is_interrupted as _is_int
if _is_int():
return json.dumps({"error": "Interrupted", "success": False})
logger.info("Tavily crawl: %s", url)
payload: Dict[str, Any] = {
"url": url,
"limit": 20,
"extract_depth": depth,
}
if instructions:
payload["instructions"] = instructions
raw = _tavily_request("crawl", payload)
results = _normalize_tavily_documents(raw, fallback_url=url)
response = {"results": results}
# Fall through to the shared LLM processing and trimming below
# (skip the Firecrawl-specific crawl logic)
pages_crawled = len(response.get('results', []))
logger.info("Crawled %d pages", pages_crawled)
debug_call_data["pages_crawled"] = pages_crawled
debug_call_data["original_response_size"] = len(json.dumps(response))
# Process each result with LLM if enabled
if use_llm_processing:
logger.info("Processing crawled content with LLM (parallel)...")
debug_call_data["processing_applied"].append("llm_processing")
async def _process_tavily_crawl(result):
page_url = result.get('url', 'Unknown URL')
title = result.get('title', '')
content = result.get('content', '')
if not content:
return result, None, "no_content"
original_size = len(content)
processed = await process_content_with_llm(content, page_url, title, model, min_length)
if processed:
result['raw_content'] = content
result['content'] = processed
metrics = {"url": page_url, "original_size": original_size, "processed_size": len(processed),
"compression_ratio": len(processed) / original_size if original_size else 1.0, "model_used": model}
return result, metrics, "processed"
metrics = {"url": page_url, "original_size": original_size, "processed_size": original_size,
"compression_ratio": 1.0, "model_used": None, "reason": "content_too_short"}
return result, metrics, "too_short"
tasks = [_process_tavily_crawl(r) for r in response.get('results', [])]
processed_results = await asyncio.gather(*tasks)
for result, metrics, status in processed_results:
if status == "processed":
debug_call_data["compression_metrics"].append(metrics)
debug_call_data["pages_processed_with_llm"] += 1
trimmed_results = [{"url": r.get("url", ""), "title": r.get("title", ""), "content": r.get("content", ""), "error": r.get("error"),
**({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {})} for r in response.get("results", [])]
result_json = json.dumps({"results": trimmed_results}, indent=2, ensure_ascii=False)
cleaned_result = clean_base64_images(result_json)
debug_call_data["final_response_size"] = len(cleaned_result)
_debug.log_call("web_crawl_tool", debug_call_data)
_debug.save()
return cleaned_result
# web_crawl requires Firecrawl — Parallel has no crawl API
if not (os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL")):
return json.dumps({
"error": "web_crawl requires Firecrawl. Set FIRECRAWL_API_KEY, "
"or use web_search + web_extract instead.",
"success": False,
}, ensure_ascii=False)
# Ensure URL has protocol
if not url.startswith(('http://', 'https://')):
url = f'https://{url}'
@ -870,6 +1257,13 @@ async def web_crawl_tool(
instructions_text = f" with instructions: '{instructions}'" if instructions else ""
logger.info("Crawling %s%s", url, instructions_text)
# Website policy check — block before crawling
blocked = check_website_access(url)
if blocked:
logger.info("Blocked web_crawl for %s by rule %s", blocked["host"], blocked["rule"])
return json.dumps({"results": [{"url": url, "title": "", "content": "", "error": blocked["message"],
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}}]}, ensure_ascii=False)
# Use Firecrawl's v2 crawl functionality
# Docs: https://docs.firecrawl.dev/features/crawl
# The crawl() method automatically waits for completion and returns all data
@ -975,6 +1369,17 @@ async def web_crawl_tool(
page_url = metadata.get("sourceURL", metadata.get("url", "Unknown URL"))
title = metadata.get("title", "")
# Re-check crawled page URL against policy
page_blocked = check_website_access(page_url)
if page_blocked:
logger.info("Blocked crawled page %s by rule %s", page_blocked["host"], page_blocked["rule"])
pages.append({
"url": page_url, "title": title, "content": "", "raw_content": "",
"error": page_blocked["message"],
"blocked_by_policy": {"host": page_blocked["host"], "rule": page_blocked["rule"], "source": page_blocked["source"]},
})
continue
# Choose content (prefer markdown)
content = content_markdown or content_html or ""
@ -1070,9 +1475,11 @@ async def web_crawl_tool(
# Trim output to minimal fields per entry: title, content, error
trimmed_results = [
{
"url": r.get("url", ""),
"title": r.get("title", ""),
"content": r.get("content", ""),
"error": r.get("error")
"error": r.get("error"),
**({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {}),
}
for r in response.get("results", [])
]
@ -1106,13 +1513,23 @@ async def web_crawl_tool(
def check_firecrawl_api_key() -> bool:
"""
Check if the Firecrawl API key is available in environment variables.
Returns:
bool: True if API key is set, False otherwise
"""
return bool(os.getenv("FIRECRAWL_API_KEY"))
def check_web_api_key() -> bool:
"""Check if any web backend API key is available (Parallel, Firecrawl, or Tavily)."""
return bool(
os.getenv("PARALLEL_API_KEY")
or os.getenv("FIRECRAWL_API_KEY")
or os.getenv("FIRECRAWL_API_URL")
or os.getenv("TAVILY_API_KEY")
)
def check_auxiliary_model() -> bool:
"""Check if an auxiliary text model is available for LLM content processing."""
try:
@ -1139,26 +1556,32 @@ if __name__ == "__main__":
print("=" * 40)
# Check if API keys are available
firecrawl_available = check_firecrawl_api_key()
web_available = check_web_api_key()
nous_available = check_auxiliary_model()
if not firecrawl_available:
print("❌ FIRECRAWL_API_KEY environment variable not set")
print("Please set your API key: export FIRECRAWL_API_KEY='your-key-here'")
print("Get API key at: https://firecrawl.dev/")
if web_available:
backend = _get_backend()
print(f"✅ Web backend: {backend}")
if backend == "parallel":
print(" Using Parallel API (https://parallel.ai)")
elif backend == "tavily":
print(" Using Tavily API (https://tavily.com)")
else:
print(" Using Firecrawl API (https://firecrawl.dev)")
else:
print("✅ Firecrawl API key found")
print("❌ No web search backend configured")
print("Set PARALLEL_API_KEY, TAVILY_API_KEY, or FIRECRAWL_API_KEY")
if not nous_available:
print("❌ No auxiliary model available for LLM content processing")
print("Set OPENROUTER_API_KEY, configure Nous Portal, or set OPENAI_BASE_URL + OPENAI_API_KEY")
print("⚠️ Without an auxiliary model, LLM content processing will be disabled")
else:
print(f"✅ Auxiliary model available: {DEFAULT_SUMMARIZER_MODEL}")
if not firecrawl_available:
if not web_available:
exit(1)
print("🛠️ Web tools ready for use!")
if nous_available:
@ -1256,8 +1679,8 @@ registry.register(
toolset="web",
schema=WEB_SEARCH_SCHEMA,
handler=lambda args, **kw: web_search_tool(args.get("query", ""), limit=5),
check_fn=check_firecrawl_api_key,
requires_env=["FIRECRAWL_API_KEY"],
check_fn=check_web_api_key,
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "TAVILY_API_KEY"],
emoji="🔍",
)
registry.register(
@ -1266,8 +1689,8 @@ registry.register(
schema=WEB_EXTRACT_SCHEMA,
handler=lambda args, **kw: web_extract_tool(
args.get("urls", [])[:5] if isinstance(args.get("urls"), list) else [], "markdown"),
check_fn=check_firecrawl_api_key,
requires_env=["FIRECRAWL_API_KEY"],
check_fn=check_web_api_key,
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "TAVILY_API_KEY"],
is_async=True,
emoji="📄",
)

285
tools/website_policy.py Normal file
View file

@ -0,0 +1,285 @@
"""Website access policy helpers for URL-capable tools.
This module loads a user-managed website blocklist from ~/.hermes/config.yaml
and optional shared list files. It is intentionally lightweight so web/browser
tools can enforce URL policy without pulling in the heavier CLI config stack.
Policy is cached in memory with a short TTL so config changes take effect
quickly without re-reading the file on every URL check.
"""
from __future__ import annotations
import fnmatch
import logging
import os
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
_DEFAULT_WEBSITE_BLOCKLIST = {
"enabled": False,
"domains": [],
"shared_files": [],
}
# Cache: parsed policy + timestamp. Avoids re-reading config.yaml on every
# URL check (a web_crawl with 50 pages would otherwise mean 51 YAML parses).
_CACHE_TTL_SECONDS = 30.0
_cache_lock = threading.Lock()
_cached_policy: Optional[Dict[str, Any]] = None
_cached_policy_path: Optional[str] = None
_cached_policy_time: float = 0.0
def _get_hermes_home() -> Path:
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
def _get_default_config_path() -> Path:
return _get_hermes_home() / "config.yaml"
class WebsitePolicyError(Exception):
"""Raised when a website policy file is malformed."""
def _normalize_host(host: str) -> str:
return (host or "").strip().lower().rstrip(".")
def _normalize_rule(rule: Any) -> Optional[str]:
if not isinstance(rule, str):
return None
value = rule.strip().lower()
if not value or value.startswith("#"):
return None
if "://" in value:
parsed = urlparse(value)
value = parsed.netloc or parsed.path
value = value.split("/", 1)[0].strip().rstrip(".")
if value.startswith("www."):
value = value[4:]
return value or None
def _iter_blocklist_file_rules(path: Path) -> List[str]:
"""Load rules from a shared blocklist file.
Missing or unreadable files log a warning and return an empty list
rather than raising a bad file path should not disable all web tools.
"""
try:
raw = path.read_text(encoding="utf-8")
except FileNotFoundError:
logger.warning("Shared blocklist file not found (skipping): %s", path)
return []
except (OSError, UnicodeDecodeError) as exc:
logger.warning("Failed to read shared blocklist file %s (skipping): %s", path, exc)
return []
rules: List[str] = []
for line in raw.splitlines():
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
normalized = _normalize_rule(stripped)
if normalized:
rules.append(normalized)
return rules
def _load_policy_config(config_path: Optional[Path] = None) -> Dict[str, Any]:
config_path = config_path or _get_default_config_path()
if not config_path.exists():
return dict(_DEFAULT_WEBSITE_BLOCKLIST)
try:
import yaml
except ImportError:
logger.debug("PyYAML not installed — website blocklist disabled")
return dict(_DEFAULT_WEBSITE_BLOCKLIST)
try:
with open(config_path, encoding="utf-8") as f:
config = yaml.safe_load(f) or {}
except yaml.YAMLError as exc:
raise WebsitePolicyError(f"Invalid config YAML at {config_path}: {exc}") from exc
except OSError as exc:
raise WebsitePolicyError(f"Failed to read config file {config_path}: {exc}") from exc
if not isinstance(config, dict):
raise WebsitePolicyError("config root must be a mapping")
security = config.get("security", {})
if security is None:
security = {}
if not isinstance(security, dict):
raise WebsitePolicyError("security must be a mapping")
website_blocklist = security.get("website_blocklist", {})
if website_blocklist is None:
website_blocklist = {}
if not isinstance(website_blocklist, dict):
raise WebsitePolicyError("security.website_blocklist must be a mapping")
policy = dict(_DEFAULT_WEBSITE_BLOCKLIST)
policy.update(website_blocklist)
return policy
def load_website_blocklist(config_path: Optional[Path] = None) -> Dict[str, Any]:
"""Load and return the parsed website blocklist policy.
Results are cached for ``_CACHE_TTL_SECONDS`` to avoid re-reading
config.yaml on every URL check. Pass an explicit ``config_path``
to bypass the cache (used by tests).
"""
global _cached_policy, _cached_policy_path, _cached_policy_time
resolved_path = str(config_path) if config_path else "__default__"
now = time.monotonic()
# Return cached policy if still fresh and same path
if config_path is None:
with _cache_lock:
if (
_cached_policy is not None
and _cached_policy_path == resolved_path
and (now - _cached_policy_time) < _CACHE_TTL_SECONDS
):
return _cached_policy
config_path = config_path or _get_default_config_path()
policy = _load_policy_config(config_path)
raw_domains = policy.get("domains", []) or []
if not isinstance(raw_domains, list):
raise WebsitePolicyError("security.website_blocklist.domains must be a list")
raw_shared_files = policy.get("shared_files", []) or []
if not isinstance(raw_shared_files, list):
raise WebsitePolicyError("security.website_blocklist.shared_files must be a list")
enabled = policy.get("enabled", True)
if not isinstance(enabled, bool):
raise WebsitePolicyError("security.website_blocklist.enabled must be a boolean")
rules: List[Dict[str, str]] = []
seen: set[Tuple[str, str]] = set()
for raw_rule in raw_domains:
normalized = _normalize_rule(raw_rule)
if normalized and ("config", normalized) not in seen:
rules.append({"pattern": normalized, "source": "config"})
seen.add(("config", normalized))
for shared_file in raw_shared_files:
if not isinstance(shared_file, str) or not shared_file.strip():
continue
path = Path(shared_file).expanduser()
if not path.is_absolute():
path = (_get_hermes_home() / path).resolve()
for normalized in _iter_blocklist_file_rules(path):
key = (str(path), normalized)
if key in seen:
continue
rules.append({"pattern": normalized, "source": str(path)})
seen.add(key)
result = {"enabled": enabled, "rules": rules}
# Cache the result (only for the default path — explicit paths are tests)
if config_path == _get_default_config_path():
with _cache_lock:
_cached_policy = result
_cached_policy_path = "__default__"
_cached_policy_time = now
return result
def invalidate_cache() -> None:
"""Force the next ``check_website_access`` call to re-read config."""
global _cached_policy
with _cache_lock:
_cached_policy = None
def _match_host_against_rule(host: str, pattern: str) -> bool:
if not host or not pattern:
return False
if pattern.startswith("*."):
return fnmatch.fnmatch(host, pattern)
return host == pattern or host.endswith(f".{pattern}")
def _extract_host_from_urlish(url: str) -> str:
parsed = urlparse(url)
host = _normalize_host(parsed.hostname or parsed.netloc)
if host:
return host
if "://" not in url:
schemeless = urlparse(f"//{url}")
host = _normalize_host(schemeless.hostname or schemeless.netloc)
if host:
return host
return ""
def check_website_access(url: str, config_path: Optional[Path] = None) -> Optional[Dict[str, str]]:
"""Check whether a URL is allowed by the website blocklist policy.
Returns ``None`` if access is allowed, or a dict with block metadata
(``host``, ``rule``, ``source``, ``message``) if blocked.
Never raises on policy errors logs a warning and returns ``None``
(fail-open) so a config typo doesn't break all web tools. Pass
``config_path`` explicitly (tests) to get strict error propagation.
"""
# Fast path: if no explicit config_path and the cached policy is disabled
# or empty, skip all work (no YAML read, no host extraction).
if config_path is None:
with _cache_lock:
if _cached_policy is not None and not _cached_policy.get("enabled"):
return None
host = _extract_host_from_urlish(url)
if not host:
return None
try:
policy = load_website_blocklist(config_path)
except WebsitePolicyError as exc:
if config_path is not None:
raise # Tests pass explicit paths — let errors propagate
logger.warning("Website policy config error (failing open): %s", exc)
return None
except Exception as exc:
logger.warning("Unexpected error loading website policy (failing open): %s", exc)
return None
if not policy.get("enabled"):
return None
for rule in policy.get("rules", []):
pattern = rule.get("pattern", "")
if _match_host_against_rule(host, pattern):
logger.info("Blocked URL %s — matched rule '%s' from %s",
url, pattern, rule.get("source", "config"))
return {
"url": url,
"host": host,
"rule": pattern,
"source": rule.get("source", "config"),
"message": (
f"Blocked by website policy: '{host}' matched rule '{pattern}'"
f" from {rule.get('source', 'config')}"
),
}
return None

View file

@ -130,6 +130,12 @@ TOOLSETS = {
"includes": []
},
"messaging": {
"description": "Cross-platform messaging: send messages to Telegram, Discord, Slack, SMS, etc.",
"tools": ["send_message"],
"includes": []
},
"rl": {
"description": "RL training tools for running reinforcement learning on Tinker-Atropos",
"tools": [
@ -293,7 +299,7 @@ TOOLSETS = {
},
"hermes-sms": {
"description": "SMS bot toolset - interact with Hermes via SMS (Telnyx)",
"description": "SMS bot toolset - interact with Hermes via SMS (Twilio)",
"tools": _HERMES_CORE_TOOLS,
"includes": []
},

View file

@ -49,6 +49,9 @@ hermes setup # Or configure everything at once
| **Kimi / Moonshot** | Moonshot-hosted coding and chat models | Set `KIMI_API_KEY` |
| **MiniMax** | International MiniMax endpoint | Set `MINIMAX_API_KEY` |
| **MiniMax China** | China-region MiniMax endpoint | Set `MINIMAX_CN_API_KEY` |
| **Alibaba Cloud** | Qwen models via DashScope | Set `DASHSCOPE_API_KEY` |
| **Kilo Code** | KiloCode-hosted models | Set `KILOCODE_API_KEY` |
| **Vercel AI Gateway** | Vercel AI Gateway routing | Set `AI_GATEWAY_API_KEY` |
| **Custom Endpoint** | VLLM, SGLang, or any OpenAI-compatible API | Set base URL + API key |
:::tip

View file

@ -32,6 +32,8 @@ All variables go in `~/.hermes/.env`. You can also set them with `hermes config
| `KILOCODE_BASE_URL` | Override Kilo Code base URL (default: `https://api.kilo.ai/api/gateway`) |
| `ANTHROPIC_API_KEY` | Anthropic Console API key ([console.anthropic.com](https://console.anthropic.com/)) |
| `ANTHROPIC_TOKEN` | Manual or legacy Anthropic OAuth/setup-token override |
| `DASHSCOPE_API_KEY` | Alibaba Cloud DashScope API key for Qwen models ([modelstudio.console.alibabacloud.com](https://modelstudio.console.alibabacloud.com/)) |
| `DASHSCOPE_BASE_URL` | Custom DashScope base URL (default: international endpoint) |
| `CLAUDE_CODE_OAUTH_TOKEN` | Explicit Claude Code token override if you export one manually |
| `HERMES_MODEL` | Preferred model name (checked before `LLM_MODEL`, used by gateway) |
| `LLM_MODEL` | Default model name (fallback when not set in config.yaml) |
@ -46,7 +48,7 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
| Variable | Description |
|----------|-------------|
| `HERMES_INFERENCE_PROVIDER` | Override provider selection: `auto`, `openrouter`, `nous`, `openai-codex`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `kilocode` (default: `auto`) |
| `HERMES_INFERENCE_PROVIDER` | Override provider selection: `auto`, `openrouter`, `nous`, `openai-codex`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `kilocode`, `alibaba` (default: `auto`) |
| `HERMES_PORTAL_BASE_URL` | Override Nous Portal URL (for development/testing) |
| `NOUS_INFERENCE_BASE_URL` | Override Nous inference API URL |
| `HERMES_NOUS_MIN_KEY_TTL_SECONDS` | Min agent key TTL before re-mint (default: 1800 = 30min) |
@ -59,10 +61,13 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
| Variable | Description |
|----------|-------------|
| `PARALLEL_API_KEY` | AI-native web search ([parallel.ai](https://parallel.ai/)) |
| `FIRECRAWL_API_KEY` | Web scraping ([firecrawl.dev](https://firecrawl.dev/)) |
| `FIRECRAWL_API_URL` | Custom Firecrawl API endpoint for self-hosted instances (optional) |
| `BROWSERBASE_API_KEY` | Browser automation ([browserbase.com](https://browserbase.com/)) |
| `BROWSERBASE_PROJECT_ID` | Browserbase project ID |
| `BROWSER_USE_API_KEY` | Browser Use cloud browser API key ([browser-use.com](https://browser-use.com/)) |
| `BROWSER_CDP_URL` | Chrome DevTools Protocol URL for local browser (set via `/browser connect`, e.g. `ws://localhost:9222`) |
| `BROWSER_INACTIVITY_TIMEOUT` | Browser session inactivity timeout in seconds |
| `FAL_KEY` | Image generation ([fal.ai](https://fal.ai/)) |
| `GROQ_API_KEY` | Groq Whisper STT API key ([groq.com](https://groq.com/)) |
@ -151,6 +156,14 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
| `SIGNAL_HOME_CHANNEL_NAME` | Display name for the Signal home channel |
| `SIGNAL_IGNORE_STORIES` | Ignore Signal stories/status updates |
| `SIGNAL_ALLOW_ALL_USERS` | Allow all Signal users without an allowlist |
| `TWILIO_ACCOUNT_SID` | Twilio Account SID (shared with telephony skill) |
| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill) |
| `TWILIO_PHONE_NUMBER` | Twilio phone number in E.164 format (shared with telephony skill) |
| `SMS_WEBHOOK_PORT` | Webhook listener port for inbound SMS (default: `8080`) |
| `SMS_ALLOWED_USERS` | Comma-separated E.164 phone numbers allowed to chat |
| `SMS_ALLOW_ALL_USERS` | Allow all SMS senders without an allowlist |
| `SMS_HOME_CHANNEL` | Phone number for cron job / notification delivery |
| `SMS_HOME_CHANNEL_NAME` | Display name for the SMS home channel |
| `EMAIL_ADDRESS` | Email address for the Email gateway adapter |
| `EMAIL_PASSWORD` | Password or app password for the email account |
| `EMAIL_IMAP_HOST` | IMAP hostname for the email adapter |
@ -162,6 +175,21 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
| `EMAIL_HOME_ADDRESS_NAME` | Display name for the email home target |
| `EMAIL_POLL_INTERVAL` | Email polling interval in seconds |
| `EMAIL_ALLOW_ALL_USERS` | Allow all inbound email senders |
| `DINGTALK_CLIENT_ID` | DingTalk bot AppKey from developer portal ([open.dingtalk.com](https://open.dingtalk.com)) |
| `DINGTALK_CLIENT_SECRET` | DingTalk bot AppSecret from developer portal |
| `DINGTALK_ALLOWED_USERS` | Comma-separated DingTalk user IDs allowed to message the bot |
| `MATTERMOST_URL` | Mattermost server URL (e.g. `https://mm.example.com`) |
| `MATTERMOST_TOKEN` | Bot token or personal access token for Mattermost |
| `MATTERMOST_ALLOWED_USERS` | Comma-separated Mattermost user IDs allowed to message the bot |
| `MATTERMOST_HOME_CHANNEL` | Channel ID for proactive message delivery (cron, notifications) |
| `MATTERMOST_REPLY_MODE` | Reply style: `thread` (threaded replies) or `off` (flat messages, default) |
| `MATRIX_HOMESERVER` | Matrix homeserver URL (e.g. `https://matrix.org`) |
| `MATRIX_ACCESS_TOKEN` | Matrix access token for bot authentication |
| `MATRIX_USER_ID` | Matrix user ID (e.g. `@hermes:matrix.org`) — required for password login, optional with access token |
| `MATRIX_PASSWORD` | Matrix password (alternative to access token) |
| `MATRIX_ALLOWED_USERS` | Comma-separated Matrix user IDs allowed to message the bot (e.g. `@alice:matrix.org`) |
| `MATRIX_HOME_ROOM` | Room ID for proactive message delivery (e.g. `!abc123:matrix.org`) |
| `MATRIX_ENCRYPTION` | Enable end-to-end encryption (`true`/`false`, default: `false`) |
| `HASS_TOKEN` | Home Assistant Long-Lived Access Token (enables HA platform + tools) |
| `HASS_URL` | Home Assistant URL (default: `http://homeassistant.local:8123`) |
| `MESSAGING_CWD` | Working directory for terminal commands in messaging mode (default: `~`) |

View file

@ -52,8 +52,9 @@ Type `/` in the CLI to open the autocomplete menu. Built-in commands are case-in
| Command | Description |
|---------|-------------|
| `/tools` | List available tools |
| `/tools [list\|disable\|enable] [name...]` | Manage tools: list available tools, or disable/enable specific tools for the current session. Disabling a tool removes it from the agent's toolset and triggers a session reset. |
| `/toolsets` | List available toolsets |
| `/browser [connect\|disconnect\|status]` | Manage local Chrome CDP connection. `connect` attaches browser tools to a running Chrome instance (default: `ws://localhost:9222`). `disconnect` detaches. `status` shows current connection. Auto-launches Chrome if no debugger is detected. |
| `/skills` | Search, install, inspect, or manage skills from online registries |
| `/cron` | Manage scheduled tasks (list, add/create, edit, pause, resume, run, remove) |
| `/reload-mcp` | Reload MCP servers from config.yaml |
@ -118,7 +119,7 @@ The messaging gateway supports the following built-in commands inside Telegram,
## Notes
- `/skin`, `/tools`, `/toolsets`, `/config`, `/prompt`, `/cron`, `/skills`, `/platforms`, `/paste`, and `/verbose` are **CLI-only** commands.
- `/skin`, `/tools`, `/toolsets`, `/browser`, `/config`, `/prompt`, `/cron`, `/skills`, `/platforms`, `/paste`, and `/verbose` are **CLI-only** commands.
- `/status`, `/stop`, `/sethome`, `/resume`, and `/update` are **messaging-only** commands.
- `/background`, `/voice`, `/reload-mcp`, and `/rollback` work in **both** the CLI and the messaging gateway.
- `/voice join`, `/voice channel`, and `/voice leave` are only meaningful on Discord.

View file

@ -70,7 +70,9 @@ You need at least one way to connect to an LLM. Use `hermes model` to switch pro
| **Kimi / Moonshot** | `KIMI_API_KEY` in `~/.hermes/.env` (provider: `kimi-coding`) |
| **MiniMax** | `MINIMAX_API_KEY` in `~/.hermes/.env` (provider: `minimax`) |
| **MiniMax China** | `MINIMAX_CN_API_KEY` in `~/.hermes/.env` (provider: `minimax-cn`) |
| **Alibaba Cloud** | `DASHSCOPE_API_KEY` in `~/.hermes/.env` (provider: `alibaba`, aliases: `dashscope`, `qwen`) |
| **Kilo Code** | `KILOCODE_API_KEY` in `~/.hermes/.env` (provider: `kilocode`) |
| **Alibaba Cloud** | `DASHSCOPE_API_KEY` in `~/.hermes/.env` (provider: `alibaba`) |
| **Custom Endpoint** | `hermes model` (saved in `config.yaml`) or `OPENAI_BASE_URL` + `OPENAI_API_KEY` in `~/.hermes/.env` |
:::info Codex Note
@ -135,16 +137,20 @@ hermes chat --provider minimax --model MiniMax-Text-01
# MiniMax (China endpoint)
hermes chat --provider minimax-cn --model MiniMax-Text-01
# Requires: MINIMAX_CN_API_KEY in ~/.hermes/.env
# Alibaba Cloud / DashScope (Qwen models)
hermes chat --provider alibaba --model qwen-plus
# Requires: DASHSCOPE_API_KEY in ~/.hermes/.env
```
Or set the provider permanently in `config.yaml`:
```yaml
model:
provider: "zai" # or: kimi-coding, minimax, minimax-cn
provider: "zai" # or: kimi-coding, minimax, minimax-cn, alibaba
default: "glm-4-plus"
```
Base URLs can be overridden with `GLM_BASE_URL`, `KIMI_BASE_URL`, `MINIMAX_BASE_URL`, or `MINIMAX_CN_BASE_URL` environment variables.
Base URLs can be overridden with `GLM_BASE_URL`, `KIMI_BASE_URL`, `MINIMAX_BASE_URL`, `MINIMAX_CN_BASE_URL`, or `DASHSCOPE_BASE_URL` environment variables.
## Custom & Self-Hosted LLM Providers
@ -872,6 +878,7 @@ This controls both the `text_to_speech` tool and spoken replies in voice mode (`
display:
tool_progress: all # off | new | all | verbose
skin: default # Built-in or custom CLI skin (see user-guide/features/skins)
theme_mode: auto # auto | light | dark — color scheme for skin-aware rendering
personality: "kawaii" # Legacy cosmetic field still surfaced in some summaries
compact: false # Compact output mode (less whitespace)
resume_display: full # full (show previous messages on resume) | minimal (one-liner only)
@ -881,6 +888,18 @@ display:
background_process_notifications: all # all | result | error | off (gateway only)
```
### Theme mode
The `theme_mode` setting controls whether skins render in light or dark mode:
| Mode | Behavior |
|------|----------|
| `auto` (default) | Detects your terminal's background color automatically. Falls back to `dark` if detection fails. |
| `light` | Forces light-mode skin colors. Skins that define a `colors_light` override use those colors instead of the default dark-mode palette. |
| `dark` | Forces dark-mode skin colors. |
This works with any skin — built-in or custom. Skin authors can provide `colors_light` in their skin definition for optimal light-terminal appearance.
| Mode | What you see |
|------|-------------|
| `off` | Silent — just the final response |
@ -1055,6 +1074,54 @@ browser:
record_sessions: false # Auto-record browser sessions as WebM videos to ~/.hermes/browser_recordings/
```
The browser toolset supports multiple providers. See the [Browser feature page](/docs/user-guide/features/browser) for details on Browserbase, Browser Use, and local Chrome CDP setup.
## Website Blocklist
Block specific domains from being accessed by the agent's web and browser tools:
```yaml
website_blocklist:
enabled: false # Enable URL blocking (default: false)
domains: # List of blocked domain patterns
- "*.internal.company.com"
- "admin.example.com"
- "*.local"
shared_files: # Load additional rules from external files
- "/etc/hermes/blocked-sites.txt"
```
When enabled, any URL matching a blocked domain pattern is rejected before the web or browser tool executes. This applies to `web_search`, `web_extract`, `browser_navigate`, and any tool that accesses URLs.
Domain rules support:
- Exact domains: `admin.example.com`
- Wildcard subdomains: `*.internal.company.com` (blocks all subdomains)
- TLD wildcards: `*.local`
Shared files contain one domain rule per line (blank lines and `#` comments are ignored). Missing or unreadable files log a warning but don't disable other web tools.
The policy is cached for 30 seconds, so config changes take effect quickly without restart.
## Smart Approvals
Control how Hermes handles potentially dangerous commands:
```yaml
approval_mode: ask # ask | smart | off
```
| Mode | Behavior |
|------|----------|
| `ask` (default) | Prompt the user before executing any flagged command. In the CLI, shows an interactive approval dialog. In messaging, queues a pending approval request. |
| `smart` | Use an auxiliary LLM to assess whether a flagged command is actually dangerous. Low-risk commands are auto-approved with session-level persistence. Genuinely risky commands are escalated to the user. |
| `off` | Skip all approval checks. Equivalent to `HERMES_YOLO_MODE=true`. **Use with caution.** |
Smart mode is particularly useful for reducing approval fatigue — it lets the agent work more autonomously on safe operations while still catching genuinely destructive commands.
:::warning
Setting `approval_mode: off` disables all safety checks for terminal commands. Only use this in trusted, sandboxed environments.
:::
## Checkpoints
Automatic filesystem snapshots before destructive file operations. See the [Checkpoints feature page](/docs/user-guide/features/checkpoints) for details.

View file

@ -1,27 +1,30 @@
---
title: Browser Automation
description: Control cloud browsers with Browserbase integration for web interaction, form filling, scraping, and more.
description: Control browsers with multiple providers, local Chrome via CDP, or cloud browsers for web interaction, form filling, scraping, and more.
sidebar_label: Browser
sidebar_position: 5
---
# Browser Automation
Hermes Agent includes a full browser automation toolset that can run in two modes:
Hermes Agent includes a full browser automation toolset with multiple backend options:
- **Browserbase cloud mode** via [Browserbase](https://browserbase.com) for managed cloud browsers and anti-bot tooling
- **Browser Use cloud mode** via [Browser Use](https://browser-use.com) as an alternative cloud browser provider
- **Local Chrome via CDP** — connect browser tools to your own Chrome instance using `/browser connect`
- **Local browser mode** via the `agent-browser` CLI and a local Chromium installation
In both modes, the agent can navigate websites, interact with page elements, fill forms, and extract information.
In all modes, the agent can navigate websites, interact with page elements, fill forms, and extract information.
## Overview
The browser tools use the `agent-browser` CLI. In Browserbase mode, `agent-browser` connects to Browserbase cloud sessions. In local mode, it drives a local Chromium installation. Pages are represented as **accessibility trees** (text-based snapshots), making them ideal for LLM agents. Interactive elements get ref IDs (like `@e1`, `@e2`) that the agent uses for clicking and typing.
Pages are represented as **accessibility trees** (text-based snapshots), making them ideal for LLM agents. Interactive elements get ref IDs (like `@e1`, `@e2`) that the agent uses for clicking and typing.
Key capabilities:
- **Cloud execution** — no local browser needed
- **Built-in stealth** — random fingerprints, CAPTCHA solving, residential proxies
- **Multi-provider cloud execution** — Browserbase or Browser Use, no local browser needed
- **Local Chrome integration** — attach to your running Chrome via CDP for hands-on browsing
- **Built-in stealth** — random fingerprints, CAPTCHA solving, residential proxies (Browserbase)
- **Session isolation** — each task gets its own browser session
- **Automatic cleanup** — inactive sessions are closed after a timeout
- **Vision analysis** — screenshot + AI analysis for visual understanding
@ -40,9 +43,48 @@ BROWSERBASE_PROJECT_ID=your-project-id-here
Get your credentials at [browserbase.com](https://browserbase.com).
### Browser Use cloud mode
To use Browser Use as your cloud browser provider, add:
```bash
# Add to ~/.hermes/.env
BROWSER_USE_API_KEY=***
```
Get your API key at [browser-use.com](https://browser-use.com). Browser Use provides a cloud browser via its REST API. If both Browserbase and Browser Use credentials are set, Browserbase takes priority.
### Local Chrome via CDP (`/browser connect`)
Instead of a cloud provider, you can attach Hermes browser tools to your own running Chrome instance via the Chrome DevTools Protocol (CDP). This is useful when you want to see what the agent is doing in real-time, interact with pages that require your own cookies/sessions, or avoid cloud browser costs.
In the CLI, use:
```
/browser connect # Connect to Chrome at ws://localhost:9222
/browser connect ws://host:port # Connect to a specific CDP endpoint
/browser status # Check current connection
/browser disconnect # Detach and return to cloud/local mode
```
If Chrome isn't already running with remote debugging, Hermes will attempt to auto-launch it with `--remote-debugging-port=9222`.
:::tip
To start Chrome manually with CDP enabled:
```bash
# Linux
google-chrome --remote-debugging-port=9222
# macOS
"/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" --remote-debugging-port=9222
```
:::
When connected via CDP, all browser tools (`browser_navigate`, `browser_click`, etc.) operate on your live Chrome instance instead of spinning up a cloud session.
### Local browser mode
If you do **not** set Browserbase credentials, Hermes can still use the browser tools through a local Chromium install driven by `agent-browser`.
If you do **not** set any cloud credentials and don't use `/browser connect`, Hermes can still use the browser tools through a local Chromium install driven by `agent-browser`.
### Optional Environment Variables
@ -232,10 +274,8 @@ If paid features aren't available on your plan, Hermes automatically falls back
## Limitations
- **Requires Browserbase account** — no local browser fallback
- **Requires `agent-browser` CLI** — must be installed via npm
- **Text-based interaction** — relies on accessibility tree, not pixel coordinates
- **Snapshot size** — large pages may be truncated or LLM-summarized at 8000 characters
- **Session timeout**sessions expire based on your Browserbase plan settings
- **Cost**each session consumes Browserbase credits; use `browser_close` when done
- **Session timeout**cloud sessions expire based on your provider's plan settings
- **Cost**cloud sessions consume provider credits; use `browser_close` when done. Use `/browser connect` for free local browsing.
- **No file downloads** — cannot download files from the browser

View file

@ -0,0 +1,192 @@
---
sidebar_position: 10
title: "DingTalk"
description: "Set up Hermes Agent as a DingTalk chatbot"
---
# DingTalk Setup
Hermes Agent integrates with DingTalk (钉钉) as a chatbot, letting you chat with your AI assistant through direct messages or group chats. The bot connects via DingTalk's Stream Mode — a long-lived WebSocket connection that requires no public URL or webhook server — and replies using markdown-formatted messages through DingTalk's session webhook API.
Before setup, here's the part most people want to know: how Hermes behaves once it's in your DingTalk workspace.
## How Hermes Behaves
| Context | Behavior |
|---------|----------|
| **DMs (1:1 chat)** | Hermes responds to every message. No `@mention` needed. Each DM has its own session. |
| **Group chats** | Hermes responds when you `@mention` it. Without a mention, Hermes ignores the message. |
| **Shared groups with multiple users** | By default, Hermes isolates session history per user inside the group. Two people talking in the same group do not share one transcript unless you explicitly disable that. |
### Session Model in DingTalk
By default:
- each DM gets its own session
- each user in a shared group chat gets their own session inside that group
This is controlled by `config.yaml`:
```yaml
group_sessions_per_user: true
```
Set it to `false` only if you explicitly want one shared conversation for the entire group:
```yaml
group_sessions_per_user: false
```
This guide walks you through the full setup process — from creating your DingTalk bot to sending your first message.
## Prerequisites
Install the required Python packages:
```bash
pip install dingtalk-stream httpx
```
- `dingtalk-stream` — DingTalk's official SDK for Stream Mode (WebSocket-based real-time messaging)
- `httpx` — async HTTP client used for sending replies via session webhooks
## Step 1: Create a DingTalk App
1. Go to the [DingTalk Developer Console](https://open-dev.dingtalk.com/).
2. Log in with your DingTalk admin account.
3. Click **Application Development****Custom Apps****Create App via H5 Micro-App** (or **Robot** depending on your console version).
4. Fill in:
- **App Name**: e.g., `Hermes Agent`
- **Description**: optional
5. After creating, navigate to **Credentials & Basic Info** to find your **Client ID** (AppKey) and **Client Secret** (AppSecret). Copy both.
:::warning[Credentials shown only once]
The Client Secret is only displayed once when you create the app. If you lose it, you'll need to regenerate it. Never share these credentials publicly or commit them to Git.
:::
## Step 2: Enable the Robot Capability
1. In your app's settings page, go to **Add Capability****Robot**.
2. Enable the robot capability.
3. Under **Message Reception Mode**, select **Stream Mode** (recommended — no public URL needed).
:::tip
Stream Mode is the recommended setup. It uses a long-lived WebSocket connection initiated from your machine, so you don't need a public IP, domain name, or webhook endpoint. This works behind NAT, firewalls, and on local machines.
:::
## Step 3: Find Your DingTalk User ID
Hermes Agent uses your DingTalk User ID to control who can interact with the bot. DingTalk User IDs are alphanumeric strings set by your organization's admin.
To find yours:
1. Ask your DingTalk organization admin — User IDs are configured in the DingTalk admin console under **Contacts****Members**.
2. Alternatively, the bot logs the `sender_id` for each incoming message. Start the gateway, send the bot a message, then check the logs for your ID.
## Step 4: Configure Hermes Agent
### Option A: Interactive Setup (Recommended)
Run the guided setup command:
```bash
hermes gateway setup
```
Select **DingTalk** when prompted, then paste your Client ID, Client Secret, and allowed user IDs when asked.
### Option B: Manual Configuration
Add the following to your `~/.hermes/.env` file:
```bash
# Required
DINGTALK_CLIENT_ID=your-app-key
DINGTALK_CLIENT_SECRET=your-app-secret
# Security: restrict who can interact with the bot
DINGTALK_ALLOWED_USERS=user-id-1
# Multiple allowed users (comma-separated)
# DINGTALK_ALLOWED_USERS=user-id-1,user-id-2
```
Optional behavior settings in `~/.hermes/config.yaml`:
```yaml
group_sessions_per_user: true
```
- `group_sessions_per_user: true` keeps each participant's context isolated inside shared group chats
### Start the Gateway
Once configured, start the DingTalk gateway:
```bash
hermes gateway
```
The bot should connect to DingTalk's Stream Mode within a few seconds. Send it a message — either a DM or in a group where it's been added — to test.
:::tip
You can run `hermes gateway` in the background or as a systemd service for persistent operation. See the deployment docs for details.
:::
## Troubleshooting
### Bot is not responding to messages
**Cause**: The robot capability isn't enabled, or `DINGTALK_ALLOWED_USERS` doesn't include your User ID.
**Fix**: Verify the robot capability is enabled in your app settings and that Stream Mode is selected. Check that your User ID is in `DINGTALK_ALLOWED_USERS`. Restart the gateway.
### "dingtalk-stream not installed" error
**Cause**: The `dingtalk-stream` Python package is not installed.
**Fix**: Install it:
```bash
pip install dingtalk-stream httpx
```
### "DINGTALK_CLIENT_ID and DINGTALK_CLIENT_SECRET required"
**Cause**: The credentials aren't set in your environment or `.env` file.
**Fix**: Verify `DINGTALK_CLIENT_ID` and `DINGTALK_CLIENT_SECRET` are set correctly in `~/.hermes/.env`. The Client ID is your AppKey, and the Client Secret is your AppSecret from the DingTalk Developer Console.
### Stream disconnects / reconnection loops
**Cause**: Network instability, DingTalk platform maintenance, or credential issues.
**Fix**: The adapter automatically reconnects with exponential backoff (2s → 5s → 10s → 30s → 60s). Check that your credentials are valid and your app hasn't been deactivated. Verify your network allows outbound WebSocket connections.
### Bot is offline
**Cause**: The Hermes gateway isn't running, or it failed to connect.
**Fix**: Check that `hermes gateway` is running. Look at the terminal output for error messages. Common issues: wrong credentials, app deactivated, `dingtalk-stream` or `httpx` not installed.
### "No session_webhook available"
**Cause**: The bot tried to reply but doesn't have a session webhook URL. This typically happens if the webhook expired or the bot was restarted between receiving the message and sending the reply.
**Fix**: Send a new message to the bot — each incoming message provides a fresh session webhook for replies. This is a normal DingTalk limitation; the bot can only reply to messages it has received recently.
## Security
:::warning
Always set `DINGTALK_ALLOWED_USERS` to restrict who can interact with the bot. Without it, the gateway denies all users by default as a safety measure. Only add User IDs of people you trust — authorized users have full access to the agent's capabilities, including tool use and system access.
:::
For more information on securing your Hermes Agent deployment, see the [Security Guide](../security.md).
## Notes
- **Stream Mode**: No public URL, domain name, or webhook server needed. The connection is initiated from your machine via WebSocket, so it works behind NAT and firewalls.
- **Markdown responses**: Replies are formatted in DingTalk's markdown format for rich text display.
- **Message deduplication**: The adapter deduplicates messages with a 5-minute window to prevent processing the same message twice.
- **Auto-reconnection**: If the stream connection drops, the adapter automatically reconnects with exponential backoff.
- **Message length limit**: Responses are capped at 20,000 characters per message. Longer responses are truncated.

View file

@ -1,12 +1,12 @@
---
sidebar_position: 1
title: "Messaging Gateway"
description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, Home Assistant, or your browser — architecture and setup overview"
description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, or your browser — architecture and setup overview"
---
# Messaging Gateway
Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, Home Assistant, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages.
Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages.
For the full voice feature set — including CLI microphone mode, spoken replies in messaging, and Discord voice-channel conversations — see [Voice Mode](/docs/user-guide/features/voice-mode) and [Use Voice Mode with Hermes](/docs/guides/use-voice-mode-with-hermes).
@ -21,8 +21,12 @@ flowchart TB
wa[WhatsApp]
sl[Slack]
sig[Signal]
sms[SMS]
em[Email]
ha[Home Assistant]
mm[Mattermost]
mx[Matrix]
dt[DingTalk]
end
store["Session store<br/>per chat"]
@ -35,8 +39,12 @@ flowchart TB
wa --> store
sl --> store
sig --> store
sms --> store
em --> store
ha --> store
mm --> store
mx --> store
dt --> store
store --> agent
cron --> store
```
@ -129,7 +137,11 @@ Configure per-platform overrides in `~/.hermes/gateway.json`:
TELEGRAM_ALLOWED_USERS=123456789,987654321
DISCORD_ALLOWED_USERS=123456789012345678
SIGNAL_ALLOWED_USERS=+155****4567,+155****6543
SMS_ALLOWED_USERS=+155****4567,+155****6543
EMAIL_ALLOWED_USERS=trusted@example.com,colleague@work.com
MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c
MATRIX_ALLOWED_USERS=@alice:matrix.org
DINGTALK_ALLOWED_USERS=user-id-1
# Or allow
GATEWAY_ALLOWED_USERS=123456789,987654321
@ -288,8 +300,12 @@ Each platform has its own toolset:
| WhatsApp | `hermes-whatsapp` | Full tools including terminal |
| Slack | `hermes-slack` | Full tools including terminal |
| Signal | `hermes-signal` | Full tools including terminal |
| SMS | `hermes-sms` | Full tools including terminal |
| Email | `hermes-email` | Full tools including terminal |
| Home Assistant | `hermes-homeassistant` | Full tools + HA device control (ha_list_entities, ha_get_state, ha_call_service, ha_list_services) |
| Mattermost | `hermes-mattermost` | Full tools including terminal |
| Matrix | `hermes-matrix` | Full tools including terminal |
| DingTalk | `hermes-dingtalk` | Full tools including terminal |
## Next Steps
@ -298,5 +314,9 @@ Each platform has its own toolset:
- [Slack Setup](slack.md)
- [WhatsApp Setup](whatsapp.md)
- [Signal Setup](signal.md)
- [SMS Setup (Twilio)](sms.md)
- [Email Setup](email.md)
- [Home Assistant Integration](homeassistant.md)
- [Mattermost Setup](mattermost.md)
- [Matrix Setup](matrix.md)
- [DingTalk Setup](dingtalk.md)

View file

@ -0,0 +1,354 @@
---
sidebar_position: 9
title: "Matrix"
description: "Set up Hermes Agent as a Matrix bot"
---
# Matrix Setup
Hermes Agent integrates with Matrix, the open, federated messaging protocol. Matrix lets you run your own homeserver or use a public one like matrix.org — either way, you keep control of your communications. The bot connects via the `matrix-nio` Python SDK, processes messages through the Hermes Agent pipeline (including tool use, memory, and reasoning), and responds in real time. It supports text, file attachments, images, audio, video, and optional end-to-end encryption (E2EE).
Hermes works with any Matrix homeserver — Synapse, Conduit, Dendrite, or matrix.org.
Before setup, here's the part most people want to know: how Hermes behaves once it's connected.
## How Hermes Behaves
| Context | Behavior |
|---------|----------|
| **DMs** | Hermes responds to every message. No `@mention` needed. Each DM has its own session. |
| **Rooms** | Hermes responds to all messages in rooms it has joined. Room invites are auto-accepted. |
| **Threads** | Hermes supports Matrix threads (MSC3440). If you reply in a thread, Hermes keeps the thread context isolated from the main room timeline. |
| **Shared rooms with multiple users** | By default, Hermes isolates session history per user inside the room. Two people talking in the same room do not share one transcript unless you explicitly disable that. |
:::tip
The bot automatically joins rooms when invited. Just invite the bot's Matrix user to any room and it will join and start responding.
:::
### Session Model in Matrix
By default:
- each DM gets its own session
- each thread gets its own session namespace
- each user in a shared room gets their own session inside that room
This is controlled by `config.yaml`:
```yaml
group_sessions_per_user: true
```
Set it to `false` only if you explicitly want one shared conversation for the entire room:
```yaml
group_sessions_per_user: false
```
Shared sessions can be useful for a collaborative room, but they also mean:
- users share context growth and token costs
- one person's long tool-heavy task can bloat everyone else's context
- one person's in-flight run can interrupt another person's follow-up in the same room
This guide walks you through the full setup process — from creating your bot account to sending your first message.
## Step 1: Create a Bot Account
You need a Matrix user account for the bot. There are several ways to do this:
### Option A: Register on Your Homeserver (Recommended)
If you run your own homeserver (Synapse, Conduit, Dendrite):
1. Use the admin API or registration tool to create a new user:
```bash
# Synapse example
register_new_matrix_user -c /etc/synapse/homeserver.yaml http://localhost:8008
```
2. Choose a username like `hermes` — the full user ID will be `@hermes:your-server.org`.
### Option B: Use matrix.org or Another Public Homeserver
1. Go to [Element Web](https://app.element.io) and create a new account.
2. Pick a username for your bot (e.g., `hermes-bot`).
### Option C: Use Your Own Account
You can also run Hermes as your own user. This means the bot posts as you — useful for personal assistants.
## Step 2: Get an Access Token
Hermes needs an access token to authenticate with the homeserver. You have two options:
### Option A: Access Token (Recommended)
The most reliable way to get a token:
**Via Element:**
1. Log in to [Element](https://app.element.io) with the bot account.
2. Go to **Settings****Help & About**.
3. Scroll down and expand **Advanced** — the access token is displayed there.
4. **Copy it immediately.**
**Via the API:**
```bash
curl -X POST https://your-server/_matrix/client/v3/login \
-H "Content-Type: application/json" \
-d '{
"type": "m.login.password",
"user": "@hermes:your-server.org",
"password": "your-password"
}'
```
The response includes an `access_token` field — copy it.
:::warning[Keep your access token safe]
The access token gives full access to the bot's Matrix account. Never share it publicly or commit it to Git. If compromised, revoke it by logging out all sessions for that user.
:::
### Option B: Password Login
Instead of providing an access token, you can give Hermes the bot's user ID and password. Hermes will log in automatically on startup. This is simpler but means the password is stored in your `.env` file.
```bash
MATRIX_USER_ID=@hermes:your-server.org
MATRIX_PASSWORD=your-password
```
## Step 3: Find Your Matrix User ID
Hermes Agent uses your Matrix User ID to control who can interact with the bot. Matrix User IDs follow the format `@username:server`.
To find yours:
1. Open [Element](https://app.element.io) (or your preferred Matrix client).
2. Click your avatar → **Settings**.
3. Your User ID is displayed at the top of the profile (e.g., `@alice:matrix.org`).
:::tip
Matrix User IDs always start with `@` and contain a `:` followed by the server name. For example: `@alice:matrix.org`, `@bob:your-server.com`.
:::
## Step 4: Configure Hermes Agent
### Option A: Interactive Setup (Recommended)
Run the guided setup command:
```bash
hermes gateway setup
```
Select **Matrix** when prompted, then provide your homeserver URL, access token (or user ID + password), and allowed user IDs when asked.
### Option B: Manual Configuration
Add the following to your `~/.hermes/.env` file:
**Using an access token:**
```bash
# Required
MATRIX_HOMESERVER=https://matrix.example.org
MATRIX_ACCESS_TOKEN=***
# Optional: user ID (auto-detected from token if omitted)
# MATRIX_USER_ID=@hermes:matrix.example.org
# Security: restrict who can interact with the bot
MATRIX_ALLOWED_USERS=@alice:matrix.example.org
# Multiple allowed users (comma-separated)
# MATRIX_ALLOWED_USERS=@alice:matrix.example.org,@bob:matrix.example.org
```
**Using password login:**
```bash
# Required
MATRIX_HOMESERVER=https://matrix.example.org
MATRIX_USER_ID=@hermes:matrix.example.org
MATRIX_PASSWORD=***
# Security
MATRIX_ALLOWED_USERS=@alice:matrix.example.org
```
Optional behavior settings in `~/.hermes/config.yaml`:
```yaml
group_sessions_per_user: true
```
- `group_sessions_per_user: true` keeps each participant's context isolated inside shared rooms
### Start the Gateway
Once configured, start the Matrix gateway:
```bash
hermes gateway
```
The bot should connect to your homeserver and start syncing within a few seconds. Send it a message — either a DM or in a room it has joined — to test.
:::tip
You can run `hermes gateway` in the background or as a systemd service for persistent operation. See the deployment docs for details.
:::
## End-to-End Encryption (E2EE)
Hermes supports Matrix end-to-end encryption, so you can chat with your bot in encrypted rooms.
### Requirements
E2EE requires the `matrix-nio` library with encryption extras and the `libolm` C library:
```bash
# Install matrix-nio with E2EE support
pip install 'matrix-nio[e2e]'
# Or install with hermes extras
pip install 'hermes-agent[matrix]'
```
You also need `libolm` installed on your system:
```bash
# Debian/Ubuntu
sudo apt install libolm-dev
# macOS
brew install libolm
# Fedora
sudo dnf install libolm-devel
```
### Enable E2EE
Add to your `~/.hermes/.env`:
```bash
MATRIX_ENCRYPTION=true
```
When E2EE is enabled, Hermes:
- Stores encryption keys in `~/.hermes/matrix/store/`
- Uploads device keys on first connection
- Decrypts incoming messages and encrypts outgoing messages automatically
- Auto-joins encrypted rooms when invited
:::warning
If you delete the `~/.hermes/matrix/store/` directory, the bot loses its encryption keys. You'll need to verify the device again in your Matrix client. Back up this directory if you want to preserve encrypted sessions.
:::
:::info
If `matrix-nio[e2e]` is not installed or `libolm` is missing, the bot falls back to a plain (unencrypted) client automatically. You'll see a warning in the logs.
:::
## Home Room
You can designate a "home room" where the bot sends proactive messages (such as cron job output, reminders, and notifications). There are two ways to set it:
### Using the Slash Command
Type `/sethome` in any Matrix room where the bot is present. That room becomes the home room.
### Manual Configuration
Add this to your `~/.hermes/.env`:
```bash
MATRIX_HOME_ROOM=!abc123def456:matrix.example.org
```
:::tip
To find a Room ID: in Element, go to the room → **Settings****Advanced** → the **Internal room ID** is shown there (starts with `!`).
:::
## Troubleshooting
### Bot is not responding to messages
**Cause**: The bot hasn't joined the room, or `MATRIX_ALLOWED_USERS` doesn't include your User ID.
**Fix**: Invite the bot to the room — it auto-joins on invite. Verify your User ID is in `MATRIX_ALLOWED_USERS` (use the full `@user:server` format). Restart the gateway.
### "Failed to authenticate" / "whoami failed" on startup
**Cause**: The access token or homeserver URL is incorrect.
**Fix**: Verify `MATRIX_HOMESERVER` points to your homeserver (include `https://`, no trailing slash). Check that `MATRIX_ACCESS_TOKEN` is valid — try it with curl:
```bash
curl -H "Authorization: Bearer YOUR_TOKEN" \
https://your-server/_matrix/client/v3/account/whoami
```
If this returns your user info, the token is valid. If it returns an error, generate a new token.
### "matrix-nio not installed" error
**Cause**: The `matrix-nio` Python package is not installed.
**Fix**: Install it:
```bash
pip install 'matrix-nio[e2e]'
```
Or with Hermes extras:
```bash
pip install 'hermes-agent[matrix]'
```
### Encryption errors / "could not decrypt event"
**Cause**: Missing encryption keys, `libolm` not installed, or the bot's device isn't trusted.
**Fix**:
1. Verify `libolm` is installed on your system (see the E2EE section above).
2. Make sure `MATRIX_ENCRYPTION=true` is set in your `.env`.
3. In your Matrix client (Element), go to the bot's profile → **Sessions** → verify/trust the bot's device.
4. If the bot just joined an encrypted room, it can only decrypt messages sent *after* it joined. Older messages are inaccessible.
### Sync issues / bot falls behind
**Cause**: Long-running tool executions can delay the sync loop, or the homeserver is slow.
**Fix**: The sync loop automatically retries every 5 seconds on error. Check the Hermes logs for sync-related warnings. If the bot consistently falls behind, ensure your homeserver has adequate resources.
### Bot is offline
**Cause**: The Hermes gateway isn't running, or it failed to connect.
**Fix**: Check that `hermes gateway` is running. Look at the terminal output for error messages. Common issues: wrong homeserver URL, expired access token, homeserver unreachable.
### "User not allowed" / Bot ignores you
**Cause**: Your User ID isn't in `MATRIX_ALLOWED_USERS`.
**Fix**: Add your User ID to `MATRIX_ALLOWED_USERS` in `~/.hermes/.env` and restart the gateway. Use the full `@user:server` format.
## Security
:::warning
Always set `MATRIX_ALLOWED_USERS` to restrict who can interact with the bot. Without it, the gateway denies all users by default as a safety measure. Only add User IDs of people you trust — authorized users have full access to the agent's capabilities, including tool use and system access.
:::
For more information on securing your Hermes Agent deployment, see the [Security Guide](../security.md).
## Notes
- **Any homeserver**: Works with Synapse, Conduit, Dendrite, matrix.org, or any spec-compliant Matrix homeserver. No specific homeserver software required.
- **Federation**: If you're on a federated homeserver, the bot can communicate with users from other servers — just add their full `@user:server` IDs to `MATRIX_ALLOWED_USERS`.
- **Auto-join**: The bot automatically accepts room invites and joins. It starts responding immediately after joining.
- **Media support**: Hermes can send and receive images, audio, video, and file attachments. Media is uploaded to your homeserver using the Matrix content repository API.

View file

@ -0,0 +1,277 @@
---
sidebar_position: 8
title: "Mattermost"
description: "Set up Hermes Agent as a Mattermost bot"
---
# Mattermost Setup
Hermes Agent integrates with Mattermost as a bot, letting you chat with your AI assistant through direct messages or team channels. Mattermost is a self-hosted, open-source Slack alternative — you run it on your own infrastructure, keeping full control of your data. The bot connects via Mattermost's REST API (v4) and WebSocket for real-time events, processes messages through the Hermes Agent pipeline (including tool use, memory, and reasoning), and responds in real time. It supports text, file attachments, images, and slash commands.
No external Mattermost library is required — the adapter uses `aiohttp`, which is already a Hermes dependency.
Before setup, here's the part most people want to know: how Hermes behaves once it's in your Mattermost instance.
## How Hermes Behaves
| Context | Behavior |
|---------|----------|
| **DMs** | Hermes responds to every message. No `@mention` needed. Each DM has its own session. |
| **Public/private channels** | Hermes responds when you `@mention` it. Without a mention, Hermes ignores the message. |
| **Threads** | If `MATTERMOST_REPLY_MODE=thread`, Hermes replies in a thread under your message. Thread context stays isolated from the parent channel. |
| **Shared channels with multiple users** | By default, Hermes isolates session history per user inside the channel. Two people talking in the same channel do not share one transcript unless you explicitly disable that. |
:::tip
If you want Hermes to reply as threaded conversations (nested under your original message), set `MATTERMOST_REPLY_MODE=thread`. The default is `off`, which sends flat messages in the channel.
:::
### Session Model in Mattermost
By default:
- each DM gets its own session
- each thread gets its own session namespace
- each user in a shared channel gets their own session inside that channel
This is controlled by `config.yaml`:
```yaml
group_sessions_per_user: true
```
Set it to `false` only if you explicitly want one shared conversation for the entire channel:
```yaml
group_sessions_per_user: false
```
Shared sessions can be useful for a collaborative channel, but they also mean:
- users share context growth and token costs
- one person's long tool-heavy task can bloat everyone else's context
- one person's in-flight run can interrupt another person's follow-up in the same channel
This guide walks you through the full setup process — from creating your bot on Mattermost to sending your first message.
## Step 1: Enable Bot Accounts
Bot accounts must be enabled on your Mattermost server before you can create one.
1. Log in to Mattermost as a **System Admin**.
2. Go to **System Console****Integrations****Bot Accounts**.
3. Set **Enable Bot Account Creation** to **true**.
4. Click **Save**.
:::info
If you don't have System Admin access, ask your Mattermost administrator to enable bot accounts and create one for you.
:::
## Step 2: Create a Bot Account
1. In Mattermost, click the **☰** menu (top-left) → **Integrations****Bot Accounts**.
2. Click **Add Bot Account**.
3. Fill in the details:
- **Username**: e.g., `hermes`
- **Display Name**: e.g., `Hermes Agent`
- **Description**: optional
- **Role**: `Member` is sufficient
4. Click **Create Bot Account**.
5. Mattermost will display the **bot token**. **Copy it immediately.**
:::warning[Token shown only once]
The bot token is only displayed once when you create the bot account. If you lose it, you'll need to regenerate it from the bot account settings. Never share your token publicly or commit it to Git — anyone with this token has full control of the bot.
:::
Store the token somewhere safe (a password manager, for example). You'll need it in Step 5.
:::tip
You can also use a **personal access token** instead of a bot account. Go to **Profile****Security****Personal Access Tokens****Create Token**. This is useful if you want Hermes to post as your own user rather than a separate bot user.
:::
## Step 3: Add the Bot to Channels
The bot needs to be a member of any channel where you want it to respond:
1. Open the channel where you want the bot.
2. Click the channel name → **Add Members**.
3. Search for your bot username (e.g., `hermes`) and add it.
For DMs, simply open a direct message with the bot — it will be able to respond immediately.
## Step 4: Find Your Mattermost User ID
Hermes Agent uses your Mattermost User ID to control who can interact with the bot. To find it:
1. Click your **avatar** (top-left corner) → **Profile**.
2. Your User ID is displayed in the profile dialog — click it to copy.
Your User ID is a 26-character alphanumeric string like `3uo8dkh1p7g1mfk49ear5fzs5c`.
:::warning
Your User ID is **not** your username. The username is what appears after `@` (e.g., `@alice`). The User ID is a long alphanumeric identifier that Mattermost uses internally.
:::
**Alternative**: You can also get your User ID via the API:
```bash
curl -H "Authorization: Bearer YOUR_TOKEN" \
https://your-mattermost-server/api/v4/users/me | jq .id
```
:::tip
To get a **Channel ID**: click the channel name → **View Info**. The Channel ID is shown in the info panel. You'll need this if you want to set a home channel manually.
:::
## Step 5: Configure Hermes Agent
### Option A: Interactive Setup (Recommended)
Run the guided setup command:
```bash
hermes gateway setup
```
Select **Mattermost** when prompted, then paste your server URL, bot token, and user ID when asked.
### Option B: Manual Configuration
Add the following to your `~/.hermes/.env` file:
```bash
# Required
MATTERMOST_URL=https://mm.example.com
MATTERMOST_TOKEN=***
MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c
# Multiple allowed users (comma-separated)
# MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c,8fk2jd9s0a7bncm1xqw4tp6r3e
# Optional: reply mode (thread or off, default: off)
# MATTERMOST_REPLY_MODE=thread
```
Optional behavior settings in `~/.hermes/config.yaml`:
```yaml
group_sessions_per_user: true
```
- `group_sessions_per_user: true` keeps each participant's context isolated inside shared channels and threads
### Start the Gateway
Once configured, start the Mattermost gateway:
```bash
hermes gateway
```
The bot should connect to your Mattermost server within a few seconds. Send it a message — either a DM or in a channel where it's been added — to test.
:::tip
You can run `hermes gateway` in the background or as a systemd service for persistent operation. See the deployment docs for details.
:::
## Home Channel
You can designate a "home channel" where the bot sends proactive messages (such as cron job output, reminders, and notifications). There are two ways to set it:
### Using the Slash Command
Type `/sethome` in any Mattermost channel where the bot is present. That channel becomes the home channel.
### Manual Configuration
Add this to your `~/.hermes/.env`:
```bash
MATTERMOST_HOME_CHANNEL=abc123def456ghi789jkl012mn
```
Replace the ID with the actual channel ID (click the channel name → View Info → copy the ID).
## Reply Mode
The `MATTERMOST_REPLY_MODE` setting controls how Hermes posts responses:
| Mode | Behavior |
|------|----------|
| `off` (default) | Hermes posts flat messages in the channel, like a normal user. |
| `thread` | Hermes replies in a thread under your original message. Keeps channels clean when there's lots of back-and-forth. |
Set it in your `~/.hermes/.env`:
```bash
MATTERMOST_REPLY_MODE=thread
```
## Troubleshooting
### Bot is not responding to messages
**Cause**: The bot is not a member of the channel, or `MATTERMOST_ALLOWED_USERS` doesn't include your User ID.
**Fix**: Add the bot to the channel (channel name → Add Members → search for the bot). Verify your User ID is in `MATTERMOST_ALLOWED_USERS`. Restart the gateway.
### 403 Forbidden errors
**Cause**: The bot token is invalid, or the bot doesn't have permission to post in the channel.
**Fix**: Check that `MATTERMOST_TOKEN` in your `.env` file is correct. Make sure the bot account hasn't been deactivated. Verify the bot has been added to the channel. If using a personal access token, ensure your account has the required permissions.
### WebSocket disconnects / reconnection loops
**Cause**: Network instability, Mattermost server restarts, or firewall/proxy issues with WebSocket connections.
**Fix**: The adapter automatically reconnects with exponential backoff (2s → 60s). Check your server's WebSocket configuration — reverse proxies (nginx, Apache) need WebSocket upgrade headers configured. Verify no firewall is blocking WebSocket connections on your Mattermost server.
For nginx, ensure your config includes:
```nginx
location /api/v4/websocket {
proxy_pass http://mattermost-backend;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_read_timeout 600s;
}
```
### "Failed to authenticate" on startup
**Cause**: The token or server URL is incorrect.
**Fix**: Verify `MATTERMOST_URL` points to your Mattermost server (include `https://`, no trailing slash). Check that `MATTERMOST_TOKEN` is valid — try it with curl:
```bash
curl -H "Authorization: Bearer YOUR_TOKEN" \
https://your-server/api/v4/users/me
```
If this returns your bot's user info, the token is valid. If it returns an error, regenerate the token.
### Bot is offline
**Cause**: The Hermes gateway isn't running, or it failed to connect.
**Fix**: Check that `hermes gateway` is running. Look at the terminal output for error messages. Common issues: wrong URL, expired token, Mattermost server unreachable.
### "User not allowed" / Bot ignores you
**Cause**: Your User ID isn't in `MATTERMOST_ALLOWED_USERS`.
**Fix**: Add your User ID to `MATTERMOST_ALLOWED_USERS` in `~/.hermes/.env` and restart the gateway. Remember: the User ID is a 26-character alphanumeric string, not your `@username`.
## Security
:::warning
Always set `MATTERMOST_ALLOWED_USERS` to restrict who can interact with the bot. Without it, the gateway denies all users by default as a safety measure. Only add User IDs of people you trust — authorized users have full access to the agent's capabilities, including tool use and system access.
:::
For more information on securing your Hermes Agent deployment, see the [Security Guide](../security.md).
## Notes
- **Self-hosted friendly**: Works with any self-hosted Mattermost instance. No Mattermost Cloud account or subscription required.
- **No extra dependencies**: The adapter uses `aiohttp` for HTTP and WebSocket, which is already included with Hermes Agent.
- **Team Edition compatible**: Works with both Mattermost Team Edition (free) and Enterprise Edition.

View file

@ -0,0 +1,175 @@
---
sidebar_position: 8
title: "SMS (Twilio)"
description: "Set up Hermes Agent as an SMS chatbot via Twilio"
---
# SMS Setup (Twilio)
Hermes connects to SMS through the [Twilio](https://www.twilio.com/) API. People text your Twilio phone number and get AI responses back — same conversational experience as Telegram or Discord, but over standard text messages.
:::info Shared Credentials
The SMS gateway shares credentials with the optional [telephony skill](/docs/reference/skills-catalog). If you've already set up Twilio for voice calls or one-off SMS, the gateway works with the same `TWILIO_ACCOUNT_SID`, `TWILIO_AUTH_TOKEN`, and `TWILIO_PHONE_NUMBER`.
:::
---
## Prerequisites
- **Twilio account** — [Sign up at twilio.com](https://www.twilio.com/try-twilio) (free trial available)
- **A Twilio phone number** with SMS capability
- **A publicly accessible server** — Twilio sends webhooks to your server when SMS arrives
- **aiohttp**`pip install 'hermes-agent[sms]'`
---
## Step 1: Get Your Twilio Credentials
1. Go to the [Twilio Console](https://console.twilio.com/)
2. Copy your **Account SID** and **Auth Token** from the dashboard
3. Go to **Phone Numbers → Manage → Active Numbers** — note your phone number in E.164 format (e.g., `+15551234567`)
---
## Step 2: Configure Hermes
### Interactive setup (recommended)
```bash
hermes gateway setup
```
Select **SMS (Twilio)** from the platform list. The wizard will prompt for your credentials.
### Manual setup
Add to `~/.hermes/.env`:
```bash
TWILIO_ACCOUNT_SID=ACxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
TWILIO_AUTH_TOKEN=your_auth_token_here
TWILIO_PHONE_NUMBER=+15551234567
# Security: restrict to specific phone numbers (recommended)
SMS_ALLOWED_USERS=+15559876543,+15551112222
# Optional: set a home channel for cron job delivery
SMS_HOME_CHANNEL=+15559876543
```
---
## Step 3: Configure Twilio Webhook
Twilio needs to know where to send incoming messages. In the [Twilio Console](https://console.twilio.com/):
1. Go to **Phone Numbers → Manage → Active Numbers**
2. Click your phone number
3. Under **Messaging → A MESSAGE COMES IN**, set:
- **Webhook**: `https://your-server:8080/webhooks/twilio`
- **HTTP Method**: `POST`
:::tip Exposing Your Webhook
If you're running Hermes locally, use a tunnel to expose the webhook:
```bash
# Using cloudflared
cloudflared tunnel --url http://localhost:8080
# Using ngrok
ngrok http 8080
```
Set the resulting public URL as your Twilio webhook.
:::
The webhook port defaults to `8080`. Override with:
```bash
SMS_WEBHOOK_PORT=3000
```
---
## Step 4: Start the Gateway
```bash
hermes gateway
```
You should see:
```
[sms] Twilio webhook server listening on port 8080, from: +1555***4567
```
Text your Twilio number — Hermes will respond via SMS.
---
## Environment Variables
| Variable | Required | Description |
|----------|----------|-------------|
| `TWILIO_ACCOUNT_SID` | Yes | Twilio Account SID (starts with `AC`) |
| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token |
| `TWILIO_PHONE_NUMBER` | Yes | Your Twilio phone number (E.164 format) |
| `SMS_WEBHOOK_PORT` | No | Webhook listener port (default: `8080`) |
| `SMS_ALLOWED_USERS` | No | Comma-separated E.164 phone numbers allowed to chat |
| `SMS_ALLOW_ALL_USERS` | No | Set to `true` to allow anyone (not recommended) |
| `SMS_HOME_CHANNEL` | No | Phone number for cron job / notification delivery |
| `SMS_HOME_CHANNEL_NAME` | No | Display name for the home channel (default: `Home`) |
---
## SMS-Specific Behavior
- **Plain text only** — Markdown is automatically stripped since SMS renders it as literal characters
- **1600 character limit** — Longer responses are split across multiple messages at natural boundaries (newlines, then spaces)
- **Echo prevention** — Messages from your own Twilio number are ignored to prevent loops
- **Phone number redaction** — Phone numbers are redacted in logs for privacy
---
## Security
**The gateway denies all users by default.** Configure an allowlist:
```bash
# Recommended: restrict to specific phone numbers
SMS_ALLOWED_USERS=+15559876543,+15551112222
# Or allow all (NOT recommended for bots with terminal access)
SMS_ALLOW_ALL_USERS=true
```
:::warning
SMS has no built-in encryption. Don't use SMS for sensitive operations unless you understand the security implications. For sensitive use cases, prefer Signal or Telegram.
:::
---
## Troubleshooting
### Messages not arriving
1. Check your Twilio webhook URL is correct and publicly accessible
2. Verify `TWILIO_ACCOUNT_SID` and `TWILIO_AUTH_TOKEN` are correct
3. Check the Twilio Console → **Monitor → Logs → Messaging** for delivery errors
4. Ensure your phone number is in `SMS_ALLOWED_USERS` (or `SMS_ALLOW_ALL_USERS=true`)
### Replies not sending
1. Check `TWILIO_PHONE_NUMBER` is set correctly (E.164 format with `+`)
2. Verify your Twilio account has SMS-capable numbers
3. Check Hermes gateway logs for Twilio API errors
### Webhook port conflicts
If port 8080 is already in use, change it:
```bash
SMS_WEBHOOK_PORT=3001
```
Update the webhook URL in Twilio Console to match.

View file

@ -277,6 +277,25 @@ Error messages from MCP tools are sanitized before being returned to the LLM. Th
- Bearer tokens
- `token=`, `key=`, `API_KEY=`, `password=`, `secret=` parameters
### Website Access Policy
You can restrict which websites the agent can access through its web and browser tools. This is useful for preventing the agent from accessing internal services, admin panels, or other sensitive URLs.
```yaml
# In ~/.hermes/config.yaml
website_blocklist:
enabled: true
domains:
- "*.internal.company.com"
- "admin.example.com"
shared_files:
- "/etc/hermes/blocked-sites.txt"
```
When a blocked URL is requested, the tool returns an error explaining the domain is blocked by policy. The blocklist is enforced across `web_search`, `web_extract`, `browser_navigate`, and all URL-capable tools.
See [Website Blocklist](/docs/user-guide/configuration#website-blocklist) in the configuration guide for full details.
### Context File Injection Protection
Context files (AGENTS.md, .cursorrules, SOUL.md) are scanned for prompt injection before being included in the system prompt. The scanner checks for:

View file

@ -48,6 +48,9 @@ const sidebars: SidebarsConfig = {
'user-guide/messaging/signal',
'user-guide/messaging/email',
'user-guide/messaging/homeassistant',
'user-guide/messaging/mattermost',
'user-guide/messaging/matrix',
'user-guide/messaging/dingtalk',
],
},
{