The architecture has been updated

This commit is contained in:
Skyber_2 2026-03-31 23:31:36 +03:00
parent 805f7a017e
commit a01257ead9
1119 changed files with 226 additions and 352 deletions

View file

@ -0,0 +1,266 @@
#!/usr/bin/env python3
"""
Tools Package
This package contains all the specific tool implementations for the Hermes Agent.
Each module provides specialized functionality for different capabilities:
- web_tools: Web search, content extraction, and crawling
- terminal_tool: Command execution (local/docker/modal/daytona/ssh/singularity backends)
- vision_tools: Image analysis and understanding
- mixture_of_agents_tool: Multi-model collaborative reasoning
- image_generation_tool: Text-to-image generation with upscaling
The tools are imported into model_tools.py which provides a unified interface
for the AI agent to access all capabilities.
"""
# Export all tools for easy importing
from .web_tools import (
web_search_tool,
web_extract_tool,
web_crawl_tool,
check_firecrawl_api_key
)
# Primary terminal tool (local/docker/singularity/modal/daytona/ssh)
from .terminal_tool import (
terminal_tool,
check_terminal_requirements,
cleanup_vm,
cleanup_all_environments,
get_active_environments_info,
register_task_env_overrides,
clear_task_env_overrides,
TERMINAL_TOOL_DESCRIPTION
)
from .vision_tools import (
vision_analyze_tool,
check_vision_requirements
)
from .mixture_of_agents_tool import (
mixture_of_agents_tool,
check_moa_requirements
)
from .image_generation_tool import (
image_generate_tool,
check_image_generation_requirements
)
from .skills_tool import (
skills_list,
skill_view,
check_skills_requirements,
SKILLS_TOOL_DESCRIPTION
)
from .skill_manager_tool import (
skill_manage,
check_skill_manage_requirements,
SKILL_MANAGE_SCHEMA
)
# Browser automation tools (agent-browser + Browserbase)
# from .browser_tool import (
# browser_navigate,
# browser_snapshot,
# browser_click,
# browser_type,
# browser_scroll,
# browser_back,
# browser_press,
# browser_close,
# browser_get_images,
# browser_vision,
# cleanup_browser,
# cleanup_all_browsers,
# get_active_browser_sessions,
# check_browser_requirements,
# BROWSER_TOOL_SCHEMAS
# )
from .browser_use_tool import run_browser_task
from .browser_tool import cleanup_browser, cleanup_all_browsers
# Cronjob management tools (CLI-only, hermes-cli toolset)
from .cronjob_tools import (
cronjob,
schedule_cronjob,
list_cronjobs,
remove_cronjob,
check_cronjob_requirements,
get_cronjob_tool_definitions,
CRONJOB_SCHEMA,
)
# RL Training tools (Tinker-Atropos)
from .rl_training_tool import (
rl_list_environments,
rl_select_environment,
rl_get_current_config,
rl_edit_config,
rl_start_training,
rl_check_status,
rl_stop_training,
rl_get_results,
rl_list_runs,
rl_test_inference,
check_rl_api_keys,
get_missing_keys,
)
# File manipulation tools (read, write, patch, search)
from .file_tools import (
read_file_tool,
write_file_tool,
patch_tool,
search_tool,
get_file_tools,
clear_file_ops_cache,
)
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
from .tts_tool import (
text_to_speech_tool,
check_tts_requirements,
)
# Planning & task management tool
from .todo_tool import (
todo_tool,
check_todo_requirements,
TODO_SCHEMA,
TodoStore,
)
# Clarifying questions tool (interactive Q&A with the user)
from .clarify_tool import (
clarify_tool,
check_clarify_requirements,
CLARIFY_SCHEMA,
)
# Code execution sandbox (programmatic tool calling)
from .code_execution_tool import (
execute_code,
check_sandbox_requirements,
EXECUTE_CODE_SCHEMA,
)
# Subagent delegation (spawn child agents with isolated context)
from .delegate_tool import (
delegate_task,
check_delegate_requirements,
DELEGATE_TASK_SCHEMA,
)
# File tools have no external requirements - they use the terminal backend
def check_file_requirements():
"""File tools only require terminal backend to be available."""
from .terminal_tool import check_terminal_requirements
return check_terminal_requirements()
__all__ = [
# Web tools
'web_search_tool',
'web_extract_tool',
'web_crawl_tool',
'check_firecrawl_api_key',
# Terminal tools
'terminal_tool',
'check_terminal_requirements',
'cleanup_vm',
'cleanup_all_environments',
'get_active_environments_info',
'register_task_env_overrides',
'clear_task_env_overrides',
'TERMINAL_TOOL_DESCRIPTION',
# Vision tools
'vision_analyze_tool',
'check_vision_requirements',
# MoA tools
'mixture_of_agents_tool',
'check_moa_requirements',
# Image generation tools
'image_generate_tool',
'check_image_generation_requirements',
# Skills tools
'skills_list',
'skill_view',
'check_skills_requirements',
'SKILLS_TOOL_DESCRIPTION',
# Skill management
'skill_manage',
'check_skill_manage_requirements',
'SKILL_MANAGE_SCHEMA',
# Browser automation tools
'browser_navigate',
'browser_snapshot',
'browser_click',
'browser_type',
'browser_scroll',
'browser_back',
'browser_press',
'browser_close',
'browser_get_images',
'browser_vision',
'cleanup_browser',
'cleanup_all_browsers',
'get_active_browser_sessions',
'check_browser_requirements',
'BROWSER_TOOL_SCHEMAS',
# Cronjob management tools (CLI-only)
'cronjob',
'schedule_cronjob',
'list_cronjobs',
'remove_cronjob',
'check_cronjob_requirements',
'get_cronjob_tool_definitions',
'CRONJOB_SCHEMA',
# RL Training tools
'rl_list_environments',
'rl_select_environment',
'rl_get_current_config',
'rl_edit_config',
'rl_start_training',
'rl_check_status',
'rl_stop_training',
'rl_get_results',
'rl_list_runs',
'rl_test_inference',
'check_rl_api_keys',
'get_missing_keys',
# File manipulation tools
'read_file_tool',
'write_file_tool',
'patch_tool',
'search_tool',
'get_file_tools',
'clear_file_ops_cache',
'check_file_requirements',
# Text-to-speech tools
'text_to_speech_tool',
'check_tts_requirements',
# Planning & task management tool
'todo_tool',
'check_todo_requirements',
'TODO_SCHEMA',
'TodoStore',
# Clarifying questions tool
'clarify_tool',
'check_clarify_requirements',
'CLARIFY_SCHEMA',
# Code execution sandbox
'execute_code',
'check_sandbox_requirements',
'EXECUTE_CODE_SCHEMA',
# Subagent delegation
'delegate_task',
'check_delegate_requirements',
'DELEGATE_TASK_SCHEMA',
]

View file

@ -0,0 +1,44 @@
"""Strip ANSI escape sequences from subprocess output.
Used by terminal_tool, code_execution_tool, and process_registry to clean
command output before returning it to the model. This prevents ANSI codes
from entering the model's context — which is the root cause of models
copying escape sequences into file writes.
Covers the full ECMA-48 spec: CSI (including private-mode ``?`` prefix,
colon-separated params, intermediate bytes), OSC (BEL and ST terminators),
DCS/SOS/PM/APC string sequences, nF multi-byte escapes, Fp/Fe/Fs
single-byte escapes, and 8-bit C1 control characters.
"""
import re
_ANSI_ESCAPE_RE = re.compile(
r"\x1b"
r"(?:"
r"\[[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # CSI sequence
r"|\][\s\S]*?(?:\x07|\x1b\\)" # OSC (BEL or ST terminator)
r"|[PX^_][\s\S]*?(?:\x1b\\)" # DCS/SOS/PM/APC strings
r"|[\x20-\x2f]+[\x30-\x7e]" # nF escape sequences
r"|[\x30-\x7e]" # Fp/Fe/Fs single-byte
r")"
r"|\x9b[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # 8-bit CSI
r"|\x9d[\s\S]*?(?:\x07|\x9c)" # 8-bit OSC
r"|[\x80-\x9f]", # Other 8-bit C1 controls
re.DOTALL,
)
# Fast-path check — skip full regex when no escape-like bytes are present.
_HAS_ESCAPE = re.compile(r"[\x1b\x80-\x9f]")
def strip_ansi(text: str) -> str:
"""Remove ANSI escape sequences from text.
Returns the input unchanged (fast path) when no ESC or C1 bytes are
present. Safe to call on any string clean text passes through
with negligible overhead.
"""
if not text or not _HAS_ESCAPE.search(text):
return text
return _ANSI_ESCAPE_RE.sub("", text)

View file

@ -0,0 +1,590 @@
"""Dangerous command approval -- detection, prompting, and per-session state.
This module is the single source of truth for the dangerous command system:
- Pattern detection (DANGEROUS_PATTERNS, detect_dangerous_command)
- Per-session approval state (thread-safe, keyed by session_key)
- Approval prompting (CLI interactive + gateway async)
- Smart approval via auxiliary LLM (auto-approve low-risk commands)
- Permanent allowlist persistence (config.yaml)
"""
import logging
import os
import re
import sys
import threading
from typing import Optional
logger = logging.getLogger(__name__)
# =========================================================================
# Dangerous command patterns
# =========================================================================
DANGEROUS_PATTERNS = [
(r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"),
(r'\brm\s+-[^\s]*r', "recursive delete"),
(r'\brm\s+--recursive\b', "recursive delete (long flag)"),
(r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"),
(r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"),
(r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"),
(r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"),
(r'\bmkfs\b', "format filesystem"),
(r'\bdd\s+.*if=', "disk copy"),
(r'>\s*/dev/sd', "write to block device"),
(r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"),
(r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"),
(r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"),
(r'>\s*/etc/', "overwrite system config"),
(r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"),
(r'\bkill\s+-9\s+-1\b', "kill all processes"),
(r'\bpkill\s+-9\b', "force kill processes"),
(r':\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\}\s*;\s*:', "fork bomb"),
# Any shell invocation via -c or combined flags like -lc, -ic, etc.
(r'\b(bash|sh|zsh|ksh)\s+-[^\s]*c(\s+|$)', "shell command via -c/-lc flag"),
(r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"),
(r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"),
(r'\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b', "execute remote script via process substitution"),
(r'\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)', "overwrite system file via tee"),
(r'\bxargs\s+.*\brm\b', "xargs with rm"),
(r'\bfind\b.*-exec\s+(/\S*/)?rm\b', "find -exec rm"),
(r'\bfind\b.*-delete\b', "find -delete"),
# Gateway protection: never start gateway outside systemd management
(r'gateway\s+run\b.*(&\s*$|&\s*;|\bdisown\b|\bsetsid\b)', "start gateway outside systemd (use 'systemctl --user restart hermes-gateway')"),
(r'\bnohup\b.*gateway\s+run\b', "start gateway outside systemd (use 'systemctl --user restart hermes-gateway')"),
]
def _legacy_pattern_key(pattern: str) -> str:
"""Reproduce the old regex-derived approval key for backwards compatibility."""
return pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20]
_PATTERN_KEY_ALIASES: dict[str, set[str]] = {}
for _pattern, _description in DANGEROUS_PATTERNS:
_legacy_key = _legacy_pattern_key(_pattern)
_canonical_key = _description
_PATTERN_KEY_ALIASES.setdefault(_canonical_key, set()).update({_canonical_key, _legacy_key})
_PATTERN_KEY_ALIASES.setdefault(_legacy_key, set()).update({_legacy_key, _canonical_key})
def _approval_key_aliases(pattern_key: str) -> set[str]:
"""Return all approval keys that should match this pattern.
New approvals use the human-readable description string, but older
command_allowlist entries and session approvals may still contain the
historical regex-derived key.
"""
return _PATTERN_KEY_ALIASES.get(pattern_key, {pattern_key})
# =========================================================================
# Detection
# =========================================================================
def detect_dangerous_command(command: str) -> tuple:
"""Check if a command matches any dangerous patterns.
Returns:
(is_dangerous, pattern_key, description) or (False, None, None)
"""
command_lower = command.lower()
for pattern, description in DANGEROUS_PATTERNS:
if re.search(pattern, command_lower, re.IGNORECASE | re.DOTALL):
pattern_key = description
return (True, pattern_key, description)
return (False, None, None)
# =========================================================================
# Per-session approval state (thread-safe)
# =========================================================================
_lock = threading.Lock()
_pending: dict[str, dict] = {}
_session_approved: dict[str, set] = {}
_permanent_approved: set = set()
def submit_pending(session_key: str, approval: dict):
"""Store a pending approval request for a session."""
with _lock:
_pending[session_key] = approval
def pop_pending(session_key: str) -> Optional[dict]:
"""Retrieve and remove a pending approval for a session."""
with _lock:
return _pending.pop(session_key, None)
def has_pending(session_key: str) -> bool:
"""Check if a session has a pending approval request."""
with _lock:
return session_key in _pending
def approve_session(session_key: str, pattern_key: str):
"""Approve a pattern for this session only."""
with _lock:
_session_approved.setdefault(session_key, set()).add(pattern_key)
def is_approved(session_key: str, pattern_key: str) -> bool:
"""Check if a pattern is approved (session-scoped or permanent).
Accept both the current canonical key and the legacy regex-derived key so
existing command_allowlist entries continue to work after key migrations.
"""
aliases = _approval_key_aliases(pattern_key)
with _lock:
if any(alias in _permanent_approved for alias in aliases):
return True
session_approvals = _session_approved.get(session_key, set())
return any(alias in session_approvals for alias in aliases)
def approve_permanent(pattern_key: str):
"""Add a pattern to the permanent allowlist."""
with _lock:
_permanent_approved.add(pattern_key)
def load_permanent(patterns: set):
"""Bulk-load permanent allowlist entries from config."""
with _lock:
_permanent_approved.update(patterns)
def clear_session(session_key: str):
"""Clear all approvals and pending requests for a session."""
with _lock:
_session_approved.pop(session_key, None)
_pending.pop(session_key, None)
# =========================================================================
# Config persistence for permanent allowlist
# =========================================================================
def load_permanent_allowlist() -> set:
"""Load permanently allowed command patterns from config.
Also syncs them into the approval module so is_approved() works for
patterns added via 'always' in a previous session.
"""
try:
from hermes_cli.config import load_config
config = load_config()
patterns = set(config.get("command_allowlist", []) or [])
if patterns:
load_permanent(patterns)
return patterns
except Exception:
return set()
def save_permanent_allowlist(patterns: set):
"""Save permanently allowed command patterns to config."""
try:
from hermes_cli.config import load_config, save_config
config = load_config()
config["command_allowlist"] = list(patterns)
save_config(config)
except Exception as e:
logger.warning("Could not save allowlist: %s", e)
# =========================================================================
# Approval prompting + orchestration
# =========================================================================
def prompt_dangerous_approval(command: str, description: str,
timeout_seconds: int = 60,
allow_permanent: bool = True,
approval_callback=None) -> str:
"""Prompt the user to approve a dangerous command (CLI only).
Args:
allow_permanent: When False, hide the [a]lways option (used when
tirith warnings are present, since broad permanent allowlisting
is inappropriate for content-level security findings).
approval_callback: Optional callback registered by the CLI for
prompt_toolkit integration. Signature:
(command, description, *, allow_permanent=True) -> str.
Returns: 'once', 'session', 'always', or 'deny'
"""
if approval_callback is not None:
try:
return approval_callback(command, description,
allow_permanent=allow_permanent)
except Exception:
return "deny"
os.environ["HERMES_SPINNER_PAUSE"] = "1"
try:
while True:
print()
print(f" ⚠️ DANGEROUS COMMAND: {description}")
print(f" {command}")
print()
if allow_permanent:
print(" [o]nce | [s]ession | [a]lways | [d]eny")
else:
print(" [o]nce | [s]ession | [d]eny")
print()
sys.stdout.flush()
result = {"choice": ""}
def get_input():
try:
prompt = " Choice [o/s/a/D]: " if allow_permanent else " Choice [o/s/D]: "
result["choice"] = input(prompt).strip().lower()
except (EOFError, OSError):
result["choice"] = ""
thread = threading.Thread(target=get_input, daemon=True)
thread.start()
thread.join(timeout=timeout_seconds)
if thread.is_alive():
print("\n ⏱ Timeout - denying command")
return "deny"
choice = result["choice"]
if choice in ('o', 'once'):
print(" ✓ Allowed once")
return "once"
elif choice in ('s', 'session'):
print(" ✓ Allowed for this session")
return "session"
elif choice in ('a', 'always'):
if not allow_permanent:
print(" ✓ Allowed for this session")
return "session"
print(" ✓ Added to permanent allowlist")
return "always"
else:
print(" ✗ Denied")
return "deny"
except (EOFError, KeyboardInterrupt):
print("\n ✗ Cancelled")
return "deny"
finally:
if "HERMES_SPINNER_PAUSE" in os.environ:
del os.environ["HERMES_SPINNER_PAUSE"]
print()
sys.stdout.flush()
def _normalize_approval_mode(mode) -> str:
"""Normalize approval mode values loaded from YAML/config.
YAML 1.1 treats bare words like `off` as booleans, so a config entry like
`approvals:\n mode: off` is parsed as False unless quoted. Treat that as the
intended string mode instead of falling back to manual approvals.
"""
if isinstance(mode, bool):
return "off" if mode is False else "manual"
if isinstance(mode, str):
normalized = mode.strip().lower()
return normalized or "manual"
return "manual"
def _get_approval_mode() -> str:
"""Read the approval mode from config. Returns 'manual', 'smart', or 'off'."""
try:
from hermes_cli.config import load_config
config = load_config()
mode = config.get("approvals", {}).get("mode", "manual")
return _normalize_approval_mode(mode)
except Exception:
return "manual"
def _smart_approve(command: str, description: str) -> str:
"""Use the auxiliary LLM to assess risk and decide approval.
Returns 'approve' if the LLM determines the command is safe,
'deny' if genuinely dangerous, or 'escalate' if uncertain.
Inspired by OpenAI Codex's Smart Approvals guardian subagent
(openai/codex#13860).
"""
try:
from agent.auxiliary_client import get_text_auxiliary_client, auxiliary_max_tokens_param
client, model = get_text_auxiliary_client(task="approval")
if not client or not model:
logger.debug("Smart approvals: no aux client available, escalating")
return "escalate"
prompt = f"""You are a security reviewer for an AI coding agent. A terminal command was flagged by pattern matching as potentially dangerous.
Command: {command}
Flagged reason: {description}
Assess the ACTUAL risk of this command. Many flagged commands are false positives for example, `python -c "print('hello')"` is flagged as "script execution via -c flag" but is completely harmless.
Rules:
- APPROVE if the command is clearly safe (benign script execution, safe file operations, development tools, package installs, git operations, etc.)
- DENY if the command could genuinely damage the system (recursive delete of important paths, overwriting system files, fork bombs, wiping disks, dropping databases, etc.)
- ESCALATE if you're uncertain
Respond with exactly one word: APPROVE, DENY, or ESCALATE"""
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
**auxiliary_max_tokens_param(16),
temperature=0,
)
answer = (response.choices[0].message.content or "").strip().upper()
if "APPROVE" in answer:
return "approve"
elif "DENY" in answer:
return "deny"
else:
return "escalate"
except Exception as e:
logger.debug("Smart approvals: LLM call failed (%s), escalating", e)
return "escalate"
def check_dangerous_command(command: str, env_type: str,
approval_callback=None) -> dict:
"""Check if a command is dangerous and handle approval.
This is the main entry point called by terminal_tool before executing
any command. It orchestrates detection, session checks, and prompting.
Args:
command: The shell command to check.
env_type: Terminal backend type ('local', 'ssh', 'docker', etc.).
approval_callback: Optional CLI callback for interactive prompts.
Returns:
{"approved": True/False, "message": str or None, ...}
"""
if env_type in ("docker", "singularity", "modal", "daytona"):
return {"approved": True, "message": None}
# --yolo: bypass all approval prompts
if os.getenv("HERMES_YOLO_MODE"):
return {"approved": True, "message": None}
is_dangerous, pattern_key, description = detect_dangerous_command(command)
if not is_dangerous:
return {"approved": True, "message": None}
session_key = os.getenv("HERMES_SESSION_KEY", "default")
if is_approved(session_key, pattern_key):
return {"approved": True, "message": None}
is_cli = os.getenv("HERMES_INTERACTIVE")
is_gateway = os.getenv("HERMES_GATEWAY_SESSION")
if not is_cli and not is_gateway:
return {"approved": True, "message": None}
if is_gateway or os.getenv("HERMES_EXEC_ASK"):
submit_pending(session_key, {
"command": command,
"pattern_key": pattern_key,
"description": description,
})
return {
"approved": False,
"pattern_key": pattern_key,
"status": "approval_required",
"command": command,
"description": description,
"message": (
f"⚠️ This command is potentially dangerous ({description}). "
f"Asking the user for approval.\n\n**Command:**\n```\n{command}\n```"
),
}
choice = prompt_dangerous_approval(command, description,
approval_callback=approval_callback)
if choice == "deny":
return {
"approved": False,
"message": f"BLOCKED: User denied this potentially dangerous command (matched '{description}' pattern). Do NOT retry this command - the user has explicitly rejected it.",
"pattern_key": pattern_key,
"description": description,
}
if choice == "session":
approve_session(session_key, pattern_key)
elif choice == "always":
approve_session(session_key, pattern_key)
approve_permanent(pattern_key)
save_permanent_allowlist(_permanent_approved)
return {"approved": True, "message": None}
# =========================================================================
# Combined pre-exec guard (tirith + dangerous command detection)
# =========================================================================
def check_all_command_guards(command: str, env_type: str,
approval_callback=None) -> dict:
"""Run all pre-exec security checks and return a single approval decision.
Gathers findings from tirith and dangerous-command detection, then
presents them as a single combined approval request. This prevents
a gateway force=True replay from bypassing one check when only the
other was shown to the user.
"""
# Skip containers for both checks
if env_type in ("docker", "singularity", "modal", "daytona"):
return {"approved": True, "message": None}
# --yolo or approvals.mode=off: bypass all approval prompts
approval_mode = _get_approval_mode()
if os.getenv("HERMES_YOLO_MODE") or approval_mode == "off":
return {"approved": True, "message": None}
is_cli = os.getenv("HERMES_INTERACTIVE")
is_gateway = os.getenv("HERMES_GATEWAY_SESSION")
is_ask = os.getenv("HERMES_EXEC_ASK")
# Preserve the existing non-interactive behavior: outside CLI/gateway/ask
# flows, we do not block on approvals and we skip external guard work.
if not is_cli and not is_gateway and not is_ask:
return {"approved": True, "message": None}
# --- Phase 1: Gather findings from both checks ---
# Tirith check — wrapper guarantees no raise for expected failures.
# Only catch ImportError (module not installed).
tirith_result = {"action": "allow", "findings": [], "summary": ""}
try:
from tools.tirith_security import check_command_security
tirith_result = check_command_security(command)
except ImportError:
pass # tirith module not installed — allow
# Dangerous command check (detection only, no approval)
is_dangerous, pattern_key, description = detect_dangerous_command(command)
# --- Phase 2: Decide ---
# If tirith blocks, block immediately (no approval possible)
if tirith_result["action"] == "block":
summary = tirith_result.get("summary") or "security issue detected"
return {
"approved": False,
"message": f"BLOCKED: Command blocked by security scan ({summary}). Do NOT retry.",
}
# Collect warnings that need approval
warnings = [] # list of (pattern_key, description, is_tirith)
session_key = os.getenv("HERMES_SESSION_KEY", "default")
if tirith_result["action"] == "warn":
findings = tirith_result.get("findings") or []
rule_id = findings[0].get("rule_id", "unknown") if findings else "unknown"
tirith_key = f"tirith:{rule_id}"
tirith_desc = f"Security scan: {tirith_result.get('summary') or 'security warning detected'}"
if not is_approved(session_key, tirith_key):
warnings.append((tirith_key, tirith_desc, True))
if is_dangerous:
if not is_approved(session_key, pattern_key):
warnings.append((pattern_key, description, False))
# Nothing to warn about
if not warnings:
return {"approved": True, "message": None}
# --- Phase 2.5: Smart approval (auxiliary LLM risk assessment) ---
# When approvals.mode=smart, ask the aux LLM before prompting the user.
# Inspired by OpenAI Codex's Smart Approvals guardian subagent
# (openai/codex#13860).
if approval_mode == "smart":
combined_desc_for_llm = "; ".join(desc for _, desc, _ in warnings)
verdict = _smart_approve(command, combined_desc_for_llm)
if verdict == "approve":
# Auto-approve and grant session-level approval for these patterns
for key, _, _ in warnings:
approve_session(session_key, key)
logger.debug("Smart approval: auto-approved '%s' (%s)",
command[:60], combined_desc_for_llm)
return {"approved": True, "message": None,
"smart_approved": True}
elif verdict == "deny":
combined_desc_for_llm = "; ".join(desc for _, desc, _ in warnings)
return {
"approved": False,
"message": f"BLOCKED by smart approval: {combined_desc_for_llm}. "
"The command was assessed as genuinely dangerous. Do NOT retry.",
"smart_denied": True,
}
# verdict == "escalate" → fall through to manual prompt
# --- Phase 3: Approval ---
# Combine descriptions for a single approval prompt
combined_desc = "; ".join(desc for _, desc, _ in warnings)
primary_key = warnings[0][0]
all_keys = [key for key, _, _ in warnings]
has_tirith = any(is_t for _, _, is_t in warnings)
# Gateway/async: single approval_required with combined description
# Store all pattern keys so gateway replay approves all of them
if is_gateway or is_ask:
submit_pending(session_key, {
"command": command,
"pattern_key": primary_key, # backward compat
"pattern_keys": all_keys, # all keys for replay
"description": combined_desc,
})
return {
"approved": False,
"pattern_key": primary_key,
"status": "approval_required",
"command": command,
"description": combined_desc,
"message": (
f"⚠️ {combined_desc}. Asking the user for approval.\n\n**Command:**\n```\n{command}\n```"
),
}
# CLI interactive: single combined prompt
# Hide [a]lways when any tirith warning is present
choice = prompt_dangerous_approval(command, combined_desc,
allow_permanent=not has_tirith,
approval_callback=approval_callback)
if choice == "deny":
return {
"approved": False,
"message": "BLOCKED: User denied. Do NOT retry.",
"pattern_key": primary_key,
"description": combined_desc,
}
# Persist approval for each warning individually
for key, _, is_tirith in warnings:
if choice == "session" or (choice == "always" and is_tirith):
# tirith: session only (no permanent broad allowlisting)
approve_session(session_key, key)
elif choice == "always":
# dangerous patterns: permanent allowed
approve_session(session_key, key)
approve_permanent(key)
save_permanent_allowlist(_permanent_approved)
return {"approved": True, "message": None}

View file

@ -0,0 +1,10 @@
"""Cloud browser provider abstraction.
Import the ABC so callers can do::
from tools.browser_providers import CloudBrowserProvider
"""
from tools.browser_providers.base import CloudBrowserProvider
__all__ = ["CloudBrowserProvider"]

View file

@ -0,0 +1,59 @@
"""Abstract base class for cloud browser providers."""
from abc import ABC, abstractmethod
from typing import Dict
class CloudBrowserProvider(ABC):
"""Interface for cloud browser backends (Browserbase, Steel, etc.).
Implementations live in sibling modules and are registered in
``browser_tool._PROVIDER_REGISTRY``. The user selects a provider via
``hermes setup`` / ``hermes tools``; the choice is persisted as
``config["browser"]["cloud_provider"]``.
"""
@abstractmethod
def provider_name(self) -> str:
"""Short, human-readable name shown in logs and diagnostics."""
@abstractmethod
def is_configured(self) -> bool:
"""Return True when all required env vars / credentials are present.
Called at tool-registration time (``check_browser_requirements``) to
gate availability. Must be cheap no network calls.
"""
@abstractmethod
def create_session(self, task_id: str) -> Dict[str, object]:
"""Create a cloud browser session and return session metadata.
Must return a dict with at least::
{
"session_name": str, # unique name for agent-browser --session
"bb_session_id": str, # provider session ID (for close/cleanup)
"cdp_url": str, # CDP websocket URL
"features": dict, # feature flags that were enabled
}
``bb_session_id`` is a legacy key name kept for backward compat with
the rest of browser_tool.py it holds the provider's session ID
regardless of which provider is in use.
"""
@abstractmethod
def close_session(self, session_id: str) -> bool:
"""Release / terminate a cloud session by its provider session ID.
Returns True on success, False on failure. Should not raise.
"""
@abstractmethod
def emergency_cleanup(self, session_id: str) -> None:
"""Best-effort session teardown during process exit.
Called from atexit / signal handlers. Must tolerate missing
credentials, network errors, etc. log and move on.
"""

View file

@ -0,0 +1,107 @@
"""Browser Use cloud browser provider."""
import logging
import os
import uuid
from typing import Dict
import requests
from tools.browser_providers.base import CloudBrowserProvider
logger = logging.getLogger(__name__)
_BASE_URL = "https://api.browser-use.com/api/v2"
class BrowserUseProvider(CloudBrowserProvider):
"""Browser Use (https://browser-use.com) cloud browser backend."""
def provider_name(self) -> str:
return "Browser Use"
def is_configured(self) -> bool:
return bool(os.environ.get("BROWSER_USE_API_KEY"))
# ------------------------------------------------------------------
# Session lifecycle
# ------------------------------------------------------------------
def _headers(self) -> Dict[str, str]:
api_key = os.environ.get("BROWSER_USE_API_KEY")
if not api_key:
raise ValueError(
"BROWSER_USE_API_KEY environment variable is required. "
"Get your key at https://browser-use.com"
)
return {
"Content-Type": "application/json",
"X-Browser-Use-API-Key": api_key,
}
def create_session(self, task_id: str) -> Dict[str, object]:
response = requests.post(
f"{_BASE_URL}/browsers",
headers=self._headers(),
json={},
timeout=30,
)
if not response.ok:
raise RuntimeError(
f"Failed to create Browser Use session: "
f"{response.status_code} {response.text}"
)
session_data = response.json()
session_name = f"hermes_{task_id}_{uuid.uuid4().hex[:8]}"
logger.info("Created Browser Use session %s", session_name)
return {
"session_name": session_name,
"bb_session_id": session_data["id"],
"cdp_url": session_data["cdpUrl"],
"features": {"browser_use": True},
}
def close_session(self, session_id: str) -> bool:
try:
response = requests.patch(
f"{_BASE_URL}/browsers/{session_id}",
headers=self._headers(),
json={"action": "stop"},
timeout=10,
)
if response.status_code in (200, 201, 204):
logger.debug("Successfully closed Browser Use session %s", session_id)
return True
else:
logger.warning(
"Failed to close Browser Use session %s: HTTP %s - %s",
session_id,
response.status_code,
response.text[:200],
)
return False
except Exception as e:
logger.error("Exception closing Browser Use session %s: %s", session_id, e)
return False
def emergency_cleanup(self, session_id: str) -> None:
api_key = os.environ.get("BROWSER_USE_API_KEY")
if not api_key:
logger.warning("Cannot emergency-cleanup Browser Use session %s — missing credentials", session_id)
return
try:
requests.patch(
f"{_BASE_URL}/browsers/{session_id}",
headers={
"Content-Type": "application/json",
"X-Browser-Use-API-Key": api_key,
},
json={"action": "stop"},
timeout=5,
)
except Exception as e:
logger.debug("Emergency cleanup failed for Browser Use session %s: %s", session_id, e)

View file

@ -0,0 +1,206 @@
"""Browserbase cloud browser provider."""
import logging
import os
import uuid
from typing import Dict
import requests
from tools.browser_providers.base import CloudBrowserProvider
logger = logging.getLogger(__name__)
class BrowserbaseProvider(CloudBrowserProvider):
"""Browserbase (https://browserbase.com) cloud browser backend."""
def provider_name(self) -> str:
return "Browserbase"
def is_configured(self) -> bool:
return bool(
os.environ.get("BROWSERBASE_API_KEY")
and os.environ.get("BROWSERBASE_PROJECT_ID")
)
# ------------------------------------------------------------------
# Session lifecycle
# ------------------------------------------------------------------
def _get_config(self) -> Dict[str, str]:
api_key = os.environ.get("BROWSERBASE_API_KEY")
project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
if not api_key or not project_id:
raise ValueError(
"BROWSERBASE_API_KEY and BROWSERBASE_PROJECT_ID environment "
"variables are required. Get your credentials at "
"https://browserbase.com"
)
return {"api_key": api_key, "project_id": project_id}
def create_session(self, task_id: str) -> Dict[str, object]:
config = self._get_config()
# Optional env-var knobs
enable_proxies = os.environ.get("BROWSERBASE_PROXIES", "true").lower() != "false"
enable_advanced_stealth = os.environ.get("BROWSERBASE_ADVANCED_STEALTH", "false").lower() == "true"
enable_keep_alive = os.environ.get("BROWSERBASE_KEEP_ALIVE", "true").lower() != "false"
custom_timeout_ms = os.environ.get("BROWSERBASE_SESSION_TIMEOUT")
features_enabled = {
"basic_stealth": True,
"proxies": False,
"advanced_stealth": False,
"keep_alive": False,
"custom_timeout": False,
}
session_config: Dict[str, object] = {"projectId": config["project_id"]}
if enable_keep_alive:
session_config["keepAlive"] = True
if custom_timeout_ms:
try:
timeout_val = int(custom_timeout_ms)
if timeout_val > 0:
session_config["timeout"] = timeout_val
except ValueError:
logger.warning("Invalid BROWSERBASE_SESSION_TIMEOUT value: %s", custom_timeout_ms)
if enable_proxies:
session_config["proxies"] = True
if enable_advanced_stealth:
session_config["browserSettings"] = {"advancedStealth": True}
# --- Create session via API ---
headers = {
"Content-Type": "application/json",
"X-BB-API-Key": config["api_key"],
}
response = requests.post(
"https://api.browserbase.com/v1/sessions",
headers=headers,
json=session_config,
timeout=30,
)
proxies_fallback = False
keepalive_fallback = False
# Handle 402 — paid features unavailable
if response.status_code == 402:
if enable_keep_alive:
keepalive_fallback = True
logger.warning(
"keepAlive may require paid plan (402), retrying without it. "
"Sessions may timeout during long operations."
)
session_config.pop("keepAlive", None)
response = requests.post(
"https://api.browserbase.com/v1/sessions",
headers=headers,
json=session_config,
timeout=30,
)
if response.status_code == 402 and enable_proxies:
proxies_fallback = True
logger.warning(
"Proxies unavailable (402), retrying without proxies. "
"Bot detection may be less effective."
)
session_config.pop("proxies", None)
response = requests.post(
"https://api.browserbase.com/v1/sessions",
headers=headers,
json=session_config,
timeout=30,
)
if not response.ok:
raise RuntimeError(
f"Failed to create Browserbase session: "
f"{response.status_code} {response.text}"
)
session_data = response.json()
session_name = f"hermes_{task_id}_{uuid.uuid4().hex[:8]}"
if enable_proxies and not proxies_fallback:
features_enabled["proxies"] = True
if enable_advanced_stealth:
features_enabled["advanced_stealth"] = True
if enable_keep_alive and not keepalive_fallback:
features_enabled["keep_alive"] = True
if custom_timeout_ms and "timeout" in session_config:
features_enabled["custom_timeout"] = True
feature_str = ", ".join(k for k, v in features_enabled.items() if v)
logger.info("Created Browserbase session %s with features: %s", session_name, feature_str)
return {
"session_name": session_name,
"bb_session_id": session_data["id"],
"cdp_url": session_data["connectUrl"],
"features": features_enabled,
}
def close_session(self, session_id: str) -> bool:
try:
config = self._get_config()
except ValueError:
logger.warning("Cannot close Browserbase session %s — missing credentials", session_id)
return False
try:
response = requests.post(
f"https://api.browserbase.com/v1/sessions/{session_id}",
headers={
"X-BB-API-Key": config["api_key"],
"Content-Type": "application/json",
},
json={
"projectId": config["project_id"],
"status": "REQUEST_RELEASE",
},
timeout=10,
)
if response.status_code in (200, 201, 204):
logger.debug("Successfully closed Browserbase session %s", session_id)
return True
else:
logger.warning(
"Failed to close session %s: HTTP %s - %s",
session_id,
response.status_code,
response.text[:200],
)
return False
except Exception as e:
logger.error("Exception closing Browserbase session %s: %s", session_id, e)
return False
def emergency_cleanup(self, session_id: str) -> None:
api_key = os.environ.get("BROWSERBASE_API_KEY")
project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
if not api_key or not project_id:
logger.warning("Cannot emergency-cleanup Browserbase session %s — missing credentials", session_id)
return
try:
requests.post(
f"https://api.browserbase.com/v1/sessions/{session_id}",
headers={
"X-BB-API-Key": api_key,
"Content-Type": "application/json",
},
json={
"projectId": project_id,
"status": "REQUEST_RELEASE",
},
timeout=5,
)
except Exception as e:
logger.debug("Emergency cleanup failed for Browserbase session %s: %s", session_id, e)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,86 @@
import json
import os
import asyncio
import socket
from browser_use import Agent, Browser, ChatOpenAI
from tools.registry import registry
async def run_browser_task(task):
browser_host = "browser"
browser_port = 9222
BROWSER_VIEW_URL = os.getenv("BROWSER_VIEW_URL", "")
try:
browser_ip = socket.gethostbyname(browser_host)
cdp_url = f"http://{browser_ip}:{browser_port}"
except Exception:
cdp_url = f"http://{browser_host}:{browser_port}"
browser = Browser(cdp_url=cdp_url)
llm = ChatOpenAI(
model=os.getenv("MODEL_DEFAULT", "qwen3.5-122b"),
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_BASE_URL"),
temperature=0.0,
)
agent = Agent(
task=task,
llm=llm,
browser=browser,
use_vision=False
)
try:
history = await agent.run()
final_result = history.final_result()
response = {
"success": True,
"result": final_result,
"browser_view": BROWSER_VIEW_URL
}
return json.dumps(response, ensure_ascii=False)
except Exception as e:
return json.dumps({
"success": False,
"error": f"Browser automation failed: {str(e)}"
}, ensure_ascii=False)
finally:
if browser:
try:
await browser.close()
except Exception:
pass
registry.register(
name="internet_browser",
toolset="browse_cmd",
schema={
"name": "internet_browser",
"description": (
"ГЛАВНЫЙ ИНСТРУМЕНТ ДЛЯ ВЕБ-СЕРФИНГА. Вызывай этот инструмент НАПРЯМУЮ (через стандартный tool call/function call). "
"КАТЕГОРИЧЕСКИ ЗАПРЕЩЕНО использовать `execute_code` или `delegate_task` для работы с браузером. "
"Не пиши Python-скрипты! Просто передай в этот инструмент параметр `task`. "
"Используй для любых задач в интернете: поиск товаров (Wildberries, Ozon), чтение статей, клики, навигация."
),
"parameters": {
"type": "object",
"properties": {
"task": {
"type": "string",
"description": "Подробная задача на естественном языке. Например: 'Зайди на wildberries.ru, найди черную футболку и верни цену'."
}
},
"required": ["task"]
}
},
handler=lambda args, **kw: asyncio.run(run_browser_task(args.get("task"))),
emoji="🌐",
)

View file

@ -0,0 +1,548 @@
"""
Checkpoint Manager Transparent filesystem snapshots via shadow git repos.
Creates automatic snapshots of working directories before file-mutating
operations (write_file, patch), triggered once per conversation turn.
Provides rollback to any previous checkpoint.
This is NOT a tool the LLM never sees it. It's transparent infrastructure
controlled by the ``checkpoints`` config flag or ``--checkpoints`` CLI flag.
Architecture:
~/.hermes/checkpoints/{sha256(abs_dir)[:16]}/ shadow git repo
HEAD, refs/, objects/ standard git internals
HERMES_WORKDIR original dir path
info/exclude default excludes
The shadow repo uses GIT_DIR + GIT_WORK_TREE so no git state leaks
into the user's project directory.
"""
import hashlib
import logging
import os
import shutil
import subprocess
import time
from pathlib import Path
from typing import Dict, List, Optional, Set
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CHECKPOINT_BASE = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "checkpoints"
DEFAULT_EXCLUDES = [
"node_modules/",
"dist/",
"build/",
".env",
".env.*",
".env.local",
".env.*.local",
"__pycache__/",
"*.pyc",
"*.pyo",
".DS_Store",
"*.log",
".cache/",
".next/",
".nuxt/",
"coverage/",
".pytest_cache/",
".venv/",
"venv/",
".git/",
]
# Git subprocess timeout (seconds).
_GIT_TIMEOUT: int = max(10, min(60, int(os.getenv("HERMES_CHECKPOINT_TIMEOUT", "30"))))
# Max files to snapshot — skip huge directories to avoid slowdowns.
_MAX_FILES = 50_000
# ---------------------------------------------------------------------------
# Shadow repo helpers
# ---------------------------------------------------------------------------
def _shadow_repo_path(working_dir: str) -> Path:
"""Deterministic shadow repo path: sha256(abs_path)[:16]."""
abs_path = str(Path(working_dir).resolve())
dir_hash = hashlib.sha256(abs_path.encode()).hexdigest()[:16]
return CHECKPOINT_BASE / dir_hash
def _git_env(shadow_repo: Path, working_dir: str) -> dict:
"""Build env dict that redirects git to the shadow repo."""
env = os.environ.copy()
env["GIT_DIR"] = str(shadow_repo)
env["GIT_WORK_TREE"] = str(Path(working_dir).resolve())
env.pop("GIT_INDEX_FILE", None)
env.pop("GIT_NAMESPACE", None)
env.pop("GIT_ALTERNATE_OBJECT_DIRECTORIES", None)
return env
def _run_git(
args: List[str],
shadow_repo: Path,
working_dir: str,
timeout: int = _GIT_TIMEOUT,
allowed_returncodes: Optional[Set[int]] = None,
) -> tuple:
"""Run a git command against the shadow repo. Returns (ok, stdout, stderr).
``allowed_returncodes`` suppresses error logging for known/expected non-zero
exits while preserving the normal ``ok = (returncode == 0)`` contract.
Example: ``git diff --cached --quiet`` returns 1 when changes exist.
"""
env = _git_env(shadow_repo, working_dir)
cmd = ["git"] + list(args)
allowed_returncodes = allowed_returncodes or set()
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
env=env,
cwd=str(Path(working_dir).resolve()),
)
ok = result.returncode == 0
stdout = result.stdout.strip()
stderr = result.stderr.strip()
if not ok and result.returncode not in allowed_returncodes:
logger.error(
"Git command failed: %s (rc=%d) stderr=%s",
" ".join(cmd), result.returncode, stderr,
)
return ok, stdout, stderr
except subprocess.TimeoutExpired:
msg = f"git timed out after {timeout}s: {' '.join(cmd)}"
logger.error(msg, exc_info=True)
return False, "", msg
except FileNotFoundError:
logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True)
return False, "", "git not found"
except Exception as exc:
logger.error("Unexpected git error running %s: %s", " ".join(cmd), exc, exc_info=True)
return False, "", str(exc)
def _init_shadow_repo(shadow_repo: Path, working_dir: str) -> Optional[str]:
"""Initialise shadow repo if needed. Returns error string or None."""
if (shadow_repo / "HEAD").exists():
return None
shadow_repo.mkdir(parents=True, exist_ok=True)
ok, _, err = _run_git(["init"], shadow_repo, working_dir)
if not ok:
return f"Shadow repo init failed: {err}"
_run_git(["config", "user.email", "hermes@local"], shadow_repo, working_dir)
_run_git(["config", "user.name", "Hermes Checkpoint"], shadow_repo, working_dir)
info_dir = shadow_repo / "info"
info_dir.mkdir(exist_ok=True)
(info_dir / "exclude").write_text(
"\n".join(DEFAULT_EXCLUDES) + "\n", encoding="utf-8"
)
(shadow_repo / "HERMES_WORKDIR").write_text(
str(Path(working_dir).resolve()) + "\n", encoding="utf-8"
)
logger.debug("Initialised checkpoint repo at %s for %s", shadow_repo, working_dir)
return None
def _dir_file_count(path: str) -> int:
"""Quick file count estimate (stops early if over _MAX_FILES)."""
count = 0
try:
for _ in Path(path).rglob("*"):
count += 1
if count > _MAX_FILES:
return count
except (PermissionError, OSError):
pass
return count
# ---------------------------------------------------------------------------
# CheckpointManager
# ---------------------------------------------------------------------------
class CheckpointManager:
"""Manages automatic filesystem checkpoints.
Designed to be owned by AIAgent. Call ``new_turn()`` at the start of
each conversation turn and ``ensure_checkpoint(dir, reason)`` before
any file-mutating tool call. The manager deduplicates so at most one
snapshot is taken per directory per turn.
Parameters
----------
enabled : bool
Master switch (from config / CLI flag).
max_snapshots : int
Keep at most this many checkpoints per directory.
"""
def __init__(self, enabled: bool = False, max_snapshots: int = 50):
self.enabled = enabled
self.max_snapshots = max_snapshots
self._checkpointed_dirs: Set[str] = set()
self._git_available: Optional[bool] = None # lazy probe
# ------------------------------------------------------------------
# Turn lifecycle
# ------------------------------------------------------------------
def new_turn(self) -> None:
"""Reset per-turn dedup. Call at the start of each agent iteration."""
self._checkpointed_dirs.clear()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def ensure_checkpoint(self, working_dir: str, reason: str = "auto") -> bool:
"""Take a checkpoint if enabled and not already done this turn.
Returns True if a checkpoint was taken, False otherwise.
Never raises all errors are silently logged.
"""
if not self.enabled:
return False
# Lazy git probe
if self._git_available is None:
self._git_available = shutil.which("git") is not None
if not self._git_available:
logger.debug("Checkpoints disabled: git not found")
if not self._git_available:
return False
abs_dir = str(Path(working_dir).resolve())
# Skip root, home, and other overly broad directories
if abs_dir in ("/", str(Path.home())):
logger.debug("Checkpoint skipped: directory too broad (%s)", abs_dir)
return False
# Already checkpointed this turn?
if abs_dir in self._checkpointed_dirs:
return False
self._checkpointed_dirs.add(abs_dir)
try:
return self._take(abs_dir, reason)
except Exception as e:
logger.debug("Checkpoint failed (non-fatal): %s", e)
return False
def list_checkpoints(self, working_dir: str) -> List[Dict]:
"""List available checkpoints for a directory.
Returns a list of dicts with keys: hash, short_hash, timestamp, reason,
files_changed, insertions, deletions. Most recent first.
"""
abs_dir = str(Path(working_dir).resolve())
shadow = _shadow_repo_path(abs_dir)
if not (shadow / "HEAD").exists():
return []
ok, stdout, _ = _run_git(
["log", "--format=%H|%h|%aI|%s", "-n", str(self.max_snapshots)],
shadow, abs_dir,
)
if not ok or not stdout:
return []
results = []
for line in stdout.splitlines():
parts = line.split("|", 3)
if len(parts) == 4:
entry = {
"hash": parts[0],
"short_hash": parts[1],
"timestamp": parts[2],
"reason": parts[3],
"files_changed": 0,
"insertions": 0,
"deletions": 0,
}
# Get diffstat for this commit
stat_ok, stat_out, _ = _run_git(
["diff", "--shortstat", f"{parts[0]}~1", parts[0]],
shadow, abs_dir,
allowed_returncodes={128, 129}, # first commit has no parent
)
if stat_ok and stat_out:
self._parse_shortstat(stat_out, entry)
results.append(entry)
return results
@staticmethod
def _parse_shortstat(stat_line: str, entry: Dict) -> None:
"""Parse git --shortstat output into entry dict."""
import re
m = re.search(r'(\d+) file', stat_line)
if m:
entry["files_changed"] = int(m.group(1))
m = re.search(r'(\d+) insertion', stat_line)
if m:
entry["insertions"] = int(m.group(1))
m = re.search(r'(\d+) deletion', stat_line)
if m:
entry["deletions"] = int(m.group(1))
def diff(self, working_dir: str, commit_hash: str) -> Dict:
"""Show diff between a checkpoint and the current working tree.
Returns dict with success, diff text, and stat summary.
"""
abs_dir = str(Path(working_dir).resolve())
shadow = _shadow_repo_path(abs_dir)
if not (shadow / "HEAD").exists():
return {"success": False, "error": "No checkpoints exist for this directory"}
# Verify the commit exists
ok, _, err = _run_git(
["cat-file", "-t", commit_hash], shadow, abs_dir,
)
if not ok:
return {"success": False, "error": f"Checkpoint '{commit_hash}' not found"}
# Stage current state to compare against checkpoint
_run_git(["add", "-A"], shadow, abs_dir, timeout=_GIT_TIMEOUT * 2)
# Get stat summary: checkpoint vs current working tree
ok_stat, stat_out, _ = _run_git(
["diff", "--stat", commit_hash, "--cached"],
shadow, abs_dir,
)
# Get actual diff (limited to avoid terminal flood)
ok_diff, diff_out, _ = _run_git(
["diff", commit_hash, "--cached", "--no-color"],
shadow, abs_dir,
)
# Unstage to avoid polluting the shadow repo index
_run_git(["reset", "HEAD", "--quiet"], shadow, abs_dir)
if not ok_stat and not ok_diff:
return {"success": False, "error": "Could not generate diff"}
return {
"success": True,
"stat": stat_out if ok_stat else "",
"diff": diff_out if ok_diff else "",
}
def restore(self, working_dir: str, commit_hash: str, file_path: str = None) -> Dict:
"""Restore files to a checkpoint state.
Uses ``git checkout <hash> -- .`` (or a specific file) which restores
tracked files without moving HEAD safe and reversible.
Parameters
----------
file_path : str, optional
If provided, restore only this file instead of the entire directory.
Returns dict with success/error info.
"""
abs_dir = str(Path(working_dir).resolve())
shadow = _shadow_repo_path(abs_dir)
if not (shadow / "HEAD").exists():
return {"success": False, "error": "No checkpoints exist for this directory"}
# Verify the commit exists
ok, _, err = _run_git(
["cat-file", "-t", commit_hash], shadow, abs_dir,
)
if not ok:
return {"success": False, "error": f"Checkpoint '{commit_hash}' not found", "debug": err or None}
# Take a checkpoint of current state before restoring (so you can undo the undo)
self._take(abs_dir, f"pre-rollback snapshot (restoring to {commit_hash[:8]})")
# Restore — full directory or single file
restore_target = file_path if file_path else "."
ok, stdout, err = _run_git(
["checkout", commit_hash, "--", restore_target],
shadow, abs_dir, timeout=_GIT_TIMEOUT * 2,
)
if not ok:
return {"success": False, "error": f"Restore failed: {err}", "debug": err or None}
# Get info about what was restored
ok2, reason_out, _ = _run_git(
["log", "--format=%s", "-1", commit_hash], shadow, abs_dir,
)
reason = reason_out if ok2 else "unknown"
result = {
"success": True,
"restored_to": commit_hash[:8],
"reason": reason,
"directory": abs_dir,
}
if file_path:
result["file"] = file_path
return result
def get_working_dir_for_path(self, file_path: str) -> str:
"""Resolve a file path to its working directory for checkpointing.
Walks up from the file's parent to find a reasonable project root
(directory containing .git, pyproject.toml, package.json, etc.).
Falls back to the file's parent directory.
"""
path = Path(file_path).resolve()
if path.is_dir():
candidate = path
else:
candidate = path.parent
# Walk up looking for project root markers
markers = {".git", "pyproject.toml", "package.json", "Cargo.toml",
"go.mod", "Makefile", "pom.xml", ".hg", "Gemfile"}
check = candidate
while check != check.parent:
if any((check / m).exists() for m in markers):
return str(check)
check = check.parent
# No project root found — use the file's parent
return str(candidate)
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _take(self, working_dir: str, reason: str) -> bool:
"""Take a snapshot. Returns True on success."""
shadow = _shadow_repo_path(working_dir)
# Init if needed
err = _init_shadow_repo(shadow, working_dir)
if err:
logger.debug("Checkpoint init failed: %s", err)
return False
# Quick size guard — don't try to snapshot enormous directories
if _dir_file_count(working_dir) > _MAX_FILES:
logger.debug("Checkpoint skipped: >%d files in %s", _MAX_FILES, working_dir)
return False
# Stage everything
ok, _, err = _run_git(
["add", "-A"], shadow, working_dir, timeout=_GIT_TIMEOUT * 2,
)
if not ok:
logger.debug("Checkpoint git-add failed: %s", err)
return False
# Check if there's anything to commit
ok_diff, diff_out, _ = _run_git(
["diff", "--cached", "--quiet"],
shadow,
working_dir,
allowed_returncodes={1},
)
if ok_diff:
# No changes to commit
logger.debug("Checkpoint skipped: no changes in %s", working_dir)
return False
# Commit
ok, _, err = _run_git(
["commit", "-m", reason, "--allow-empty-message"],
shadow, working_dir, timeout=_GIT_TIMEOUT * 2,
)
if not ok:
logger.debug("Checkpoint commit failed: %s", err)
return False
logger.debug("Checkpoint taken in %s: %s", working_dir, reason)
# Prune old snapshots
self._prune(shadow, working_dir)
return True
def _prune(self, shadow_repo: Path, working_dir: str) -> None:
"""Keep only the last max_snapshots commits via orphan reset."""
ok, stdout, _ = _run_git(
["rev-list", "--count", "HEAD"], shadow_repo, working_dir,
)
if not ok:
return
try:
count = int(stdout)
except ValueError:
return
if count <= self.max_snapshots:
return
# Get the hash of the commit at the cutoff point
ok, cutoff_hash, _ = _run_git(
["rev-list", "--reverse", "HEAD", "--skip=0",
f"--max-count=1"],
shadow_repo, working_dir,
)
# For simplicity, we don't actually prune — git's pack mechanism
# handles this efficiently, and the objects are small. The log
# listing is already limited by max_snapshots.
# Full pruning would require rebase --onto or filter-branch which
# is fragile for a background feature. We just limit the log view.
logger.debug("Checkpoint repo has %d commits (limit %d)", count, self.max_snapshots)
def format_checkpoint_list(checkpoints: List[Dict], directory: str) -> str:
"""Format checkpoint list for display to user."""
if not checkpoints:
return f"No checkpoints found for {directory}"
lines = [f"📸 Checkpoints for {directory}:\n"]
for i, cp in enumerate(checkpoints, 1):
# Parse ISO timestamp to something readable
ts = cp["timestamp"]
if "T" in ts:
ts = ts.split("T")[1].split("+")[0].split("-")[0][:5] # HH:MM
date = cp["timestamp"].split("T")[0]
ts = f"{date} {ts}"
# Build change summary
files = cp.get("files_changed", 0)
ins = cp.get("insertions", 0)
dele = cp.get("deletions", 0)
if files:
stat = f" ({files} file{'s' if files != 1 else ''}, +{ins}/-{dele})"
else:
stat = ""
lines.append(f" {i}. {cp['short_hash']} {ts} {cp['reason']}{stat}")
lines.append(f"\n /rollback <N> restore to checkpoint N")
lines.append(f" /rollback diff <N> preview changes since checkpoint N")
lines.append(f" /rollback <N> <file> restore a single file from checkpoint N")
return "\n".join(lines)

View file

@ -0,0 +1,141 @@
#!/usr/bin/env python3
"""
Clarify Tool Module - Interactive Clarifying Questions
Allows the agent to present structured multiple-choice questions or open-ended
prompts to the user. In CLI mode, choices are navigable with arrow keys. On
messaging platforms, choices are rendered as a numbered list.
The actual user-interaction logic lives in the platform layer (cli.py for CLI,
gateway/run.py for messaging). This module defines the schema, validation, and
a thin dispatcher that delegates to a platform-provided callback.
"""
import json
from typing import Dict, Any, List, Optional, Callable
# Maximum number of predefined choices the agent can offer.
# A 5th "Other (type your answer)" option is always appended by the UI.
MAX_CHOICES = 4
def clarify_tool(
question: str,
choices: Optional[List[str]] = None,
callback: Optional[Callable] = None,
) -> str:
"""
Ask the user a question, optionally with multiple-choice options.
Args:
question: The question text to present.
choices: Up to 4 predefined answer choices. When omitted the
question is purely open-ended.
callback: Platform-provided function that handles the actual UI
interaction. Signature: callback(question, choices) -> str.
Injected by the agent runner (cli.py / gateway).
Returns:
JSON string with the user's response.
"""
if not question or not question.strip():
return json.dumps({"error": "Question text is required."}, ensure_ascii=False)
question = question.strip()
# Validate and trim choices
if choices is not None:
if not isinstance(choices, list):
return json.dumps({"error": "choices must be a list of strings."}, ensure_ascii=False)
choices = [str(c).strip() for c in choices if str(c).strip()]
if len(choices) > MAX_CHOICES:
choices = choices[:MAX_CHOICES]
if not choices:
choices = None # empty list → open-ended
if callback is None:
return json.dumps(
{"error": "Clarify tool is not available in this execution context."},
ensure_ascii=False,
)
try:
user_response = callback(question, choices)
except Exception as exc:
return json.dumps(
{"error": f"Failed to get user input: {exc}"},
ensure_ascii=False,
)
return json.dumps({
"question": question,
"choices_offered": choices,
"user_response": str(user_response).strip(),
}, ensure_ascii=False)
def check_clarify_requirements() -> bool:
"""Clarify tool has no external requirements -- always available."""
return True
# =============================================================================
# OpenAI Function-Calling Schema
# =============================================================================
CLARIFY_SCHEMA = {
"name": "clarify",
"description": (
"Ask the user a question when you need clarification, feedback, or a "
"decision before proceeding. Supports two modes:\n\n"
"1. **Multiple choice** — provide up to 4 choices. The user picks one "
"or types their own answer via a 5th 'Other' option.\n"
"2. **Open-ended** — omit choices entirely. The user types a free-form "
"response.\n\n"
"Use this tool when:\n"
"- The task is ambiguous and you need the user to choose an approach\n"
"- You want post-task feedback ('How did that work out?')\n"
"- You want to offer to save a skill or update memory\n"
"- A decision has meaningful trade-offs the user should weigh in on\n\n"
"Do NOT use this tool for simple yes/no confirmation of dangerous "
"commands (the terminal tool handles that). Prefer making a reasonable "
"default choice yourself when the decision is low-stakes."
),
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to present to the user.",
},
"choices": {
"type": "array",
"items": {"type": "string"},
"maxItems": MAX_CHOICES,
"description": (
"Up to 4 answer choices. Omit this parameter entirely to "
"ask an open-ended question. When provided, the UI "
"automatically appends an 'Other (type your answer)' option."
),
},
},
"required": ["question"],
},
}
# --- Registry ---
from tools.registry import registry
registry.register(
name="clarify",
toolset="clarify",
schema=CLARIFY_SCHEMA,
handler=lambda args, **kw: clarify_tool(
question=args.get("question", ""),
choices=args.get("choices"),
callback=kw.get("callback")),
check_fn=check_clarify_requirements,
emoji="",
)

View file

@ -0,0 +1,806 @@
#!/usr/bin/env python3
"""
Code Execution Tool -- Programmatic Tool Calling (PTC)
Lets the LLM write a Python script that calls Hermes tools via RPC,
collapsing multi-step tool chains into a single inference turn.
Architecture:
1. Parent generates a `hermes_tools.py` stub module with RPC functions
2. Parent opens a Unix domain socket and starts an RPC listener thread
3. Parent spawns a child process that runs the LLM's script
4. When the script calls a tool function, the call travels over the UDS
back to the parent, which dispatches through handle_function_call
5. Only the script's stdout is returned to the LLM; intermediate tool
results never enter the context window
Platform: Linux / macOS only (Unix domain sockets). Disabled on Windows.
"""
import json
import logging
import os
import platform
import signal
import socket
import subprocess
import sys
import tempfile
import threading
import time
import uuid
_IS_WINDOWS = platform.system() == "Windows"
from typing import Any, Dict, List, Optional
# Availability gate: UDS requires a POSIX OS
logger = logging.getLogger(__name__)
SANDBOX_AVAILABLE = sys.platform != "win32"
# The 7 tools allowed inside the sandbox. The intersection of this list
# and the session's enabled tools determines which stubs are generated.
SANDBOX_ALLOWED_TOOLS = frozenset([
"web_search",
"web_extract",
"read_file",
"write_file",
"search_files",
"patch",
"terminal",
])
# Resource limit defaults (overridable via config.yaml → code_execution.*)
DEFAULT_TIMEOUT = 300 # 5 minutes
DEFAULT_MAX_TOOL_CALLS = 50
MAX_STDOUT_BYTES = 50_000 # 50 KB
MAX_STDERR_BYTES = 10_000 # 10 KB
def check_sandbox_requirements() -> bool:
"""Code execution sandbox requires a POSIX OS for Unix domain sockets."""
return SANDBOX_AVAILABLE
# ---------------------------------------------------------------------------
# hermes_tools.py code generator
# ---------------------------------------------------------------------------
# Per-tool stub templates: (function_name, signature, docstring, args_dict_expr)
# The args_dict_expr builds the JSON payload sent over the RPC socket.
_TOOL_STUBS = {
"web_search": (
"web_search",
"query: str, limit: int = 5",
'"""Search the web. Returns dict with data.web list of {url, title, description}."""',
'{"query": query, "limit": limit}',
),
"web_extract": (
"web_extract",
"urls: list",
'"""Extract content from URLs. Returns dict with results list of {url, title, content, error}."""',
'{"urls": urls}',
),
"read_file": (
"read_file",
"path: str, offset: int = 1, limit: int = 500",
'"""Read a file (1-indexed lines). Returns dict with "content" and "total_lines"."""',
'{"path": path, "offset": offset, "limit": limit}',
),
"write_file": (
"write_file",
"path: str, content: str",
'"""Write content to a file (always overwrites). Returns dict with status."""',
'{"path": path, "content": content}',
),
"search_files": (
"search_files",
'pattern: str, target: str = "content", path: str = ".", file_glob: str = None, limit: int = 50, offset: int = 0, output_mode: str = "content", context: int = 0',
'"""Search file contents (target="content") or find files by name (target="files"). Returns dict with "matches"."""',
'{"pattern": pattern, "target": target, "path": path, "file_glob": file_glob, "limit": limit, "offset": offset, "output_mode": output_mode, "context": context}',
),
"patch": (
"patch",
'path: str = None, old_string: str = None, new_string: str = None, replace_all: bool = False, mode: str = "replace", patch: str = None',
'"""Targeted find-and-replace (mode="replace") or V4A multi-file patches (mode="patch"). Returns dict with status."""',
'{"path": path, "old_string": old_string, "new_string": new_string, "replace_all": replace_all, "mode": mode, "patch": patch}',
),
"terminal": (
"terminal",
"command: str, timeout: int = None, workdir: str = None",
'"""Run a shell command (foreground only). Returns dict with "output" and "exit_code"."""',
'{"command": command, "timeout": timeout, "workdir": workdir}',
),
}
def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
"""
Build the source code for the hermes_tools.py stub module.
Only tools in both SANDBOX_ALLOWED_TOOLS and enabled_tools get stubs.
"""
tools_to_generate = sorted(SANDBOX_ALLOWED_TOOLS & set(enabled_tools))
stub_functions = []
export_names = []
for tool_name in tools_to_generate:
if tool_name not in _TOOL_STUBS:
continue
func_name, sig, doc, args_expr = _TOOL_STUBS[tool_name]
stub_functions.append(
f"def {func_name}({sig}):\n"
f" {doc}\n"
f" return _call({func_name!r}, {args_expr})\n"
)
export_names.append(func_name)
header = '''\
"""Auto-generated Hermes tools RPC stubs."""
import json, os, socket, shlex, time
_sock = None
# ---------------------------------------------------------------------------
# Convenience helpers (avoid common scripting pitfalls)
# ---------------------------------------------------------------------------
def json_parse(text: str):
"""Parse JSON tolerant of control characters (strict=False).
Use this instead of json.loads() when parsing output from terminal()
or web_extract() that may contain raw tabs/newlines in strings."""
return json.loads(text, strict=False)
def shell_quote(s: str) -> str:
"""Shell-escape a string for safe interpolation into commands.
Use this when inserting dynamic content into terminal() commands:
terminal(f"echo {shell_quote(user_input)}")
"""
return shlex.quote(s)
def retry(fn, max_attempts=3, delay=2):
"""Retry a function up to max_attempts times with exponential backoff.
Use for transient failures (network errors, API rate limits):
result = retry(lambda: terminal("gh issue list ..."))
"""
last_err = None
for attempt in range(max_attempts):
try:
return fn()
except Exception as e:
last_err = e
if attempt < max_attempts - 1:
time.sleep(delay * (2 ** attempt))
raise last_err
def _connect():
global _sock
if _sock is None:
_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
_sock.connect(os.environ["HERMES_RPC_SOCKET"])
_sock.settimeout(300)
return _sock
def _call(tool_name, args):
"""Send a tool call to the parent process and return the parsed result."""
conn = _connect()
request = json.dumps({"tool": tool_name, "args": args}) + "\\n"
conn.sendall(request.encode())
buf = b""
while True:
chunk = conn.recv(65536)
if not chunk:
raise RuntimeError("Agent process disconnected")
buf += chunk
if buf.endswith(b"\\n"):
break
raw = buf.decode().strip()
result = json.loads(raw)
if isinstance(result, str):
try:
return json.loads(result)
except (json.JSONDecodeError, TypeError):
return result
return result
'''
return header + "\n".join(stub_functions)
# ---------------------------------------------------------------------------
# RPC server (runs in a thread inside the parent process)
# ---------------------------------------------------------------------------
# Terminal parameters that must not be used from ephemeral sandbox scripts
_TERMINAL_BLOCKED_PARAMS = {"background", "check_interval", "pty"}
def _rpc_server_loop(
server_sock: socket.socket,
task_id: str,
tool_call_log: list,
tool_call_counter: list, # mutable [int] so the thread can increment
max_tool_calls: int,
allowed_tools: frozenset,
):
"""
Accept one client connection and dispatch tool-call requests until
the client disconnects or the call limit is reached.
"""
from model_tools import handle_function_call
conn = None
try:
server_sock.settimeout(5)
conn, _ = server_sock.accept()
conn.settimeout(300)
buf = b""
while True:
try:
chunk = conn.recv(65536)
except socket.timeout:
break
if not chunk:
break
buf += chunk
# Process all complete newline-delimited messages in the buffer
while b"\n" in buf:
line, buf = buf.split(b"\n", 1)
line = line.strip()
if not line:
continue
call_start = time.monotonic()
try:
request = json.loads(line.decode())
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
resp = json.dumps({"error": f"Invalid RPC request: {exc}"})
conn.sendall((resp + "\n").encode())
continue
tool_name = request.get("tool", "")
tool_args = request.get("args", {})
# Enforce the allow-list
if tool_name not in allowed_tools:
available = ", ".join(sorted(allowed_tools))
resp = json.dumps({
"error": (
f"Tool '{tool_name}' is not available in execute_code. "
f"Available: {available}"
)
})
conn.sendall((resp + "\n").encode())
continue
# Enforce tool call limit
if tool_call_counter[0] >= max_tool_calls:
resp = json.dumps({
"error": (
f"Tool call limit reached ({max_tool_calls}). "
"No more tool calls allowed in this execution."
)
})
conn.sendall((resp + "\n").encode())
continue
# Strip forbidden terminal parameters
if tool_name == "terminal" and isinstance(tool_args, dict):
for param in _TERMINAL_BLOCKED_PARAMS:
tool_args.pop(param, None)
# Dispatch through the standard tool handler.
# Suppress stdout/stderr from internal tool handlers so
# their status prints don't leak into the CLI spinner.
try:
_real_stdout, _real_stderr = sys.stdout, sys.stderr
devnull = open(os.devnull, "w")
try:
sys.stdout = devnull
sys.stderr = devnull
result = handle_function_call(
tool_name, tool_args, task_id=task_id
)
finally:
sys.stdout, sys.stderr = _real_stdout, _real_stderr
devnull.close()
except Exception as exc:
logger.error("Tool call failed in sandbox: %s", exc, exc_info=True)
result = json.dumps({"error": str(exc)})
tool_call_counter[0] += 1
call_duration = time.monotonic() - call_start
# Log for observability
args_preview = str(tool_args)[:80]
tool_call_log.append({
"tool": tool_name,
"args_preview": args_preview,
"duration": round(call_duration, 2),
})
conn.sendall((result + "\n").encode())
except socket.timeout:
logger.debug("RPC listener socket timeout")
except OSError as e:
logger.debug("RPC listener socket error: %s", e, exc_info=True)
finally:
if conn:
try:
conn.close()
except OSError as e:
logger.debug("RPC conn close error: %s", e)
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def execute_code(
code: str,
task_id: Optional[str] = None,
enabled_tools: Optional[List[str]] = None,
) -> str:
"""
Run a Python script in a sandboxed child process with RPC access
to a subset of Hermes tools.
Args:
code: Python source code to execute.
task_id: Session task ID for tool isolation (terminal env, etc.).
enabled_tools: Tool names enabled in the current session. The sandbox
gets the intersection with SANDBOX_ALLOWED_TOOLS.
Returns:
JSON string with execution results.
"""
if not SANDBOX_AVAILABLE:
return json.dumps({
"error": "execute_code is not available on Windows. Use normal tool calls instead."
})
if not code or not code.strip():
return json.dumps({"error": "No code provided."})
# Import interrupt event from terminal_tool (cooperative cancellation)
from tools.terminal_tool import _interrupt_event
# Resolve config
_cfg = _load_config()
timeout = _cfg.get("timeout", DEFAULT_TIMEOUT)
max_tool_calls = _cfg.get("max_tool_calls", DEFAULT_MAX_TOOL_CALLS)
# Determine which tools the sandbox can call
session_tools = set(enabled_tools) if enabled_tools else set()
sandbox_tools = frozenset(SANDBOX_ALLOWED_TOOLS & session_tools)
if not sandbox_tools:
sandbox_tools = SANDBOX_ALLOWED_TOOLS
# --- Set up temp directory with hermes_tools.py and script.py ---
tmpdir = tempfile.mkdtemp(prefix="hermes_sandbox_")
# Use /tmp on macOS to avoid the long /var/folders/... path that pushes
# Unix domain socket paths past the 104-byte macOS AF_UNIX limit.
# On Linux, tempfile.gettempdir() already returns /tmp.
_sock_tmpdir = "/tmp" if sys.platform == "darwin" else tempfile.gettempdir()
sock_path = os.path.join(_sock_tmpdir, f"hermes_rpc_{uuid.uuid4().hex}.sock")
tool_call_log: list = []
tool_call_counter = [0] # mutable so the RPC thread can increment
exec_start = time.monotonic()
server_sock = None
try:
# Write the auto-generated hermes_tools module
# sandbox_tools is already the correct set (intersection with session
# tools, or SANDBOX_ALLOWED_TOOLS as fallback — see lines above).
tools_src = generate_hermes_tools_module(list(sandbox_tools))
with open(os.path.join(tmpdir, "hermes_tools.py"), "w") as f:
f.write(tools_src)
# Write the user's script
with open(os.path.join(tmpdir, "script.py"), "w") as f:
f.write(code)
# --- Start UDS server ---
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(sock_path)
server_sock.listen(1)
rpc_thread = threading.Thread(
target=_rpc_server_loop,
args=(
server_sock, task_id, tool_call_log,
tool_call_counter, max_tool_calls, sandbox_tools,
),
daemon=True,
)
rpc_thread.start()
# --- Spawn child process ---
# Build a minimal environment for the child. We intentionally exclude
# API keys and tokens to prevent credential exfiltration from LLM-
# generated scripts. The child accesses tools via RPC, not direct API.
# Exception: env vars declared by loaded skills (via env_passthrough
# registry) or explicitly allowed by the user in config.yaml
# (terminal.env_passthrough) are passed through.
_SAFE_ENV_PREFIXES = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM",
"TMPDIR", "TMP", "TEMP", "SHELL", "LOGNAME",
"XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA")
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL",
"PASSWD", "AUTH")
try:
from tools.env_passthrough import is_env_passthrough as _is_passthrough
except Exception:
_is_passthrough = lambda _: False # noqa: E731
child_env = {}
for k, v in os.environ.items():
# Passthrough vars (skill-declared or user-configured) always pass.
if _is_passthrough(k):
child_env[k] = v
continue
# Block vars with secret-like names.
if any(s in k.upper() for s in _SECRET_SUBSTRINGS):
continue
# Allow vars with known safe prefixes.
if any(k.startswith(p) for p in _SAFE_ENV_PREFIXES):
child_env[k] = v
child_env["HERMES_RPC_SOCKET"] = sock_path
child_env["PYTHONDONTWRITEBYTECODE"] = "1"
# Ensure the hermes-agent root is importable in the sandbox so
# repo-root modules are available to child scripts.
_hermes_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_existing_pp = child_env.get("PYTHONPATH", "")
child_env["PYTHONPATH"] = _hermes_root + (os.pathsep + _existing_pp if _existing_pp else "")
# Inject user's configured timezone so datetime.now() in sandboxed
# code reflects the correct wall-clock time.
_tz_name = os.getenv("HERMES_TIMEZONE", "").strip()
if _tz_name:
child_env["TZ"] = _tz_name
proc = subprocess.Popen(
[sys.executable, "script.py"],
cwd=tmpdir,
env=child_env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.DEVNULL,
preexec_fn=None if _IS_WINDOWS else os.setsid,
)
# --- Poll loop: watch for exit, timeout, and interrupt ---
deadline = time.monotonic() + timeout
stderr_chunks: list = []
# Background readers to avoid pipe buffer deadlocks.
# For stdout we use a head+tail strategy: keep the first HEAD_BYTES
# and a rolling window of the last TAIL_BYTES so the final print()
# output is never lost. Stderr keeps head-only (errors appear early).
_STDOUT_HEAD_BYTES = int(MAX_STDOUT_BYTES * 0.4) # 40% head
_STDOUT_TAIL_BYTES = MAX_STDOUT_BYTES - _STDOUT_HEAD_BYTES # 60% tail
def _drain(pipe, chunks, max_bytes):
"""Simple head-only drain (used for stderr)."""
total = 0
try:
while True:
data = pipe.read(4096)
if not data:
break
if total < max_bytes:
keep = max_bytes - total
chunks.append(data[:keep])
total += len(data)
except (ValueError, OSError) as e:
logger.debug("Error reading process output: %s", e, exc_info=True)
stdout_total_bytes = [0] # mutable ref for total bytes seen
def _drain_head_tail(pipe, head_chunks, tail_chunks, head_bytes, tail_bytes, total_ref):
"""Drain stdout keeping both head and tail data."""
head_collected = 0
from collections import deque
tail_buf = deque()
tail_collected = 0
try:
while True:
data = pipe.read(4096)
if not data:
break
total_ref[0] += len(data)
# Fill head buffer first
if head_collected < head_bytes:
keep = min(len(data), head_bytes - head_collected)
head_chunks.append(data[:keep])
head_collected += keep
data = data[keep:] # remaining goes to tail
if not data:
continue
# Everything past head goes into rolling tail buffer
tail_buf.append(data)
tail_collected += len(data)
# Evict old tail data to stay within tail_bytes budget
while tail_collected > tail_bytes and tail_buf:
oldest = tail_buf.popleft()
tail_collected -= len(oldest)
except (ValueError, OSError):
pass
# Transfer final tail to output list
tail_chunks.extend(tail_buf)
stdout_head_chunks: list = []
stdout_tail_chunks: list = []
stdout_reader = threading.Thread(
target=_drain_head_tail,
args=(proc.stdout, stdout_head_chunks, stdout_tail_chunks,
_STDOUT_HEAD_BYTES, _STDOUT_TAIL_BYTES, stdout_total_bytes),
daemon=True
)
stderr_reader = threading.Thread(
target=_drain, args=(proc.stderr, stderr_chunks, MAX_STDERR_BYTES), daemon=True
)
stdout_reader.start()
stderr_reader.start()
status = "success"
while proc.poll() is None:
if _interrupt_event.is_set():
_kill_process_group(proc)
status = "interrupted"
break
if time.monotonic() > deadline:
_kill_process_group(proc, escalate=True)
status = "timeout"
break
time.sleep(0.2)
# Wait for readers to finish draining
stdout_reader.join(timeout=3)
stderr_reader.join(timeout=3)
stdout_head = b"".join(stdout_head_chunks).decode("utf-8", errors="replace")
stdout_tail = b"".join(stdout_tail_chunks).decode("utf-8", errors="replace")
stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace")
# Assemble stdout with head+tail truncation
total_stdout = stdout_total_bytes[0]
if total_stdout > MAX_STDOUT_BYTES and stdout_tail:
omitted = total_stdout - len(stdout_head) - len(stdout_tail)
truncated_notice = (
f"\n\n... [OUTPUT TRUNCATED - {omitted:,} chars omitted "
f"out of {total_stdout:,} total] ...\n\n"
)
stdout_text = stdout_head + truncated_notice + stdout_tail
else:
stdout_text = stdout_head + stdout_tail
exit_code = proc.returncode if proc.returncode is not None else -1
duration = round(time.monotonic() - exec_start, 2)
# Wait for RPC thread to finish
server_sock.close() # break accept() so thread exits promptly
server_sock = None # prevent double close in finally
rpc_thread.join(timeout=3)
# Strip ANSI escape sequences so the model never sees terminal
# formatting — prevents it from copying escapes into file writes.
from tools.ansi_strip import strip_ansi
stdout_text = strip_ansi(stdout_text)
stderr_text = strip_ansi(stderr_text)
# Build response
result: Dict[str, Any] = {
"status": status,
"output": stdout_text,
"tool_calls_made": tool_call_counter[0],
"duration_seconds": duration,
}
if status == "timeout":
result["error"] = f"Script timed out after {timeout}s and was killed."
elif status == "interrupted":
result["output"] = stdout_text + "\n[execution interrupted — user sent a new message]"
elif exit_code != 0:
result["status"] = "error"
result["error"] = stderr_text or f"Script exited with code {exit_code}"
# Include stderr in output so the LLM sees the traceback
if stderr_text:
result["output"] = stdout_text + "\n--- stderr ---\n" + stderr_text
return json.dumps(result, ensure_ascii=False)
except Exception as exc:
duration = round(time.monotonic() - exec_start, 2)
logger.error(
"execute_code failed after %ss with %d tool calls: %s: %s",
duration,
tool_call_counter[0],
type(exc).__name__,
exc,
exc_info=True,
)
return json.dumps({
"status": "error",
"error": str(exc),
"tool_calls_made": tool_call_counter[0],
"duration_seconds": duration,
}, ensure_ascii=False)
finally:
# Cleanup temp dir and socket
if server_sock is not None:
try:
server_sock.close()
except OSError as e:
logger.debug("Server socket close error: %s", e)
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)
try:
os.unlink(sock_path)
except OSError:
pass # already cleaned up or never created
def _kill_process_group(proc, escalate: bool = False):
"""Kill the child and its entire process group."""
try:
if _IS_WINDOWS:
proc.terminate()
else:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except (ProcessLookupError, PermissionError) as e:
logger.debug("Could not kill process group: %s", e, exc_info=True)
try:
proc.kill()
except Exception as e2:
logger.debug("Could not kill process: %s", e2, exc_info=True)
if escalate:
# Give the process 5s to exit after SIGTERM, then SIGKILL
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
try:
if _IS_WINDOWS:
proc.kill()
else:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except (ProcessLookupError, PermissionError) as e:
logger.debug("Could not kill process group with SIGKILL: %s", e, exc_info=True)
try:
proc.kill()
except Exception as e2:
logger.debug("Could not kill process: %s", e2, exc_info=True)
def _load_config() -> dict:
"""Load code_execution config from CLI_CONFIG if available."""
try:
from cli import CLI_CONFIG
return CLI_CONFIG.get("code_execution", {})
except Exception:
return {}
# ---------------------------------------------------------------------------
# OpenAI Function-Calling Schema
# ---------------------------------------------------------------------------
# Per-tool documentation lines for the execute_code description.
# Ordered to match the canonical display order.
_TOOL_DOC_LINES = [
("web_search",
" web_search(query: str, limit: int = 5) -> dict\n"
" Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}"),
("web_extract",
" web_extract(urls: list[str]) -> dict\n"
" Returns {\"results\": [{\"url\", \"title\", \"content\", \"error\"}, ...]} where content is markdown"),
("read_file",
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
" Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}"),
("write_file",
" write_file(path: str, content: str) -> dict\n"
" Always overwrites the entire file."),
("search_files",
" search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n"
" target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}"),
("patch",
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
" Replaces old_string with new_string in the file."),
("terminal",
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
" Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}"),
]
def build_execute_code_schema(enabled_sandbox_tools: set = None) -> dict:
"""Build the execute_code schema with description listing only enabled tools.
When tools are disabled via ``hermes tools`` (e.g. web is turned off),
the schema description should NOT mention web_search / web_extract
otherwise the model thinks they are available and keeps trying to use them.
"""
if enabled_sandbox_tools is None:
enabled_sandbox_tools = SANDBOX_ALLOWED_TOOLS
# Build tool documentation lines for only the enabled tools
tool_lines = "\n".join(
doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools
)
# Build example import list from enabled tools
import_examples = [n for n in ("web_search", "terminal") if n in enabled_sandbox_tools]
if not import_examples:
import_examples = sorted(enabled_sandbox_tools)[:2]
if import_examples:
import_str = ", ".join(import_examples) + ", ..."
else:
import_str = "..."
description = (
"Run a Python script that can call Hermes tools programmatically. "
"Use this when you need 3+ tool calls with processing logic between them, "
"need to filter/reduce large tool outputs before they enter your context, "
"need conditional branching (if X then Y else Z), or need to loop "
"(fetch N pages, process N files, retry on failure).\n\n"
"Use normal tool calls instead when: single tool call with no processing, "
"you need to see the full result and apply complex reasoning, "
"or the task requires interactive user input.\n\n"
f"Available via `from hermes_tools import ...`:\n\n"
f"{tool_lines}\n\n"
"Limits: 5-minute timeout, 50KB stdout cap, max 50 tool calls per script. "
"terminal() is foreground-only (no background or pty).\n\n"
"Print your final result to stdout. Use Python stdlib (json, re, math, csv, "
"datetime, collections, etc.) for processing between tool calls.\n\n"
"Also available (no import needed — built into hermes_tools):\n"
" json_parse(text: str) — json.loads with strict=False; use for terminal() output with control chars\n"
" shell_quote(s: str) — shlex.quote(); use when interpolating dynamic strings into shell commands\n"
" retry(fn, max_attempts=3, delay=2) — retry with exponential backoff for transient failures"
)
return {
"name": "execute_code",
"description": description,
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": (
"Python code to execute. Import tools with "
f"`from hermes_tools import {import_str}` "
"and print your final result to stdout."
),
},
},
"required": ["code"],
},
}
# Default schema used at registration time (all sandbox tools listed)
EXECUTE_CODE_SCHEMA = build_execute_code_schema()
# --- Registry ---
from tools.registry import registry
registry.register(
name="execute_code",
toolset="code_execution",
schema=EXECUTE_CODE_SCHEMA,
handler=lambda args, **kw: execute_code(
code=args.get("code", ""),
task_id=kw.get("task_id"),
enabled_tools=kw.get("enabled_tools")),
check_fn=check_sandbox_requirements,
emoji="🐍",
)

View file

@ -0,0 +1,458 @@
"""
Cron job management tools for Hermes Agent.
Expose a single compressed action-oriented tool to avoid schema/context bloat.
Compatibility wrappers remain for direct Python callers and legacy tests.
"""
import json
import os
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
# Import from cron module (will be available when properly installed)
sys.path.insert(0, str(Path(__file__).parent.parent))
from cron.jobs import (
create_job,
get_job,
list_jobs,
parse_schedule,
pause_job,
remove_job,
resume_job,
trigger_job,
update_job,
)
# ---------------------------------------------------------------------------
# Cron prompt scanning — critical-severity patterns only, since cron prompts
# run in fresh sessions with full tool access.
# ---------------------------------------------------------------------------
_CRON_THREAT_PATTERNS = [
(r'ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions', "prompt_injection"),
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
(r'system\s+prompt\s+override', "sys_prompt_override"),
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
(r'authorized_keys', "ssh_backdoor"),
(r'/etc/sudoers|visudo', "sudoers_mod"),
(r'rm\s+-rf\s+/', "destructive_root_rm"),
]
_CRON_INVISIBLE_CHARS = {
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
}
def _scan_cron_prompt(prompt: str) -> str:
"""Scan a cron prompt for critical threats. Returns error string if blocked, else empty."""
for char in _CRON_INVISIBLE_CHARS:
if char in prompt:
return f"Blocked: prompt contains invisible unicode U+{ord(char):04X} (possible injection)."
for pattern, pid in _CRON_THREAT_PATTERNS:
if re.search(pattern, prompt, re.IGNORECASE):
return f"Blocked: prompt matches threat pattern '{pid}'. Cron prompts must not contain injection or exfiltration payloads."
return ""
def _origin_from_env() -> Optional[Dict[str, str]]:
origin_platform = os.getenv("HERMES_SESSION_PLATFORM")
origin_chat_id = os.getenv("HERMES_SESSION_CHAT_ID")
if origin_platform and origin_chat_id:
return {
"platform": origin_platform,
"chat_id": origin_chat_id,
"chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"),
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID"),
}
return None
def _repeat_display(job: Dict[str, Any]) -> str:
times = (job.get("repeat") or {}).get("times")
completed = (job.get("repeat") or {}).get("completed", 0)
if times is None:
return "forever"
if times == 1:
return "once" if completed == 0 else "1/1"
return f"{completed}/{times}" if completed else f"{times} times"
def _canonical_skills(skill: Optional[str] = None, skills: Optional[Any] = None) -> List[str]:
if skills is None:
raw_items = [skill] if skill else []
elif isinstance(skills, str):
raw_items = [skills]
else:
raw_items = list(skills)
normalized: List[str] = []
for item in raw_items:
text = str(item or "").strip()
if text and text not in normalized:
normalized.append(text)
return normalized
def _normalize_optional_job_value(value: Optional[Any], *, strip_trailing_slash: bool = False) -> Optional[str]:
if value is None:
return None
text = str(value).strip()
if strip_trailing_slash:
text = text.rstrip("/")
return text or None
def _format_job(job: Dict[str, Any]) -> Dict[str, Any]:
prompt = job.get("prompt", "")
skills = _canonical_skills(job.get("skill"), job.get("skills"))
return {
"job_id": job["id"],
"name": job["name"],
"skill": skills[0] if skills else None,
"skills": skills,
"prompt_preview": prompt[:100] + "..." if len(prompt) > 100 else prompt,
"model": job.get("model"),
"provider": job.get("provider"),
"base_url": job.get("base_url"),
"schedule": job.get("schedule_display"),
"repeat": _repeat_display(job),
"deliver": job.get("deliver", "local"),
"next_run_at": job.get("next_run_at"),
"last_run_at": job.get("last_run_at"),
"last_status": job.get("last_status"),
"enabled": job.get("enabled", True),
"state": job.get("state", "scheduled" if job.get("enabled", True) else "paused"),
"paused_at": job.get("paused_at"),
"paused_reason": job.get("paused_reason"),
}
def cronjob(
action: str,
job_id: Optional[str] = None,
prompt: Optional[str] = None,
schedule: Optional[str] = None,
name: Optional[str] = None,
repeat: Optional[int] = None,
deliver: Optional[str] = None,
include_disabled: bool = False,
skill: Optional[str] = None,
skills: Optional[List[str]] = None,
model: Optional[str] = None,
provider: Optional[str] = None,
base_url: Optional[str] = None,
reason: Optional[str] = None,
task_id: str = None,
) -> str:
"""Unified cron job management tool."""
del task_id # unused but kept for handler signature compatibility
try:
normalized = (action or "").strip().lower()
if normalized == "create":
if not schedule:
return json.dumps({"success": False, "error": "schedule is required for create"}, indent=2)
canonical_skills = _canonical_skills(skill, skills)
if not prompt and not canonical_skills:
return json.dumps({"success": False, "error": "create requires either prompt or at least one skill"}, indent=2)
if prompt:
scan_error = _scan_cron_prompt(prompt)
if scan_error:
return json.dumps({"success": False, "error": scan_error}, indent=2)
job = create_job(
prompt=prompt or "",
schedule=schedule,
name=name,
repeat=repeat,
deliver=deliver,
origin=_origin_from_env(),
skills=canonical_skills,
model=_normalize_optional_job_value(model),
provider=_normalize_optional_job_value(provider),
base_url=_normalize_optional_job_value(base_url, strip_trailing_slash=True),
)
return json.dumps(
{
"success": True,
"job_id": job["id"],
"name": job["name"],
"skill": job.get("skill"),
"skills": job.get("skills", []),
"schedule": job["schedule_display"],
"repeat": _repeat_display(job),
"deliver": job.get("deliver", "local"),
"next_run_at": job["next_run_at"],
"job": _format_job(job),
"message": f"Cron job '{job['name']}' created.",
},
indent=2,
)
if normalized == "list":
jobs = [_format_job(job) for job in list_jobs(include_disabled=include_disabled)]
return json.dumps({"success": True, "count": len(jobs), "jobs": jobs}, indent=2)
if not job_id:
return json.dumps({"success": False, "error": f"job_id is required for action '{normalized}'"}, indent=2)
job = get_job(job_id)
if not job:
return json.dumps(
{"success": False, "error": f"Job with ID '{job_id}' not found. Use cronjob(action='list') to inspect jobs."},
indent=2,
)
if normalized == "remove":
removed = remove_job(job_id)
if not removed:
return json.dumps({"success": False, "error": f"Failed to remove job '{job_id}'"}, indent=2)
return json.dumps(
{
"success": True,
"message": f"Cron job '{job['name']}' removed.",
"removed_job": {
"id": job_id,
"name": job["name"],
"schedule": job.get("schedule_display"),
},
},
indent=2,
)
if normalized == "pause":
updated = pause_job(job_id, reason=reason)
return json.dumps({"success": True, "job": _format_job(updated)}, indent=2)
if normalized == "resume":
updated = resume_job(job_id)
return json.dumps({"success": True, "job": _format_job(updated)}, indent=2)
if normalized in {"run", "run_now", "trigger"}:
updated = trigger_job(job_id)
return json.dumps({"success": True, "job": _format_job(updated)}, indent=2)
if normalized == "update":
updates: Dict[str, Any] = {}
if prompt is not None:
scan_error = _scan_cron_prompt(prompt)
if scan_error:
return json.dumps({"success": False, "error": scan_error}, indent=2)
updates["prompt"] = prompt
if name is not None:
updates["name"] = name
if deliver is not None:
updates["deliver"] = deliver
if skills is not None or skill is not None:
canonical_skills = _canonical_skills(skill, skills)
updates["skills"] = canonical_skills
updates["skill"] = canonical_skills[0] if canonical_skills else None
if model is not None:
updates["model"] = _normalize_optional_job_value(model)
if provider is not None:
updates["provider"] = _normalize_optional_job_value(provider)
if base_url is not None:
updates["base_url"] = _normalize_optional_job_value(base_url, strip_trailing_slash=True)
if repeat is not None:
# Normalize: treat 0 or negative as None (infinite)
normalized_repeat = None if repeat <= 0 else repeat
repeat_state = dict(job.get("repeat") or {})
repeat_state["times"] = normalized_repeat
updates["repeat"] = repeat_state
if schedule is not None:
parsed_schedule = parse_schedule(schedule)
updates["schedule"] = parsed_schedule
updates["schedule_display"] = parsed_schedule.get("display", schedule)
if job.get("state") != "paused":
updates["state"] = "scheduled"
updates["enabled"] = True
if not updates:
return json.dumps({"success": False, "error": "No updates provided."}, indent=2)
updated = update_job(job_id, updates)
return json.dumps({"success": True, "job": _format_job(updated)}, indent=2)
return json.dumps({"success": False, "error": f"Unknown cron action '{action}'"}, indent=2)
except Exception as e:
return json.dumps({"success": False, "error": str(e)}, indent=2)
# ---------------------------------------------------------------------------
# Compatibility wrappers
# ---------------------------------------------------------------------------
def schedule_cronjob(
prompt: str,
schedule: str,
name: Optional[str] = None,
repeat: Optional[int] = None,
deliver: Optional[str] = None,
model: Optional[str] = None,
provider: Optional[str] = None,
base_url: Optional[str] = None,
task_id: str = None,
) -> str:
return cronjob(
action="create",
prompt=prompt,
schedule=schedule,
name=name,
repeat=repeat,
deliver=deliver,
model=model,
provider=provider,
base_url=base_url,
task_id=task_id,
)
def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
return cronjob(action="list", include_disabled=include_disabled, task_id=task_id)
def remove_cronjob(job_id: str, task_id: str = None) -> str:
return cronjob(action="remove", job_id=job_id, task_id=task_id)
CRONJOB_SCHEMA = {
"name": "cronjob",
"description": """Manage scheduled cron jobs with a single compressed tool.
Use action='create' to schedule a new job from a prompt or one or more skills.
Use action='list' to inspect jobs.
Use action='update', 'pause', 'resume', 'remove', or 'run' to manage an existing job.
Jobs run in a fresh session with no current-chat context, so prompts must be self-contained.
If skill or skills are provided on create, the future cron run loads those skills in order, then follows the prompt as the task instruction.
On update, passing skills=[] clears attached skills.
NOTE: The agent's final response is auto-delivered to the target. Put the primary
user-facing content in the final response. Cron jobs run autonomously with no user
present they cannot ask questions or request clarification.
Important safety rule: cron-run sessions should not recursively schedule more cron jobs.""",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"description": "One of: create, list, update, pause, resume, remove, run"
},
"job_id": {
"type": "string",
"description": "Required for update/pause/resume/remove/run"
},
"prompt": {
"type": "string",
"description": "For create: the full self-contained prompt. If skill or skills are also provided, this becomes the task instruction paired with those skills."
},
"schedule": {
"type": "string",
"description": "For create/update: '30m', 'every 2h', '0 9 * * *', or ISO timestamp"
},
"name": {
"type": "string",
"description": "Optional human-friendly name"
},
"repeat": {
"type": "integer",
"description": "Optional repeat count. Omit for defaults (once for one-shot, forever for recurring)."
},
"deliver": {
"type": "string",
"description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, email, sms, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'"
},
"model": {
"type": "string",
"description": "Optional per-job model override used when the cron job runs"
},
"provider": {
"type": "string",
"description": "Optional per-job provider override used when resolving runtime credentials"
},
"base_url": {
"type": "string",
"description": "Optional per-job base URL override paired with provider/model routing"
},
"include_disabled": {
"type": "boolean",
"description": "For list: include paused/completed jobs"
},
"skill": {
"type": "string",
"description": "Optional single skill name to load before executing the cron prompt"
},
"skills": {
"type": "array",
"items": {"type": "string"},
"description": "Optional ordered list of skills to load before executing the cron prompt. On update, pass an empty array to clear attached skills."
},
"reason": {
"type": "string",
"description": "Optional pause reason"
}
},
"required": ["action"]
}
}
def check_cronjob_requirements() -> bool:
"""
Check if cronjob tools can be used.
Available in interactive CLI mode and gateway/messaging platforms.
The cron system is internal (JSON file-based scheduler ticked by the gateway),
so no external crontab executable is required.
"""
return bool(
os.getenv("HERMES_INTERACTIVE")
or os.getenv("HERMES_GATEWAY_SESSION")
or os.getenv("HERMES_EXEC_ASK")
)
def get_cronjob_tool_definitions():
"""Return tool definitions for cronjob management."""
return [CRONJOB_SCHEMA]
# --- Registry ---
from tools.registry import registry
registry.register(
name="cronjob",
toolset="cronjob",
schema=CRONJOB_SCHEMA,
handler=lambda args, **kw: cronjob(
action=args.get("action", ""),
job_id=args.get("job_id"),
prompt=args.get("prompt"),
schedule=args.get("schedule"),
name=args.get("name"),
repeat=args.get("repeat"),
deliver=args.get("deliver"),
include_disabled=args.get("include_disabled", False),
skill=args.get("skill"),
skills=args.get("skills"),
model=args.get("model"),
provider=args.get("provider"),
base_url=args.get("base_url"),
reason=args.get("reason"),
task_id=kw.get("task_id"),
),
check_fn=check_cronjob_requirements,
emoji="",
)

View file

@ -0,0 +1,104 @@
"""Shared debug session infrastructure for Hermes tools.
Replaces the identical DEBUG_MODE / _log_debug_call / _save_debug_log /
get_debug_session_info boilerplate previously duplicated across web_tools,
vision_tools, mixture_of_agents_tool, and image_generation_tool.
Usage in a tool module:
from tools.debug_helpers import DebugSession
_debug = DebugSession("web_tools", env_var="WEB_TOOLS_DEBUG")
# Log a call (no-op when debug mode is off)
_debug.log_call("web_search", {"query": q, "results": len(r)})
# Save the debug log (no-op when debug mode is off)
_debug.save()
# Expose debug info to external callers
def get_debug_session_info():
return _debug.get_session_info()
"""
import datetime
import json
import logging
import os
import uuid
from pathlib import Path
from typing import Any, Dict
logger = logging.getLogger(__name__)
class DebugSession:
"""Per-tool debug session that records tool calls to a JSON log file.
Activated by a tool-specific environment variable (e.g. WEB_TOOLS_DEBUG=true).
When disabled, all methods are cheap no-ops.
"""
def __init__(self, tool_name: str, *, env_var: str) -> None:
self.tool_name = tool_name
self.enabled = os.getenv(env_var, "false").lower() == "true"
self.session_id = str(uuid.uuid4()) if self.enabled else ""
self.log_dir = Path("./logs")
self._calls: list[Dict[str, Any]] = []
self._start_time = datetime.datetime.now().isoformat() if self.enabled else ""
if self.enabled:
self.log_dir.mkdir(exist_ok=True)
logger.debug("%s debug mode enabled - Session ID: %s",
tool_name, self.session_id)
@property
def active(self) -> bool:
return self.enabled
def log_call(self, call_name: str, call_data: Dict[str, Any]) -> None:
"""Append a tool-call entry to the in-memory log."""
if not self.enabled:
return
self._calls.append({
"timestamp": datetime.datetime.now().isoformat(),
"tool_name": call_name,
**call_data,
})
def save(self) -> None:
"""Flush the in-memory log to a JSON file in the logs directory."""
if not self.enabled:
return
try:
filename = f"{self.tool_name}_debug_{self.session_id}.json"
filepath = self.log_dir / filename
payload = {
"session_id": self.session_id,
"start_time": self._start_time,
"end_time": datetime.datetime.now().isoformat(),
"debug_enabled": True,
"total_calls": len(self._calls),
"tool_calls": self._calls,
}
with open(filepath, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)
logger.debug("%s debug log saved: %s", self.tool_name, filepath)
except Exception as e:
logger.error("Error saving %s debug log: %s", self.tool_name, e)
def get_session_info(self) -> Dict[str, Any]:
"""Return a summary dict suitable for returning from get_debug_session_info()."""
if not self.enabled:
return {
"enabled": False,
"session_id": None,
"log_path": None,
"total_calls": 0,
}
return {
"enabled": True,
"session_id": self.session_id,
"log_path": str(self.log_dir / f"{self.tool_name}_debug_{self.session_id}.json"),
"total_calls": len(self._calls),
}

View file

@ -0,0 +1,789 @@
#!/usr/bin/env python3
"""
Delegate Tool -- Subagent Architecture
Spawns child AIAgent instances with isolated context, restricted toolsets,
and their own terminal sessions. Supports single-task and batch (parallel)
modes. The parent blocks until all children complete.
Each child gets:
- A fresh conversation (no parent history)
- Its own task_id (own terminal session, file ops cache)
- A restricted toolset (configurable, with blocked tools always stripped)
- A focused system prompt built from the delegated goal + context
The parent's context only sees the delegation call and the summary result,
never the child's intermediate tool calls or reasoning.
"""
import json
import logging
logger = logging.getLogger(__name__)
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional
# Tools that children must never have access to
DELEGATE_BLOCKED_TOOLS = frozenset([
"delegate_task", # no recursive delegation
"clarify", # no user interaction
"memory", # no writes to shared MEMORY.md
"send_message", # no cross-platform side effects
"execute_code", # children should reason step-by-step, not write scripts
])
MAX_CONCURRENT_CHILDREN = 3
MAX_DEPTH = 2 # parent (0) -> child (1) -> grandchild rejected (2)
DEFAULT_MAX_ITERATIONS = 50
DEFAULT_TOOLSETS = ["terminal", "file", "web"]
def check_delegate_requirements() -> bool:
"""Delegation has no external requirements -- always available."""
return True
def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
"""Build a focused system prompt for a child agent."""
parts = [
"You are a focused subagent working on a specific delegated task.",
"",
f"YOUR TASK:\n{goal}",
]
if context and context.strip():
parts.append(f"\nCONTEXT:\n{context}")
parts.append(
"\nComplete this task using the tools available to you. "
"When finished, provide a clear, concise summary of:\n"
"- What you did\n"
"- What you found or accomplished\n"
"- Any files you created or modified\n"
"- Any issues encountered\n\n"
"Be thorough but concise -- your response is returned to the "
"parent agent as a summary."
)
return "\n".join(parts)
def _strip_blocked_tools(toolsets: List[str]) -> List[str]:
"""Remove toolsets that contain only blocked tools."""
blocked_toolset_names = {
"delegation", "clarify", "memory", "code_execution",
}
return [t for t in toolsets if t not in blocked_toolset_names]
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Optional[callable]:
"""Build a callback that relays child agent tool calls to the parent display.
Two display paths:
CLI: prints tree-view lines above the parent's delegation spinner
Gateway: batches tool names and relays to parent's progress callback
Returns None if no display mechanism is available, in which case the
child agent runs with no progress callback (identical to current behavior).
"""
spinner = getattr(parent_agent, '_delegate_spinner', None)
parent_cb = getattr(parent_agent, 'tool_progress_callback', None)
if not spinner and not parent_cb:
return None # No display → no callback → zero behavior change
# Show 1-indexed prefix only in batch mode (multiple tasks)
prefix = f"[{task_index + 1}] " if task_count > 1 else ""
# Gateway: batch tool names, flush periodically
_BATCH_SIZE = 5
_batch: List[str] = []
def _callback(tool_name: str, preview: str = None):
# Special "_thinking" event: model produced text content (reasoning)
if tool_name == "_thinking":
if spinner:
short = (preview[:55] + "...") if preview and len(preview) > 55 else (preview or "")
try:
spinner.print_above(f" {prefix}├─ 💭 \"{short}\"")
except Exception as e:
logger.debug("Spinner print_above failed: %s", e)
# Don't relay thinking to gateway (too noisy for chat)
return
# Regular tool call event
if spinner:
short = (preview[:35] + "...") if preview and len(preview) > 35 else (preview or "")
from agent.display import get_tool_emoji
emoji = get_tool_emoji(tool_name)
line = f" {prefix}├─ {emoji} {tool_name}"
if short:
line += f" \"{short}\""
try:
spinner.print_above(line)
except Exception as e:
logger.debug("Spinner print_above failed: %s", e)
if parent_cb:
_batch.append(tool_name)
if len(_batch) >= _BATCH_SIZE:
summary = ", ".join(_batch)
try:
parent_cb("subagent_progress", f"🔀 {prefix}{summary}")
except Exception as e:
logger.debug("Parent callback failed: %s", e)
_batch.clear()
def _flush():
"""Flush remaining batched tool names to gateway on completion."""
if parent_cb and _batch:
summary = ", ".join(_batch)
try:
parent_cb("subagent_progress", f"🔀 {prefix}{summary}")
except Exception as e:
logger.debug("Parent callback flush failed: %s", e)
_batch.clear()
_callback._flush = _flush
return _callback
def _build_child_agent(
task_index: int,
goal: str,
context: Optional[str],
toolsets: Optional[List[str]],
model: Optional[str],
max_iterations: int,
parent_agent,
# Credential overrides from delegation config (provider:model resolution)
override_provider: Optional[str] = None,
override_base_url: Optional[str] = None,
override_api_key: Optional[str] = None,
override_api_mode: Optional[str] = None,
):
"""
Build a child AIAgent on the main thread (thread-safe construction).
Returns the constructed child agent without running it.
When override_* params are set (from delegation config), the child uses
those credentials instead of inheriting from the parent. This enables
routing subagents to a different provider:model pair (e.g. cheap/fast
model on OpenRouter while the parent runs on Nous Portal).
"""
from run_agent import AIAgent
import model_tools
# When no explicit toolsets given, inherit from parent's enabled toolsets
# so disabled tools (e.g. web) don't leak to subagents.
if toolsets:
child_toolsets = _strip_blocked_tools(toolsets)
elif parent_agent and getattr(parent_agent, "enabled_toolsets", None):
child_toolsets = _strip_blocked_tools(parent_agent.enabled_toolsets)
else:
child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS)
child_prompt = _build_child_system_prompt(goal, context)
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
parent_api_key = getattr(parent_agent, "api_key", None)
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
parent_api_key = parent_agent._client_kwargs.get("api_key")
# Build progress callback to relay tool calls to parent display
child_progress_cb = _build_child_progress_callback(task_index, parent_agent)
# Share the parent's iteration budget so subagent tool calls
# count toward the session-wide limit.
shared_budget = getattr(parent_agent, "iteration_budget", None)
# Resolve effective credentials: config override > parent inherit
effective_model = model or parent_agent.model
effective_provider = override_provider or getattr(parent_agent, "provider", None)
effective_base_url = override_base_url or parent_agent.base_url
effective_api_key = override_api_key or parent_api_key
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None)
effective_acp_command = getattr(parent_agent, "acp_command", None)
effective_acp_args = list(getattr(parent_agent, "acp_args", []) or [])
child = AIAgent(
base_url=effective_base_url,
api_key=effective_api_key,
model=effective_model,
provider=effective_provider,
api_mode=effective_api_mode,
acp_command=effective_acp_command,
acp_args=effective_acp_args,
max_iterations=max_iterations,
max_tokens=getattr(parent_agent, "max_tokens", None),
reasoning_config=getattr(parent_agent, "reasoning_config", None),
prefill_messages=getattr(parent_agent, "prefill_messages", None),
enabled_toolsets=child_toolsets,
quiet_mode=True,
ephemeral_system_prompt=child_prompt,
log_prefix=f"[subagent-{task_index}]",
platform=parent_agent.platform,
skip_context_files=True,
skip_memory=True,
clarify_callback=None,
session_db=getattr(parent_agent, '_session_db', None),
providers_allowed=parent_agent.providers_allowed,
providers_ignored=parent_agent.providers_ignored,
providers_order=parent_agent.providers_order,
provider_sort=parent_agent.provider_sort,
tool_progress_callback=child_progress_cb,
iteration_budget=shared_budget,
)
# Set delegation depth so children can't spawn grandchildren
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
# Register child for interrupt propagation
if hasattr(parent_agent, '_active_children'):
lock = getattr(parent_agent, '_active_children_lock', None)
if lock:
with lock:
parent_agent._active_children.append(child)
else:
parent_agent._active_children.append(child)
return child
def _run_single_child(
task_index: int,
goal: str,
child=None,
parent_agent=None,
**_kwargs,
) -> Dict[str, Any]:
"""
Run a pre-built child agent. Called from within a thread.
Returns a structured result dict.
"""
child_start = time.monotonic()
# Get the progress callback from the child agent
child_progress_cb = getattr(child, 'tool_progress_callback', None)
# Restore parent tool names using the value saved before child construction
# mutated the global. This is the correct parent toolset, not the child's.
import model_tools
_saved_tool_names = getattr(child, "_delegate_saved_tool_names",
list(model_tools._last_resolved_tool_names))
try:
result = child.run_conversation(user_message=goal)
# Flush any remaining batched progress to gateway
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
try:
child_progress_cb._flush()
except Exception as e:
logger.debug("Progress callback flush failed: %s", e)
duration = round(time.monotonic() - child_start, 2)
summary = result.get("final_response") or ""
completed = result.get("completed", False)
interrupted = result.get("interrupted", False)
api_calls = result.get("api_calls", 0)
if interrupted:
status = "interrupted"
elif completed and summary:
status = "completed"
else:
status = "failed"
# Build tool trace from conversation messages (already in memory).
# Uses tool_call_id to correctly pair parallel tool calls with results.
tool_trace: list[Dict[str, Any]] = []
trace_by_id: Dict[str, Dict[str, Any]] = {}
messages = result.get("messages") or []
if isinstance(messages, list):
for msg in messages:
if not isinstance(msg, dict):
continue
if msg.get("role") == "assistant":
for tc in (msg.get("tool_calls") or []):
fn = tc.get("function", {})
entry_t = {
"tool": fn.get("name", "unknown"),
"args_bytes": len(fn.get("arguments", "")),
}
tool_trace.append(entry_t)
tc_id = tc.get("id")
if tc_id:
trace_by_id[tc_id] = entry_t
elif msg.get("role") == "tool":
content = msg.get("content", "")
is_error = bool(
content and "error" in content[:80].lower()
)
result_meta = {
"result_bytes": len(content),
"status": "error" if is_error else "ok",
}
# Match by tool_call_id for parallel calls
tc_id = msg.get("tool_call_id")
target = trace_by_id.get(tc_id) if tc_id else None
if target is not None:
target.update(result_meta)
elif tool_trace:
# Fallback for messages without tool_call_id
tool_trace[-1].update(result_meta)
# Determine exit reason
if interrupted:
exit_reason = "interrupted"
elif completed:
exit_reason = "completed"
else:
exit_reason = "max_iterations"
# Extract token counts (safe for mock objects)
_input_tokens = getattr(child, "session_prompt_tokens", 0)
_output_tokens = getattr(child, "session_completion_tokens", 0)
_model = getattr(child, "model", None)
entry: Dict[str, Any] = {
"task_index": task_index,
"status": status,
"summary": summary,
"api_calls": api_calls,
"duration_seconds": duration,
"model": _model if isinstance(_model, str) else None,
"exit_reason": exit_reason,
"tokens": {
"input": _input_tokens if isinstance(_input_tokens, (int, float)) else 0,
"output": _output_tokens if isinstance(_output_tokens, (int, float)) else 0,
},
"tool_trace": tool_trace,
}
if status == "failed":
entry["error"] = result.get("error", "Subagent did not produce a response.")
return entry
except Exception as exc:
duration = round(time.monotonic() - child_start, 2)
logging.exception(f"[subagent-{task_index}] failed")
return {
"task_index": task_index,
"status": "error",
"summary": None,
"error": str(exc),
"api_calls": 0,
"duration_seconds": duration,
}
finally:
# Restore the parent's tool names so the process-global is correct
# for any subsequent execute_code calls or other consumers.
import model_tools
saved_tool_names = getattr(child, "_delegate_saved_tool_names", None)
if isinstance(saved_tool_names, list):
model_tools._last_resolved_tool_names = list(saved_tool_names)
# Unregister child from interrupt propagation
if hasattr(parent_agent, '_active_children'):
try:
lock = getattr(parent_agent, '_active_children_lock', None)
if lock:
with lock:
parent_agent._active_children.remove(child)
else:
parent_agent._active_children.remove(child)
except (ValueError, UnboundLocalError) as e:
logger.debug("Could not remove child from active_children: %s", e)
def delegate_task(
goal: Optional[str] = None,
context: Optional[str] = None,
toolsets: Optional[List[str]] = None,
tasks: Optional[List[Dict[str, Any]]] = None,
max_iterations: Optional[int] = None,
parent_agent=None,
) -> str:
"""
Spawn one or more child agents to handle delegated tasks.
Supports two modes:
- Single: provide goal (+ optional context, toolsets)
- Batch: provide tasks array [{goal, context, toolsets}, ...]
Returns JSON with results array, one entry per task.
"""
if parent_agent is None:
return json.dumps({"error": "delegate_task requires a parent agent context."})
# Depth limit
depth = getattr(parent_agent, '_delegate_depth', 0)
if depth >= MAX_DEPTH:
return json.dumps({
"error": (
f"Delegation depth limit reached ({MAX_DEPTH}). "
"Subagents cannot spawn further subagents."
)
})
# Load config
cfg = _load_config()
default_max_iter = cfg.get("max_iterations", DEFAULT_MAX_ITERATIONS)
effective_max_iter = max_iterations or default_max_iter
# Resolve delegation credentials (provider:model pair).
# When delegation.provider is configured, this resolves the full credential
# bundle (base_url, api_key, api_mode) via the same runtime provider system
# used by CLI/gateway startup. When unconfigured, returns None values so
# children inherit from the parent.
try:
creds = _resolve_delegation_credentials(cfg, parent_agent)
except ValueError as exc:
return json.dumps({"error": str(exc)})
# Normalize to task list
if tasks and isinstance(tasks, list):
task_list = tasks[:MAX_CONCURRENT_CHILDREN]
elif goal and isinstance(goal, str) and goal.strip():
task_list = [{"goal": goal, "context": context, "toolsets": toolsets}]
else:
return json.dumps({"error": "Provide either 'goal' (single task) or 'tasks' (batch)."})
if not task_list:
return json.dumps({"error": "No tasks provided."})
# Validate each task has a goal
for i, task in enumerate(task_list):
if not task.get("goal", "").strip():
return json.dumps({"error": f"Task {i} is missing a 'goal'."})
overall_start = time.monotonic()
results = []
n_tasks = len(task_list)
# Track goal labels for progress display (truncated for readability)
task_labels = [t["goal"][:40] for t in task_list]
# Save parent tool names BEFORE any child construction mutates the global.
# _build_child_agent() calls AIAgent() which calls get_tool_definitions(),
# which overwrites model_tools._last_resolved_tool_names with child's toolset.
import model_tools as _model_tools
_parent_tool_names = list(_model_tools._last_resolved_tool_names)
# Build all child agents on the main thread (thread-safe construction)
# Wrapped in try/finally so the global is always restored even if a
# child build raises (otherwise _last_resolved_tool_names stays corrupted).
children = []
try:
for i, t in enumerate(task_list):
child = _build_child_agent(
task_index=i, goal=t["goal"], context=t.get("context"),
toolsets=t.get("toolsets") or toolsets, model=creds["model"],
max_iterations=effective_max_iter, parent_agent=parent_agent,
override_provider=creds["provider"], override_base_url=creds["base_url"],
override_api_key=creds["api_key"],
override_api_mode=creds["api_mode"],
)
# Override with correct parent tool names (before child construction mutated global)
child._delegate_saved_tool_names = _parent_tool_names
children.append((i, t, child))
finally:
# Authoritative restore: reset global to parent's tool names after all children built
_model_tools._last_resolved_tool_names = _parent_tool_names
if n_tasks == 1:
# Single task -- run directly (no thread pool overhead)
_i, _t, child = children[0]
result = _run_single_child(0, _t["goal"], child, parent_agent)
results.append(result)
else:
# Batch -- run in parallel with per-task progress lines
completed_count = 0
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor:
futures = {}
for i, t, child in children:
future = executor.submit(
_run_single_child,
task_index=i,
goal=t["goal"],
child=child,
parent_agent=parent_agent,
)
futures[future] = i
for future in as_completed(futures):
try:
entry = future.result()
except Exception as exc:
idx = futures[future]
entry = {
"task_index": idx,
"status": "error",
"summary": None,
"error": str(exc),
"api_calls": 0,
"duration_seconds": 0,
}
results.append(entry)
completed_count += 1
# Print per-task completion line above the spinner
idx = entry["task_index"]
label = task_labels[idx] if idx < len(task_labels) else f"Task {idx}"
dur = entry.get("duration_seconds", 0)
status = entry.get("status", "?")
icon = "" if status == "completed" else ""
remaining = n_tasks - completed_count
completion_line = f"{icon} [{idx+1}/{n_tasks}] {label} ({dur}s)"
if spinner_ref:
try:
spinner_ref.print_above(completion_line)
except Exception:
print(f" {completion_line}")
else:
print(f" {completion_line}")
# Update spinner text to show remaining count
if spinner_ref and remaining > 0:
try:
spinner_ref.update_text(f"🔀 {remaining} task{'s' if remaining != 1 else ''} remaining")
except Exception as e:
logger.debug("Spinner update_text failed: %s", e)
# Sort by task_index so results match input order
results.sort(key=lambda r: r["task_index"])
total_duration = round(time.monotonic() - overall_start, 2)
return json.dumps({
"results": results,
"total_duration_seconds": total_duration,
}, ensure_ascii=False)
def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict:
"""Resolve credentials for subagent delegation.
If ``delegation.base_url`` is configured, subagents use that direct
OpenAI-compatible endpoint. Otherwise, if ``delegation.provider`` is
configured, the full credential bundle (base_url, api_key, api_mode,
provider) is resolved via the runtime provider system the same path used
by CLI/gateway startup. This lets subagents run on a completely different
provider:model pair.
If neither base_url nor provider is configured, returns None values so the
child inherits everything from the parent agent.
Raises ValueError with a user-friendly message on credential failure.
"""
configured_model = str(cfg.get("model") or "").strip() or None
configured_provider = str(cfg.get("provider") or "").strip() or None
configured_base_url = str(cfg.get("base_url") or "").strip() or None
configured_api_key = str(cfg.get("api_key") or "").strip() or None
if configured_base_url:
api_key = (
configured_api_key
or os.getenv("OPENAI_API_KEY", "").strip()
)
if not api_key:
raise ValueError(
"Delegation base_url is configured but no API key was found. "
"Set delegation.api_key or OPENAI_API_KEY."
)
base_lower = configured_base_url.lower()
provider = "custom"
api_mode = "chat_completions"
if "chatgpt.com/backend-api/codex" in base_lower:
provider = "openai-codex"
api_mode = "codex_responses"
elif "api.anthropic.com" in base_lower:
provider = "anthropic"
api_mode = "anthropic_messages"
return {
"model": configured_model,
"provider": provider,
"base_url": configured_base_url,
"api_key": api_key,
"api_mode": api_mode,
}
if not configured_provider:
# No provider override — child inherits everything from parent
return {
"model": configured_model,
"provider": None,
"base_url": None,
"api_key": None,
"api_mode": None,
}
# Provider is configured — resolve full credentials
try:
from hermes_cli.runtime_provider import resolve_runtime_provider
runtime = resolve_runtime_provider(requested=configured_provider)
except Exception as exc:
raise ValueError(
f"Cannot resolve delegation provider '{configured_provider}': {exc}. "
f"Check that the provider is configured (API key set, valid provider name), "
f"or set delegation.base_url/delegation.api_key for a direct endpoint. "
f"Available providers: openrouter, nous, zai, kimi-coding, minimax."
) from exc
api_key = runtime.get("api_key", "")
if not api_key:
raise ValueError(
f"Delegation provider '{configured_provider}' resolved but has no API key. "
f"Set the appropriate environment variable or run 'hermes login'."
)
return {
"model": configured_model,
"provider": runtime.get("provider"),
"base_url": runtime.get("base_url"),
"api_key": api_key,
"api_mode": runtime.get("api_mode"),
"command": runtime.get("command"),
"args": list(runtime.get("args") or []),
}
def _load_config() -> dict:
"""Load delegation config from CLI_CONFIG or persistent config.
Checks the runtime config (cli.py CLI_CONFIG) first, then falls back
to the persistent config (hermes_cli/config.py load_config()) so that
``delegation.model`` / ``delegation.provider`` are picked up regardless
of the entry point (CLI, gateway, cron).
"""
try:
from cli import CLI_CONFIG
cfg = CLI_CONFIG.get("delegation", {})
if cfg:
return cfg
except Exception:
pass
try:
from hermes_cli.config import load_config
full = load_config()
return full.get("delegation", {})
except Exception:
return {}
# ---------------------------------------------------------------------------
# OpenAI Function-Calling Schema
# ---------------------------------------------------------------------------
DELEGATE_TASK_SCHEMA = {
"name": "delegate_task",
"description": (
"Spawn one or more subagents to work on tasks in isolated contexts. "
"Each subagent gets its own conversation, terminal session, and toolset. "
"Only the final summary is returned -- intermediate tool results "
"never enter your context window.\n\n"
"TWO MODES (one of 'goal' or 'tasks' is required):\n"
"1. Single task: provide 'goal' (+ optional context, toolsets)\n"
"2. Batch (parallel): provide 'tasks' array with up to 3 items. "
"All run concurrently and results are returned together.\n\n"
"WHEN TO USE delegate_task:\n"
"- Reasoning-heavy subtasks (debugging, code review, research synthesis)\n"
"- Tasks that would flood your context with intermediate data\n"
"- Parallel independent workstreams (research A and B simultaneously)\n\n"
"WHEN NOT TO USE (use these instead):\n"
"- Mechanical multi-step work with no reasoning needed -> use execute_code\n"
"- Single tool call -> just call the tool directly\n"
"- Tasks needing user interaction -> subagents cannot use clarify\n\n"
"IMPORTANT:\n"
"- Subagents have NO memory of your conversation. Pass all relevant "
"info (file paths, error messages, constraints) via the 'context' field.\n"
"- Subagents CANNOT call: delegate_task, clarify, memory, send_message, "
"execute_code.\n"
"- Each subagent gets its own terminal session (separate working directory and state).\n"
"- Results are always returned as an array, one entry per task."
),
"parameters": {
"type": "object",
"properties": {
"goal": {
"type": "string",
"description": (
"What the subagent should accomplish. Be specific and "
"self-contained -- the subagent knows nothing about your "
"conversation history."
),
},
"context": {
"type": "string",
"description": (
"Background information the subagent needs: file paths, "
"error messages, project structure, constraints. The more "
"specific you are, the better the subagent performs."
),
},
"toolsets": {
"type": "array",
"items": {"type": "string"},
"description": (
"Toolsets to enable for this subagent. "
"Default: inherits your enabled toolsets. "
"Common patterns: ['terminal', 'file'] for code work, "
"['web'] for research, ['terminal', 'file', 'web'] for "
"full-stack tasks."
),
},
"tasks": {
"type": "array",
"items": {
"type": "object",
"properties": {
"goal": {"type": "string", "description": "Task goal"},
"context": {"type": "string", "description": "Task-specific context"},
"toolsets": {
"type": "array",
"items": {"type": "string"},
"description": "Toolsets for this specific task",
},
},
"required": ["goal"],
},
"maxItems": 3,
"description": (
"Batch mode: up to 3 tasks to run in parallel. Each gets "
"its own subagent with isolated context and terminal session. "
"When provided, top-level goal/context/toolsets are ignored."
),
},
"max_iterations": {
"type": "integer",
"description": (
"Max tool-calling turns per subagent (default: 50). "
"Only set lower for simple tasks."
),
},
},
"required": [],
},
}
# --- Registry ---
from tools.registry import registry
registry.register(
name="delegate_task",
toolset="delegation",
schema=DELEGATE_TASK_SCHEMA,
handler=lambda args, **kw: delegate_task(
goal=args.get("goal"),
context=args.get("context"),
toolsets=args.get("toolsets"),
tasks=args.get("tasks"),
max_iterations=args.get("max_iterations"),
parent_agent=kw.get("parent_agent")),
check_fn=check_delegate_requirements,
emoji="🔀",
)

View file

@ -0,0 +1,99 @@
"""Environment variable passthrough registry.
Skills that declare ``required_environment_variables`` in their frontmatter
need those vars available in sandboxed execution environments (execute_code,
terminal). By default both sandboxes strip secrets from the child process
environment for security. This module provides a session-scoped allowlist
so skill-declared vars (and user-configured overrides) pass through.
Two sources feed the allowlist:
1. **Skill declarations** when a skill is loaded via ``skill_view``, its
``required_environment_variables`` are registered here automatically.
2. **User config** ``terminal.env_passthrough`` in config.yaml lets users
explicitly allowlist vars for non-skill use cases.
Both ``code_execution_tool.py`` and ``tools/environments/local.py`` consult
:func:`is_env_passthrough` before stripping a variable.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Iterable
logger = logging.getLogger(__name__)
# Session-scoped set of env var names that should pass through to sandboxes.
_allowed_env_vars: set[str] = set()
# Cache for the config-based allowlist (loaded once per process).
_config_passthrough: frozenset[str] | None = None
def register_env_passthrough(var_names: Iterable[str]) -> None:
"""Register environment variable names as allowed in sandboxed environments.
Typically called when a skill declares ``required_environment_variables``.
"""
for name in var_names:
name = name.strip()
if name:
_allowed_env_vars.add(name)
logger.debug("env passthrough: registered %s", name)
def _load_config_passthrough() -> frozenset[str]:
"""Load ``tools.env_passthrough`` from config.yaml (cached)."""
global _config_passthrough
if _config_passthrough is not None:
return _config_passthrough
result: set[str] = set()
try:
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
config_path = hermes_home / "config.yaml"
if config_path.exists():
import yaml
with open(config_path) as f:
cfg = yaml.safe_load(f) or {}
passthrough = cfg.get("terminal", {}).get("env_passthrough")
if isinstance(passthrough, list):
for item in passthrough:
if isinstance(item, str) and item.strip():
result.add(item.strip())
except Exception as e:
logger.debug("Could not read tools.env_passthrough from config: %s", e)
_config_passthrough = frozenset(result)
return _config_passthrough
def is_env_passthrough(var_name: str) -> bool:
"""Check whether *var_name* is allowed to pass through to sandboxes.
Returns ``True`` if the variable was registered by a skill or listed in
the user's ``tools.env_passthrough`` config.
"""
if var_name in _allowed_env_vars:
return True
return var_name in _load_config_passthrough()
def get_all_passthrough() -> frozenset[str]:
"""Return the union of skill-registered and config-based passthrough vars."""
return frozenset(_allowed_env_vars) | _load_config_passthrough()
def clear_env_passthrough() -> None:
"""Reset the skill-scoped allowlist (e.g. on session reset)."""
_allowed_env_vars.clear()
def reset_config_cache() -> None:
"""Force re-read of config on next access (for testing)."""
global _config_passthrough
_config_passthrough = None

View file

@ -0,0 +1,13 @@
"""Hermes execution environment backends.
Each backend provides the same interface (BaseEnvironment ABC) for running
shell commands in a specific execution context: local, Docker, Singularity,
SSH, Modal, or Daytona.
The terminal_tool.py factory (_create_environment) selects the backend
based on the TERMINAL_ENV configuration.
"""
from tools.environments.base import BaseEnvironment
__all__ = ["BaseEnvironment"]

View file

@ -0,0 +1,99 @@
"""Base class for all Hermes execution environment backends."""
from abc import ABC, abstractmethod
import os
import subprocess
from pathlib import Path
from hermes_cli.config import get_hermes_home
def get_sandbox_dir() -> Path:
"""Return the host-side root for all sandbox storage (Docker workspaces,
Singularity overlays/SIF cache, etc.).
Configurable via TERMINAL_SANDBOX_DIR. Defaults to {HERMES_HOME}/sandboxes/.
"""
custom = os.getenv("TERMINAL_SANDBOX_DIR")
if custom:
p = Path(custom)
else:
p = get_hermes_home() / "sandboxes"
p.mkdir(parents=True, exist_ok=True)
return p
class BaseEnvironment(ABC):
"""Common interface for all Hermes execution backends.
Subclasses implement execute() and cleanup(). Shared helpers eliminate
duplicated subprocess boilerplate across backends.
"""
def __init__(self, cwd: str, timeout: int, env: dict = None):
self.cwd = cwd
self.timeout = timeout
self.env = env or {}
@abstractmethod
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
"""Execute a command, return {"output": str, "returncode": int}."""
...
@abstractmethod
def cleanup(self):
"""Release backend resources (container, instance, connection)."""
...
def stop(self):
"""Alias for cleanup (compat with older callers)."""
self.cleanup()
def __del__(self):
try:
self.cleanup()
except Exception:
pass
# ------------------------------------------------------------------
# Shared helpers (eliminate duplication across backends)
# ------------------------------------------------------------------
def _prepare_command(self, command: str) -> tuple[str, str | None]:
"""Transform sudo commands if SUDO_PASSWORD is available.
Returns:
(transformed_command, sudo_stdin) see _transform_sudo_command
for the full contract. Callers that drive a subprocess directly
should prepend sudo_stdin (when not None) to any stdin_data they
pass to Popen. Callers that embed stdin via heredoc (modal,
daytona) handle sudo_stdin in their own execute() method.
"""
from tools.terminal_tool import _transform_sudo_command
return _transform_sudo_command(command)
def _build_run_kwargs(self, timeout: int | None,
stdin_data: str | None = None) -> dict:
"""Build common subprocess.run kwargs for non-interactive execution."""
kw = {
"text": True,
"timeout": timeout or self.timeout,
"encoding": "utf-8",
"errors": "replace",
"stdout": subprocess.PIPE,
"stderr": subprocess.STDOUT,
}
if stdin_data is not None:
kw["input"] = stdin_data
else:
kw["stdin"] = subprocess.DEVNULL
return kw
def _timeout_result(self, timeout: int | None) -> dict:
"""Standard return dict when a command times out."""
return {
"output": f"Command timed out after {timeout or self.timeout}s",
"returncode": 124,
}

View file

@ -0,0 +1,250 @@
"""Daytona cloud execution environment.
Uses the Daytona Python SDK to run commands in cloud sandboxes.
Supports persistent sandboxes: when enabled, sandboxes are stopped on cleanup
and resumed on next creation, preserving the filesystem across sessions.
"""
import logging
import time
import math
import shlex
import threading
import uuid
import warnings
from typing import Optional
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
class DaytonaEnvironment(BaseEnvironment):
"""Daytona cloud sandbox execution backend.
Uses stopped/started sandbox lifecycle for filesystem persistence
instead of snapshots, making it faster and stateless on the host.
"""
def __init__(
self,
image: str,
cwd: str = "/home/daytona",
timeout: int = 60,
cpu: int = 1,
memory: int = 5120, # MB (hermes convention)
disk: int = 10240, # MB (Daytona platform max is 10GB)
persistent_filesystem: bool = True,
task_id: str = "default",
):
self._requested_cwd = cwd
super().__init__(cwd=cwd, timeout=timeout)
from daytona import (
Daytona,
CreateSandboxFromImageParams,
DaytonaError,
Resources,
SandboxState,
)
self._persistent = persistent_filesystem
self._task_id = task_id
self._SandboxState = SandboxState
self._daytona = Daytona()
self._sandbox = None
self._lock = threading.Lock()
memory_gib = max(1, math.ceil(memory / 1024))
disk_gib = max(1, math.ceil(disk / 1024))
if disk_gib > 10:
warnings.warn(
f"Daytona: requested disk ({disk_gib}GB) exceeds platform limit (10GB). "
f"Capping to 10GB. Set container_disk: 10240 in config to silence this.",
stacklevel=2,
)
disk_gib = 10
resources = Resources(cpu=cpu, memory=memory_gib, disk=disk_gib)
labels = {"hermes_task_id": task_id}
sandbox_name = f"hermes-{task_id}"
# Try to resume an existing sandbox for this task
if self._persistent:
# 1. Try name-based lookup (new path)
try:
self._sandbox = self._daytona.get(sandbox_name)
self._sandbox.start()
logger.info("Daytona: resumed sandbox %s for task %s",
self._sandbox.id, task_id)
except DaytonaError:
self._sandbox = None
except Exception as e:
logger.warning("Daytona: failed to resume sandbox for task %s: %s",
task_id, e)
self._sandbox = None
# 2. Legacy fallback: find sandbox created before the naming migration
if self._sandbox is None:
try:
page = self._daytona.list(labels=labels, page=1, limit=1)
if page.items:
self._sandbox = page.items[0]
self._sandbox.start()
logger.info("Daytona: resumed legacy sandbox %s for task %s",
self._sandbox.id, task_id)
except Exception as e:
logger.debug("Daytona: no legacy sandbox found for task %s: %s",
task_id, e)
self._sandbox = None
# Create a fresh sandbox if we don't have one
if self._sandbox is None:
self._sandbox = self._daytona.create(
CreateSandboxFromImageParams(
image=image,
name=sandbox_name,
labels=labels,
auto_stop_interval=0,
resources=resources,
)
)
logger.info("Daytona: created sandbox %s for task %s",
self._sandbox.id, task_id)
# Resolve cwd: detect actual home dir inside the sandbox
if self._requested_cwd in ("~", "/home/daytona"):
try:
home = self._sandbox.process.exec("echo $HOME").result.strip()
if home:
self.cwd = home
except Exception:
pass # leave cwd as-is; sandbox will use its own default
logger.info("Daytona: resolved cwd to %s", self.cwd)
def _ensure_sandbox_ready(self):
"""Restart sandbox if it was stopped (e.g., by a previous interrupt)."""
self._sandbox.refresh_data()
if self._sandbox.state in (self._SandboxState.STOPPED, self._SandboxState.ARCHIVED):
self._sandbox.start()
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
"""Run exec in a background thread with interrupt polling.
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
server-side timeout is not enforced and the SDK has no client-side
fallback), so we wrap the command with the shell ``timeout`` utility
which reliably kills the process and returns exit code 124.
"""
# Wrap with shell `timeout` to enforce the deadline reliably.
# Add a small buffer so the shell timeout fires before any SDK-level
# timeout would, giving us a clean exit code 124.
timed_command = f"timeout {timeout} sh -c {shlex.quote(exec_command)}"
result_holder: dict = {"value": None, "error": None}
def _run():
try:
response = self._sandbox.process.exec(
timed_command, cwd=cwd,
)
result_holder["value"] = {
"output": response.result or "",
"returncode": response.exit_code,
}
except Exception as e:
result_holder["error"] = e
t = threading.Thread(target=_run, daemon=True)
t.start()
# Wait for timeout + generous buffer for network/SDK overhead
deadline = time.monotonic() + timeout + 10
while t.is_alive():
t.join(timeout=0.2)
if is_interrupted():
with self._lock:
try:
self._sandbox.stop()
except Exception:
pass
return {
"output": "[Command interrupted - Daytona sandbox stopped]",
"returncode": 130,
}
if time.monotonic() > deadline:
# Shell timeout didn't fire and SDK is hung — force stop
with self._lock:
try:
self._sandbox.stop()
except Exception:
pass
return self._timeout_result(timeout)
if result_holder["error"]:
return {"error": result_holder["error"]}
return result_holder["value"]
def execute(self, command: str, cwd: str = "", *,
timeout: Optional[int] = None,
stdin_data: Optional[str] = None) -> dict:
with self._lock:
self._ensure_sandbox_ready()
if stdin_data is not None:
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
while marker in stdin_data:
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
exec_command, sudo_stdin = self._prepare_command(command)
# Daytona sandboxes execute commands via the Daytona SDK and cannot
# pipe subprocess stdin directly the way a local Popen can. When a
# sudo password is present, use a shell-level pipe from printf so that
# the password feeds sudo -S without appearing as an echo argument
# embedded in the shell string. The password is still visible in the
# remote sandbox's command line, but it is not exposed on the user's
# local machine — which is the primary threat being mitigated.
if sudo_stdin is not None:
import shlex
exec_command = (
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
)
effective_cwd = cwd or self.cwd or None
effective_timeout = timeout or self.timeout
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
if "error" in result:
from daytona import DaytonaError
err = result["error"]
if isinstance(err, DaytonaError):
with self._lock:
try:
self._ensure_sandbox_ready()
except Exception:
return {"output": f"Daytona execution error: {err}", "returncode": 1}
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
if "error" not in result:
return result
return {"output": f"Daytona execution error: {err}", "returncode": 1}
return result
def cleanup(self):
with self._lock:
if self._sandbox is None:
return
try:
if self._persistent:
self._sandbox.stop()
logger.info("Daytona: stopped sandbox %s (filesystem preserved)",
self._sandbox.id)
else:
self._daytona.delete(self._sandbox)
logger.info("Daytona: deleted sandbox %s", self._sandbox.id)
except Exception as e:
logger.warning("Daytona: cleanup failed: %s", e)
self._sandbox = None

View file

@ -0,0 +1,494 @@
"""Docker execution environment for sandboxed command execution.
Security hardened (cap-drop ALL, no-new-privileges, PID limits),
configurable resource limits (CPU, memory, disk), and optional filesystem
persistence via bind mounts.
"""
import logging
import os
import re
import shutil
import subprocess
import sys
import threading
import time
import uuid
from typing import Optional
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
# Common Docker Desktop install paths checked when 'docker' is not in PATH.
# macOS Intel: /usr/local/bin, macOS Apple Silicon (Homebrew): /opt/homebrew/bin,
# Docker Desktop app bundle: /Applications/Docker.app/Contents/Resources/bin
_DOCKER_SEARCH_PATHS = [
"/usr/local/bin/docker",
"/opt/homebrew/bin/docker",
"/Applications/Docker.app/Contents/Resources/bin/docker",
]
_docker_executable: Optional[str] = None # resolved once, cached
_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _normalize_forward_env_names(forward_env: list[str] | None) -> list[str]:
"""Return a deduplicated list of valid environment variable names."""
normalized: list[str] = []
seen: set[str] = set()
for item in forward_env or []:
if not isinstance(item, str):
logger.warning("Ignoring non-string docker_forward_env entry: %r", item)
continue
key = item.strip()
if not key:
continue
if not _ENV_VAR_NAME_RE.match(key):
logger.warning("Ignoring invalid docker_forward_env entry: %r", item)
continue
if key in seen:
continue
seen.add(key)
normalized.append(key)
return normalized
def _load_hermes_env_vars() -> dict[str, str]:
"""Load ~/.hermes/.env values without failing Docker command execution."""
try:
from hermes_cli.config import load_env
return load_env() or {}
except Exception:
return {}
def find_docker() -> Optional[str]:
"""Locate the docker CLI binary.
Checks ``shutil.which`` first (respects PATH), then probes well-known
install locations on macOS where Docker Desktop may not be in PATH
(e.g. when running as a gateway service via launchd).
Returns the absolute path, or ``None`` if docker cannot be found.
"""
global _docker_executable
if _docker_executable is not None:
return _docker_executable
found = shutil.which("docker")
if found:
_docker_executable = found
return found
for path in _DOCKER_SEARCH_PATHS:
if os.path.isfile(path) and os.access(path, os.X_OK):
_docker_executable = path
logger.info("Found docker at non-PATH location: %s", path)
return path
return None
# Security flags applied to every container.
# The container itself is the security boundary (isolated from host).
# We drop all capabilities then add back the minimum needed:
# DAC_OVERRIDE - root can write to bind-mounted dirs owned by host user
# CHOWN/FOWNER - package managers (pip, npm, apt) need to set file ownership
# Block privilege escalation and limit PIDs.
# /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds).
_SECURITY_ARGS = [
"--cap-drop", "ALL",
"--cap-add", "DAC_OVERRIDE",
"--cap-add", "CHOWN",
"--cap-add", "FOWNER",
"--security-opt", "no-new-privileges",
"--pids-limit", "256",
"--tmpfs", "/tmp:rw,nosuid,size=512m",
"--tmpfs", "/var/tmp:rw,noexec,nosuid,size=256m",
"--tmpfs", "/run:rw,noexec,nosuid,size=64m",
]
_storage_opt_ok: Optional[bool] = None # cached result across instances
def _ensure_docker_available() -> None:
"""Best-effort check that the docker CLI is available before use.
Reuses ``find_docker()`` so this preflight stays consistent with the rest of
the Docker backend, including known non-PATH Docker Desktop locations.
"""
docker_exe = find_docker()
if not docker_exe:
logger.error(
"Docker backend selected but no docker executable was found in PATH "
"or known install locations. Install Docker Desktop and ensure the "
"CLI is available."
)
raise RuntimeError(
"Docker executable not found in PATH or known install locations. "
"Install Docker and ensure the 'docker' command is available."
)
try:
result = subprocess.run(
[docker_exe, "version"],
capture_output=True,
text=True,
timeout=5,
)
except FileNotFoundError:
logger.error(
"Docker backend selected but the resolved docker executable '%s' could "
"not be executed.",
docker_exe,
exc_info=True,
)
raise RuntimeError(
"Docker executable could not be executed. Check your Docker installation."
)
except subprocess.TimeoutExpired:
logger.error(
"Docker backend selected but '%s version' timed out. "
"The Docker daemon may not be running.",
docker_exe,
exc_info=True,
)
raise RuntimeError(
"Docker daemon is not responding. Ensure Docker is running and try again."
)
except Exception:
logger.error(
"Unexpected error while checking Docker availability.",
exc_info=True,
)
raise
else:
if result.returncode != 0:
logger.error(
"Docker backend selected but '%s version' failed "
"(exit code %d, stderr=%s)",
docker_exe,
result.returncode,
result.stderr.strip(),
)
raise RuntimeError(
"Docker command is available but 'docker version' failed. "
"Check your Docker installation."
)
class DockerEnvironment(BaseEnvironment):
"""Hardened Docker container execution with resource limits and persistence.
Security: all capabilities dropped, no privilege escalation, PID limits,
size-limited tmpfs for scratch dirs. The container itself is the security
boundary the filesystem inside is writable so agents can install packages
(pip, npm, apt) as needed. Writable workspace via tmpfs or bind mounts.
Persistence: when enabled, bind mounts preserve /workspace and /root
across container restarts.
"""
def __init__(
self,
image: str,
cwd: str = "/root",
timeout: int = 60,
cpu: float = 0,
memory: int = 0,
disk: int = 0,
persistent_filesystem: bool = False,
task_id: str = "default",
volumes: list = None,
forward_env: list[str] | None = None,
network: bool = True,
host_cwd: str = None,
auto_mount_cwd: bool = False,
):
if cwd == "~":
cwd = "/root"
super().__init__(cwd=cwd, timeout=timeout)
self._base_image = image
self._persistent = persistent_filesystem
self._task_id = task_id
self._forward_env = _normalize_forward_env_names(forward_env)
self._container_id: Optional[str] = None
logger.info(f"DockerEnvironment volumes: {volumes}")
# Ensure volumes is a list (config.yaml could be malformed)
if volumes is not None and not isinstance(volumes, list):
logger.warning(f"docker_volumes config is not a list: {volumes!r}")
volumes = []
# Fail fast if Docker is not available.
_ensure_docker_available()
# Build resource limit args
resource_args = []
if cpu > 0:
resource_args.extend(["--cpus", str(cpu)])
if memory > 0:
resource_args.extend(["--memory", f"{memory}m"])
if disk > 0 and sys.platform != "darwin":
if self._storage_opt_supported():
resource_args.extend(["--storage-opt", f"size={disk}m"])
else:
logger.warning(
"Docker storage driver does not support per-container disk limits "
"(requires overlay2 on XFS with pquota). Container will run without disk quota."
)
if not network:
resource_args.append("--network=none")
# Persistent workspace via bind mounts from a configurable host directory
# (TERMINAL_SANDBOX_DIR, default ~/.hermes/sandboxes/). Non-persistent
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
from tools.environments.base import get_sandbox_dir
# User-configured volume mounts (from config.yaml docker_volumes)
volume_args = []
workspace_explicitly_mounted = False
for vol in (volumes or []):
if not isinstance(vol, str):
logger.warning(f"Docker volume entry is not a string: {vol!r}")
continue
vol = vol.strip()
if not vol:
continue
if ":" in vol:
volume_args.extend(["-v", vol])
if ":/workspace" in vol:
workspace_explicitly_mounted = True
else:
logger.warning(f"Docker volume '{vol}' missing colon, skipping")
host_cwd_abs = os.path.abspath(os.path.expanduser(host_cwd)) if host_cwd else ""
bind_host_cwd = (
auto_mount_cwd
and bool(host_cwd_abs)
and os.path.isdir(host_cwd_abs)
and not workspace_explicitly_mounted
)
if auto_mount_cwd and host_cwd and not os.path.isdir(host_cwd_abs):
logger.debug(f"Skipping docker cwd mount: host_cwd is not a valid directory: {host_cwd}")
self._workspace_dir: Optional[str] = None
self._home_dir: Optional[str] = None
writable_args = []
if self._persistent:
sandbox = get_sandbox_dir() / "docker" / task_id
self._home_dir = str(sandbox / "home")
os.makedirs(self._home_dir, exist_ok=True)
writable_args.extend([
"-v", f"{self._home_dir}:/root",
])
if not bind_host_cwd and not workspace_explicitly_mounted:
self._workspace_dir = str(sandbox / "workspace")
os.makedirs(self._workspace_dir, exist_ok=True)
writable_args.extend([
"-v", f"{self._workspace_dir}:/workspace",
])
else:
if not bind_host_cwd and not workspace_explicitly_mounted:
writable_args.extend([
"--tmpfs", "/workspace:rw,exec,size=10g",
])
writable_args.extend([
"--tmpfs", "/home:rw,exec,size=1g",
"--tmpfs", "/root:rw,exec,size=1g",
])
if bind_host_cwd:
logger.info(f"Mounting configured host cwd to /workspace: {host_cwd_abs}")
volume_args = ["-v", f"{host_cwd_abs}:/workspace", *volume_args]
elif workspace_explicitly_mounted:
logger.debug("Skipping docker cwd mount: /workspace already mounted by user config")
logger.info(f"Docker volume_args: {volume_args}")
all_run_args = list(_SECURITY_ARGS) + writable_args + resource_args + volume_args
logger.info(f"Docker run_args: {all_run_args}")
# Resolve the docker executable once so it works even when
# /usr/local/bin is not in PATH (common on macOS gateway/service).
self._docker_exe = find_docker() or "docker"
# Start the container directly via `docker run -d`.
container_name = f"hermes-{uuid.uuid4().hex[:8]}"
run_cmd = [
self._docker_exe, "run", "-d",
"--name", container_name,
"-w", cwd,
*all_run_args,
image,
"sleep", "2h",
]
logger.debug(f"Starting container: {' '.join(run_cmd)}")
result = subprocess.run(
run_cmd,
capture_output=True,
text=True,
timeout=120, # image pull may take a while
check=True,
)
self._container_id = result.stdout.strip()
logger.info(f"Started container {container_name} ({self._container_id[:12]})")
@staticmethod
def _storage_opt_supported() -> bool:
"""Check if Docker's storage driver supports --storage-opt size=.
Only overlay2 on XFS with pquota supports per-container disk quotas.
Ubuntu (and most distros) default to ext4, where this flag errors out.
"""
global _storage_opt_ok
if _storage_opt_ok is not None:
return _storage_opt_ok
try:
docker = find_docker() or "docker"
result = subprocess.run(
[docker, "info", "--format", "{{.Driver}}"],
capture_output=True, text=True, timeout=10,
)
driver = result.stdout.strip().lower()
if driver != "overlay2":
_storage_opt_ok = False
return False
# overlay2 only supports storage-opt on XFS with pquota.
# Probe by attempting a dry-ish run — the fastest reliable check.
probe = subprocess.run(
[docker, "create", "--storage-opt", "size=1m", "hello-world"],
capture_output=True, text=True, timeout=15,
)
if probe.returncode == 0:
# Clean up the created container
container_id = probe.stdout.strip()
if container_id:
subprocess.run([docker, "rm", container_id],
capture_output=True, timeout=5)
_storage_opt_ok = True
else:
_storage_opt_ok = False
except Exception:
_storage_opt_ok = False
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
return _storage_opt_ok
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
exec_command, sudo_stdin = self._prepare_command(command)
work_dir = cwd or self.cwd
effective_timeout = timeout or self.timeout
# Merge sudo password (if any) with caller-supplied stdin_data.
if sudo_stdin is not None and stdin_data is not None:
effective_stdin = sudo_stdin + stdin_data
elif sudo_stdin is not None:
effective_stdin = sudo_stdin
else:
effective_stdin = stdin_data
# docker exec -w doesn't expand ~, so prepend a cd into the command
if work_dir == "~" or work_dir.startswith("~/"):
exec_command = f"cd {work_dir} && {exec_command}"
work_dir = "/"
assert self._container_id, "Container not started"
cmd = [self._docker_exe, "exec"]
if effective_stdin is not None:
cmd.append("-i")
cmd.extend(["-w", work_dir])
hermes_env = _load_hermes_env_vars() if self._forward_env else {}
for key in self._forward_env:
value = os.getenv(key)
if value is None:
value = hermes_env.get(key)
if value is not None:
cmd.extend(["-e", f"{key}={value}"])
cmd.extend([self._container_id, "bash", "-lc", exec_command])
try:
_output_chunks = []
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
text=True,
)
if effective_stdin:
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except Exception:
pass
def _drain():
try:
for line in proc.stdout:
_output_chunks.append(line)
except Exception:
pass
reader = threading.Thread(target=_drain, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
proc.terminate()
try:
proc.wait(timeout=1)
except subprocess.TimeoutExpired:
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
reader.join(timeout=5)
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
except Exception as e:
return {"output": f"Docker execution error: {e}", "returncode": 1}
def cleanup(self):
"""Stop and remove the container. Bind-mount dirs persist if persistent=True."""
if self._container_id:
try:
# Stop in background so cleanup doesn't block
stop_cmd = (
f"(timeout 60 {self._docker_exe} stop {self._container_id} || "
f"{self._docker_exe} rm -f {self._container_id}) >/dev/null 2>&1 &"
)
subprocess.Popen(stop_cmd, shell=True)
except Exception as e:
logger.warning("Failed to stop container %s: %s", self._container_id, e)
if not self._persistent:
# Also schedule removal (stop only leaves it as stopped)
try:
subprocess.Popen(
f"sleep 3 && {self._docker_exe} rm -f {self._container_id} >/dev/null 2>&1 &",
shell=True,
)
except Exception:
pass
self._container_id = None
if not self._persistent:
for d in (self._workspace_dir, self._home_dir):
if d:
shutil.rmtree(d, ignore_errors=True)

View file

@ -0,0 +1,476 @@
"""Local execution environment with interrupt support and non-blocking I/O."""
import glob
import os
import platform
import shutil
import signal
import subprocess
import threading
import time
_IS_WINDOWS = platform.system() == "Windows"
from tools.environments.base import BaseEnvironment
from tools.environments.persistent_shell import PersistentShellMixin
from tools.interrupt import is_interrupted
# Unique marker to isolate real command output from shell init/exit noise.
# printf (no trailing newline) keeps the boundaries clean for splitting.
_OUTPUT_FENCE = "__HERMES_FENCE_a9f7b3__"
# Hermes-internal env vars that should NOT leak into terminal subprocesses.
# These are loaded from ~/.hermes/.env for Hermes' own LLM/provider calls
# but can break external CLIs (e.g. codex) that also honor them.
# See: https://github.com/NousResearch/hermes-agent/issues/1002
#
# Built dynamically from the provider registry so new providers are
# automatically covered without manual blocklist maintenance.
_HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_"
def _build_provider_env_blocklist() -> frozenset:
"""Derive the blocklist from provider, tool, and gateway config.
Automatically picks up api_key_env_vars and base_url_env_var from
every registered provider, plus tool/messaging env vars from the
optional config registry, so new Hermes-managed secrets are blocked
in subprocesses without having to maintain multiple static lists.
"""
blocked: set[str] = set()
try:
from hermes_cli.auth import PROVIDER_REGISTRY
for pconfig in PROVIDER_REGISTRY.values():
blocked.update(pconfig.api_key_env_vars)
if pconfig.base_url_env_var:
blocked.add(pconfig.base_url_env_var)
except ImportError:
pass
try:
from hermes_cli.config import OPTIONAL_ENV_VARS
for name, metadata in OPTIONAL_ENV_VARS.items():
category = metadata.get("category")
if category in {"tool", "messaging"}:
blocked.add(name)
elif category == "setting" and metadata.get("password"):
blocked.add(name)
except ImportError:
pass
# Vars not covered above but still Hermes-internal / conflict-prone.
blocked.update({
"OPENAI_BASE_URL",
"OPENAI_API_KEY",
"OPENAI_API_BASE", # legacy alias
"OPENAI_ORG_ID",
"OPENAI_ORGANIZATION",
"OPENROUTER_API_KEY",
"ANTHROPIC_BASE_URL",
"ANTHROPIC_TOKEN", # OAuth token (not in registry as env var)
"CLAUDE_CODE_OAUTH_TOKEN",
"LLM_MODEL",
# Expanded isolation for other major providers (Issue #1002)
"GOOGLE_API_KEY", # Gemini / Google AI Studio
"DEEPSEEK_API_KEY", # DeepSeek
"MISTRAL_API_KEY", # Mistral AI
"GROQ_API_KEY", # Groq
"TOGETHER_API_KEY", # Together AI
"PERPLEXITY_API_KEY", # Perplexity
"COHERE_API_KEY", # Cohere
"FIREWORKS_API_KEY", # Fireworks AI
"XAI_API_KEY", # xAI (Grok)
"HELICONE_API_KEY", # LLM Observability proxy
"PARALLEL_API_KEY",
"FIRECRAWL_API_KEY",
"FIRECRAWL_API_URL",
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
"TELEGRAM_HOME_CHANNEL",
"TELEGRAM_HOME_CHANNEL_NAME",
"DISCORD_HOME_CHANNEL",
"DISCORD_HOME_CHANNEL_NAME",
"DISCORD_REQUIRE_MENTION",
"DISCORD_FREE_RESPONSE_CHANNELS",
"DISCORD_AUTO_THREAD",
"SLACK_HOME_CHANNEL",
"SLACK_HOME_CHANNEL_NAME",
"SLACK_ALLOWED_USERS",
"WHATSAPP_ENABLED",
"WHATSAPP_MODE",
"WHATSAPP_ALLOWED_USERS",
"SIGNAL_HTTP_URL",
"SIGNAL_ACCOUNT",
"SIGNAL_ALLOWED_USERS",
"SIGNAL_GROUP_ALLOWED_USERS",
"SIGNAL_HOME_CHANNEL",
"SIGNAL_HOME_CHANNEL_NAME",
"SIGNAL_IGNORE_STORIES",
"HASS_TOKEN",
"HASS_URL",
"EMAIL_ADDRESS",
"EMAIL_PASSWORD",
"EMAIL_IMAP_HOST",
"EMAIL_SMTP_HOST",
"EMAIL_HOME_ADDRESS",
"EMAIL_HOME_ADDRESS_NAME",
"GATEWAY_ALLOWED_USERS",
# Skills Hub / GitHub app auth paths and aliases.
"GH_TOKEN",
"GITHUB_APP_ID",
"GITHUB_APP_PRIVATE_KEY_PATH",
"GITHUB_APP_INSTALLATION_ID",
# Remote sandbox backend credentials.
"MODAL_TOKEN_ID",
"MODAL_TOKEN_SECRET",
"DAYTONA_API_KEY",
})
return frozenset(blocked)
_HERMES_PROVIDER_ENV_BLOCKLIST = _build_provider_env_blocklist()
def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict:
"""Filter Hermes-managed secrets from a subprocess environment.
`_HERMES_FORCE_<VAR>` entries in ``extra_env`` opt a blocked variable back in
intentionally for callers that truly need it. Vars registered via
:mod:`tools.env_passthrough` (skill-declared or user-configured) also
bypass the blocklist.
"""
try:
from tools.env_passthrough import is_env_passthrough as _is_passthrough
except Exception:
_is_passthrough = lambda _: False # noqa: E731
sanitized: dict[str, str] = {}
for key, value in (base_env or {}).items():
if key.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
continue
if key not in _HERMES_PROVIDER_ENV_BLOCKLIST or _is_passthrough(key):
sanitized[key] = value
for key, value in (extra_env or {}).items():
if key.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
real_key = key[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
sanitized[real_key] = value
elif key not in _HERMES_PROVIDER_ENV_BLOCKLIST or _is_passthrough(key):
sanitized[key] = value
return sanitized
def _find_bash() -> str:
"""Find bash for command execution.
The fence wrapper uses bash syntax (semicolons, $?, printf), so we
must use bash not the user's $SHELL which could be fish/zsh/etc.
On Windows: uses Git Bash (bundled with Git for Windows).
"""
if not _IS_WINDOWS:
return (
shutil.which("bash")
or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None)
or ("/bin/bash" if os.path.isfile("/bin/bash") else None)
or os.environ.get("SHELL") # last resort: whatever they have
or "/bin/sh"
)
# Windows: look for Git Bash (installed with Git for Windows).
# Allow override via env var (same pattern as Claude Code).
custom = os.environ.get("HERMES_GIT_BASH_PATH")
if custom and os.path.isfile(custom):
return custom
# shutil.which finds bash.exe if Git\bin is on PATH
found = shutil.which("bash")
if found:
return found
# Check common Git for Windows install locations
for candidate in (
os.path.join(os.environ.get("ProgramFiles", r"C:\Program Files"), "Git", "bin", "bash.exe"),
os.path.join(os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)"), "Git", "bin", "bash.exe"),
os.path.join(os.environ.get("LOCALAPPDATA", ""), "Programs", "Git", "bin", "bash.exe"),
):
if candidate and os.path.isfile(candidate):
return candidate
raise RuntimeError(
"Git Bash not found. Hermes Agent requires Git for Windows on Windows.\n"
"Install it from: https://git-scm.com/download/win\n"
"Or set HERMES_GIT_BASH_PATH to your bash.exe location."
)
# Backward compat — process_registry.py imports this name
_find_shell = _find_bash
# Noise lines emitted by interactive shells when stdin is not a terminal.
# Used as a fallback when output fence markers are missing.
_SHELL_NOISE_SUBSTRINGS = (
# bash
"bash: cannot set terminal process group",
"bash: no job control in this shell",
"no job control in this shell",
"cannot set terminal process group",
"tcsetattr: Inappropriate ioctl for device",
# zsh / oh-my-zsh / macOS terminal session
"Restored session:",
"Saving session...",
"Last login:",
"command not found:",
"Oh My Zsh",
"compinit:",
)
def _clean_shell_noise(output: str) -> str:
"""Strip shell startup/exit warnings that leak when using -i without a TTY.
Removes lines matching known noise patterns from both the beginning
and end of the output. Lines in the middle are left untouched.
"""
def _is_noise(line: str) -> bool:
return any(noise in line for noise in _SHELL_NOISE_SUBSTRINGS)
lines = output.split("\n")
# Strip leading noise
while lines and _is_noise(lines[0]):
lines.pop(0)
# Strip trailing noise (walk backwards, skip empty lines from split)
end = len(lines) - 1
while end >= 0 and (not lines[end] or _is_noise(lines[end])):
end -= 1
if end < 0:
return ""
cleaned = lines[: end + 1]
result = "\n".join(cleaned)
# Preserve trailing newline if original had one
if output.endswith("\n") and result and not result.endswith("\n"):
result += "\n"
return result
# Standard PATH entries for environments with minimal PATH (e.g. systemd services).
# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon).
_SANE_PATH = (
"/opt/homebrew/bin:/opt/homebrew/sbin:"
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
)
def _make_run_env(env: dict) -> dict:
"""Build a run environment with a sane PATH and provider-var stripping."""
try:
from tools.env_passthrough import is_env_passthrough as _is_passthrough
except Exception:
_is_passthrough = lambda _: False # noqa: E731
merged = dict(os.environ | env)
run_env = {}
for k, v in merged.items():
if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
run_env[real_key] = v
elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST or _is_passthrough(k):
run_env[k] = v
existing_path = run_env.get("PATH", "")
if "/usr/bin" not in existing_path.split(":"):
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
return run_env
def _extract_fenced_output(raw: str) -> str:
"""Extract real command output from between fence markers.
The execute() method wraps each command with printf(FENCE) markers.
This function finds the first and last fence and returns only the
content between them, which is the actual command output free of
any shell init/exit noise.
Falls back to pattern-based _clean_shell_noise if fences are missing.
"""
first = raw.find(_OUTPUT_FENCE)
if first == -1:
return _clean_shell_noise(raw)
start = first + len(_OUTPUT_FENCE)
last = raw.rfind(_OUTPUT_FENCE)
if last <= first:
# Only start fence found (e.g. user command called `exit`)
return _clean_shell_noise(raw[start:])
return raw[start:last]
class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
"""Run commands directly on the host machine.
Features:
- Popen + polling for interrupt support (user can cancel mid-command)
- Background stdout drain thread to prevent pipe buffer deadlocks
- stdin_data support for piping content (bypasses ARG_MAX limits)
- sudo -S transform via SUDO_PASSWORD env var
- Uses interactive login shell so full user env is available
- Optional persistent shell mode (cwd/env vars survive across calls)
"""
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None,
persistent: bool = False):
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
self.persistent = persistent
if self.persistent:
self._init_persistent_shell()
@property
def _temp_prefix(self) -> str:
return f"/tmp/hermes-local-{self._session_id}"
def _spawn_shell_process(self) -> subprocess.Popen:
user_shell = _find_bash()
run_env = _make_run_env(self.env)
return subprocess.Popen(
[user_shell, "-l"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True,
env=run_env,
preexec_fn=None if _IS_WINDOWS else os.setsid,
)
def _read_temp_files(self, *paths: str) -> list[str]:
results = []
for path in paths:
if os.path.exists(path):
with open(path) as f:
results.append(f.read())
else:
results.append("")
return results
def _kill_shell_children(self):
if self._shell_pid is None:
return
try:
subprocess.run(
["pkill", "-P", str(self._shell_pid)],
capture_output=True, timeout=5,
)
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
def _cleanup_temp_files(self):
for f in glob.glob(f"{self._temp_prefix}-*"):
if os.path.exists(f):
os.remove(f)
def _execute_oneshot(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
work_dir = cwd or self.cwd or os.getcwd()
effective_timeout = timeout or self.timeout
exec_command, sudo_stdin = self._prepare_command(command)
if sudo_stdin is not None and stdin_data is not None:
effective_stdin = sudo_stdin + stdin_data
elif sudo_stdin is not None:
effective_stdin = sudo_stdin
else:
effective_stdin = stdin_data
user_shell = _find_bash()
fenced_cmd = (
f"printf '{_OUTPUT_FENCE}';"
f" {exec_command};"
f" __hermes_rc=$?;"
f" printf '{_OUTPUT_FENCE}';"
f" exit $__hermes_rc"
)
run_env = _make_run_env(self.env)
proc = subprocess.Popen(
[user_shell, "-lic", fenced_cmd],
text=True,
cwd=work_dir,
env=run_env,
encoding="utf-8",
errors="replace",
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
preexec_fn=None if _IS_WINDOWS else os.setsid,
)
if effective_stdin is not None:
def _write_stdin():
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except (BrokenPipeError, OSError):
pass
threading.Thread(target=_write_stdin, daemon=True).start()
_output_chunks: list[str] = []
def _drain_stdout():
try:
for line in proc.stdout:
_output_chunks.append(line)
except ValueError:
pass
finally:
try:
proc.stdout.close()
except Exception:
pass
reader = threading.Thread(target=_drain_stdout, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
try:
if _IS_WINDOWS:
proc.terminate()
else:
pgid = os.getpgid(proc.pid)
os.killpg(pgid, signal.SIGTERM)
try:
proc.wait(timeout=1.0)
except subprocess.TimeoutExpired:
os.killpg(pgid, signal.SIGKILL)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
"returncode": 130,
}
if time.monotonic() > deadline:
try:
if _IS_WINDOWS:
proc.terminate()
else:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except (ProcessLookupError, PermissionError):
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
reader.join(timeout=5)
output = _extract_fenced_output("".join(_output_chunks))
return {"output": output, "returncode": proc.returncode}

View file

@ -0,0 +1,259 @@
"""Modal cloud execution environment using SWE-ReX directly.
Supports persistent filesystem snapshots: when enabled, the sandbox's filesystem
is snapshotted on cleanup and restored on next creation, so installed packages,
project files, and config changes survive across sessions.
"""
import asyncio
import json
import logging
import threading
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from hermes_cli.config import get_hermes_home
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
_SNAPSHOT_STORE = get_hermes_home() / "modal_snapshots.json"
def _load_snapshots() -> Dict[str, str]:
"""Load snapshot ID mapping from disk."""
if _SNAPSHOT_STORE.exists():
try:
return json.loads(_SNAPSHOT_STORE.read_text())
except Exception:
pass
return {}
def _save_snapshots(data: Dict[str, str]) -> None:
"""Persist snapshot ID mapping to disk."""
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
class _AsyncWorker:
"""Background thread with its own event loop for async-safe swe-rex calls.
Allows sync code to submit async coroutines and block for results,
even when called from inside another running event loop (e.g. Atropos).
"""
def __init__(self):
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._thread: Optional[threading.Thread] = None
self._started = threading.Event()
def start(self):
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._started.wait(timeout=30)
def _run_loop(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._started.set()
self._loop.run_forever()
def run_coroutine(self, coro, timeout=600):
if self._loop is None or self._loop.is_closed():
raise RuntimeError("AsyncWorker loop is not running")
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result(timeout=timeout)
def stop(self):
if self._loop and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread:
self._thread.join(timeout=10)
class ModalEnvironment(BaseEnvironment):
"""Modal cloud execution via SWE-ReX.
Uses swe-rex's ModalDeployment directly for sandbox management.
Adds sudo -S support, configurable resources (CPU, memory, disk),
and optional filesystem persistence via Modal's snapshot API.
"""
def __init__(
self,
image: str,
cwd: str = "/root",
timeout: int = 60,
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
persistent_filesystem: bool = True,
task_id: str = "default",
):
super().__init__(cwd=cwd, timeout=timeout)
self._persistent = persistent_filesystem
self._task_id = task_id
self._base_image = image
self._deployment = None
self._worker = _AsyncWorker()
sandbox_kwargs = dict(modal_sandbox_kwargs or {})
# If persistent, try to restore from a previous snapshot
restored_image = None
if self._persistent:
snapshot_id = _load_snapshots().get(self._task_id)
if snapshot_id:
try:
import modal
restored_image = modal.Image.from_id(snapshot_id)
logger.info("Modal: restoring from snapshot %s", snapshot_id[:20])
except Exception as e:
logger.warning("Modal: failed to restore snapshot, using base image: %s", e)
restored_image = None
effective_image = restored_image if restored_image else image
# Pre-build a modal.Image with pip fix for Modal's legacy image builder.
# Some task images have broken pip; fix via ensurepip before Modal uses it.
import modal as _modal
if isinstance(effective_image, str):
effective_image = _modal.Image.from_registry(
effective_image,
setup_dockerfile_commands=[
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
],
)
# Start the async worker thread and create the deployment on it
# so all gRPC channels are bound to the worker's event loop.
self._worker.start()
from swerex.deployment.modal import ModalDeployment
async def _create_and_start():
deployment = ModalDeployment(
image=effective_image,
startup_timeout=180.0,
runtime_timeout=3600.0,
deployment_timeout=3600.0,
install_pipx=True,
modal_sandbox_kwargs=sandbox_kwargs,
)
await deployment.start()
return deployment
self._deployment = self._worker.run_coroutine(_create_and_start())
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if stdin_data is not None:
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
while marker in stdin_data:
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
exec_command, sudo_stdin = self._prepare_command(command)
# Modal sandboxes execute commands via the Modal SDK and cannot pipe
# subprocess stdin directly the way a local Popen can. When a sudo
# password is present, use a shell-level pipe from printf so that the
# password feeds sudo -S without appearing as an echo argument embedded
# in the shell string.
if sudo_stdin is not None:
import shlex
exec_command = (
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
)
from swerex.runtime.abstract import Command as RexCommand
effective_cwd = cwd or self.cwd
effective_timeout = timeout or self.timeout
# Run in a background thread so we can poll for interrupts
result_holder = {"value": None, "error": None}
def _run():
try:
async def _do_execute():
return await self._deployment.runtime.execute(
RexCommand(
command=exec_command,
shell=True,
check=False,
cwd=effective_cwd,
timeout=effective_timeout,
merge_output_streams=True,
)
)
output = self._worker.run_coroutine(_do_execute())
result_holder["value"] = {
"output": output.stdout,
"returncode": output.exit_code,
}
except Exception as e:
result_holder["error"] = e
t = threading.Thread(target=_run, daemon=True)
t.start()
while t.is_alive():
t.join(timeout=0.2)
if is_interrupted():
try:
self._worker.run_coroutine(
asyncio.wait_for(self._deployment.stop(), timeout=10),
timeout=15,
)
except Exception:
pass
return {
"output": "[Command interrupted - Modal sandbox terminated]",
"returncode": 130,
}
if result_holder["error"]:
return {"output": f"Modal execution error: {result_holder['error']}", "returncode": 1}
return result_holder["value"]
def cleanup(self):
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
if self._deployment is None:
return
if self._persistent:
try:
sandbox = getattr(self._deployment, '_sandbox', None)
if sandbox:
async def _snapshot():
img = await sandbox.snapshot_filesystem.aio()
return img.object_id
try:
snapshot_id = self._worker.run_coroutine(_snapshot(), timeout=60)
except Exception:
snapshot_id = None
if snapshot_id:
snapshots = _load_snapshots()
snapshots[self._task_id] = snapshot_id
_save_snapshots(snapshots)
logger.info("Modal: saved filesystem snapshot %s for task %s",
snapshot_id[:20], self._task_id)
except Exception as e:
logger.warning("Modal: filesystem snapshot failed: %s", e)
try:
self._worker.run_coroutine(
asyncio.wait_for(self._deployment.stop(), timeout=10),
timeout=15,
)
except Exception:
pass
finally:
self._worker.stop()
self._deployment = None

View file

@ -0,0 +1,272 @@
"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells."""
import logging
import shlex
import subprocess
import threading
import time
import uuid
from abc import abstractmethod
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
class PersistentShellMixin:
"""Mixin that adds persistent shell capability to any BaseEnvironment.
Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``,
``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``.
"""
persistent: bool
@abstractmethod
def _spawn_shell_process(self) -> subprocess.Popen: ...
@abstractmethod
def _read_temp_files(self, *paths: str) -> list[str]: ...
@abstractmethod
def _kill_shell_children(self): ...
@abstractmethod
def _execute_oneshot(self, command: str, cwd: str, *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict: ...
@abstractmethod
def _cleanup_temp_files(self): ...
_session_id: str = ""
_poll_interval: float = 0.01
@property
def _temp_prefix(self) -> str:
return f"/tmp/hermes-persistent-{self._session_id}"
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def _init_persistent_shell(self):
self._shell_lock = threading.Lock()
self._shell_proc: subprocess.Popen | None = None
self._shell_alive: bool = False
self._shell_pid: int | None = None
self._session_id = uuid.uuid4().hex[:12]
p = self._temp_prefix
self._pshell_stdout = f"{p}-stdout"
self._pshell_stderr = f"{p}-stderr"
self._pshell_status = f"{p}-status"
self._pshell_cwd = f"{p}-cwd"
self._pshell_pid_file = f"{p}-pid"
self._shell_proc = self._spawn_shell_process()
self._shell_alive = True
self._drain_thread = threading.Thread(
target=self._drain_shell_output, daemon=True,
)
self._drain_thread.start()
init_script = (
f"export TERM=${{TERM:-dumb}}\n"
f"touch {self._pshell_stdout} {self._pshell_stderr} "
f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n"
f"echo $$ > {self._pshell_pid_file}\n"
f"pwd > {self._pshell_cwd}\n"
)
self._send_to_shell(init_script)
deadline = time.monotonic() + 3.0
while time.monotonic() < deadline:
pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip()
if pid_str.isdigit():
self._shell_pid = int(pid_str)
break
time.sleep(0.05)
else:
logger.warning("Could not read persistent shell PID")
self._shell_pid = None
if self._shell_pid:
logger.info(
"Persistent shell started (session=%s, pid=%d)",
self._session_id, self._shell_pid,
)
reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip()
if reported_cwd:
self.cwd = reported_cwd
def _cleanup_persistent_shell(self):
if self._shell_proc is None:
return
if self._session_id:
self._cleanup_temp_files()
try:
self._shell_proc.stdin.close()
except Exception:
pass
try:
self._shell_proc.terminate()
self._shell_proc.wait(timeout=3)
except subprocess.TimeoutExpired:
self._shell_proc.kill()
self._shell_alive = False
self._shell_proc = None
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
self._drain_thread.join(timeout=1.0)
# ------------------------------------------------------------------
# execute() / cleanup() — shared dispatcher, subclasses inherit
# ------------------------------------------------------------------
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if self.persistent:
return self._execute_persistent(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
def cleanup(self):
if self.persistent:
self._cleanup_persistent_shell()
# ------------------------------------------------------------------
# Shell I/O
# ------------------------------------------------------------------
def _drain_shell_output(self):
try:
for _ in self._shell_proc.stdout:
pass
except Exception:
pass
self._shell_alive = False
def _send_to_shell(self, text: str):
if not self._shell_alive or self._shell_proc is None:
return
try:
self._shell_proc.stdin.write(text)
self._shell_proc.stdin.flush()
except (BrokenPipeError, OSError):
self._shell_alive = False
def _read_persistent_output(self) -> tuple[str, int, str]:
stdout, stderr, status_raw, cwd = self._read_temp_files(
self._pshell_stdout, self._pshell_stderr,
self._pshell_status, self._pshell_cwd,
)
output = self._merge_output(stdout, stderr)
status = status_raw.strip()
if ":" in status:
status = status.split(":", 1)[1]
try:
exit_code = int(status.strip())
except ValueError:
exit_code = 1
return output, exit_code, cwd.strip()
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
def _execute_persistent(self, command: str, cwd: str, *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if not self._shell_alive:
logger.info("Persistent shell died, restarting...")
self._init_persistent_shell()
exec_command, sudo_stdin = self._prepare_command(command)
effective_timeout = timeout or self.timeout
if stdin_data or sudo_stdin:
return self._execute_oneshot(
command, cwd, timeout=timeout, stdin_data=stdin_data,
)
with self._shell_lock:
return self._execute_persistent_locked(
exec_command, cwd, effective_timeout,
)
def _execute_persistent_locked(self, command: str, cwd: str,
timeout: int) -> dict:
work_dir = cwd or self.cwd
cmd_id = uuid.uuid4().hex[:8]
truncate = (
f": > {self._pshell_stdout}\n"
f": > {self._pshell_stderr}\n"
f": > {self._pshell_status}\n"
)
self._send_to_shell(truncate)
escaped = command.replace("'", "'\\''")
ipc_script = (
f"cd {shlex.quote(work_dir)}\n"
f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n"
f"__EC=$?\n"
f"pwd > {self._pshell_cwd}\n"
f"echo {cmd_id}:$__EC > {self._pshell_status}\n"
)
self._send_to_shell(ipc_script)
deadline = time.monotonic() + timeout
poll_interval = self._poll_interval
while True:
if is_interrupted():
self._kill_shell_children()
output, _, _ = self._read_persistent_output()
return {
"output": output + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
self._kill_shell_children()
output, _, _ = self._read_persistent_output()
if output:
return {
"output": output + f"\n[Command timed out after {timeout}s]",
"returncode": 124,
}
return self._timeout_result(timeout)
if not self._shell_alive:
return {
"output": "Persistent shell died during execution",
"returncode": 1,
}
status_content = self._read_temp_files(self._pshell_status)[0].strip()
if status_content.startswith(cmd_id + ":"):
break
time.sleep(poll_interval)
output, exit_code, new_cwd = self._read_persistent_output()
if new_cwd:
self.cwd = new_cwd
return {"output": output, "returncode": exit_code}
@staticmethod
def _merge_output(stdout: str, stderr: str) -> str:
parts = []
if stdout.strip():
parts.append(stdout.rstrip("\n"))
if stderr.strip():
parts.append(stderr.rstrip("\n"))
return "\n".join(parts)

View file

@ -0,0 +1,369 @@
"""Singularity/Apptainer persistent container environment.
Security-hardened with --containall, --no-home, capability dropping.
Supports configurable resource limits and optional filesystem persistence
via writable overlay directories that survive across sessions.
"""
import json
import logging
import os
import shutil
import subprocess
import tempfile
import threading
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from hermes_cli.config import get_hermes_home
from tools.environments.base import BaseEnvironment
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
_SNAPSHOT_STORE = get_hermes_home() / "singularity_snapshots.json"
def _find_singularity_executable() -> str:
"""Locate the apptainer or singularity CLI binary.
Returns the executable name (``"apptainer"`` or ``"singularity"``).
Raises ``RuntimeError`` with install instructions if neither is found.
"""
if shutil.which("apptainer"):
return "apptainer"
if shutil.which("singularity"):
return "singularity"
raise RuntimeError(
"Neither 'apptainer' nor 'singularity' was found in PATH. "
"Install Apptainer (https://apptainer.org/docs/admin/main/installation.html) "
"or Singularity and ensure the CLI is available."
)
def _ensure_singularity_available() -> str:
"""Preflight check: resolve the executable and verify it responds.
Returns the executable name on success.
Raises ``RuntimeError`` with an actionable message on failure.
"""
exe = _find_singularity_executable()
try:
result = subprocess.run(
[exe, "version"],
capture_output=True,
text=True,
timeout=10,
)
except FileNotFoundError:
raise RuntimeError(
f"Singularity backend selected but the resolved executable '{exe}' "
"could not be executed. Check your installation."
)
except subprocess.TimeoutExpired:
raise RuntimeError(
f"'{exe} version' timed out. The runtime may be misconfigured."
)
if result.returncode != 0:
stderr = result.stderr.strip()[:200]
raise RuntimeError(
f"'{exe} version' failed (exit code {result.returncode}): {stderr}"
)
return exe
def _load_snapshots() -> Dict[str, str]:
if _SNAPSHOT_STORE.exists():
try:
return json.loads(_SNAPSHOT_STORE.read_text())
except Exception:
pass
return {}
def _save_snapshots(data: Dict[str, str]) -> None:
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
# -------------------------------------------------------------------------
# Singularity helpers (scratch dir, SIF cache, SIF building)
# -------------------------------------------------------------------------
def _get_scratch_dir() -> Path:
"""Get the best directory for Singularity sandboxes.
Resolution order:
1. TERMINAL_SCRATCH_DIR (explicit override)
2. TERMINAL_SANDBOX_DIR / singularity (shared sandbox root)
3. /scratch (common on HPC clusters)
4. ~/.hermes/sandboxes/singularity (fallback)
"""
custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR")
if custom_scratch:
scratch_path = Path(custom_scratch)
scratch_path.mkdir(parents=True, exist_ok=True)
return scratch_path
from tools.environments.base import get_sandbox_dir
sandbox = get_sandbox_dir() / "singularity"
scratch = Path("/scratch")
if scratch.exists() and os.access(scratch, os.W_OK):
user_scratch = scratch / os.getenv("USER", "hermes") / "hermes-agent"
user_scratch.mkdir(parents=True, exist_ok=True)
logger.info("Using /scratch for sandboxes: %s", user_scratch)
return user_scratch
sandbox.mkdir(parents=True, exist_ok=True)
return sandbox
def _get_apptainer_cache_dir() -> Path:
"""Get the Apptainer cache directory for SIF images."""
cache_dir = os.getenv("APPTAINER_CACHEDIR")
if cache_dir:
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
return cache_path
scratch = _get_scratch_dir()
cache_path = scratch / ".apptainer"
cache_path.mkdir(parents=True, exist_ok=True)
return cache_path
_sif_build_lock = threading.Lock()
def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
"""Get or build a SIF image from a docker:// URL.
Returns the path unchanged if it's already a .sif file.
For docker:// URLs, checks the cache and builds if needed.
"""
if image.endswith('.sif') and Path(image).exists():
return image
if not image.startswith('docker://'):
return image
image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-')
cache_dir = _get_apptainer_cache_dir()
sif_path = cache_dir / f"{image_name}.sif"
if sif_path.exists():
return str(sif_path)
with _sif_build_lock:
if sif_path.exists():
return str(sif_path)
logger.info("Building SIF image (one-time setup)...")
logger.info(" Source: %s", image)
logger.info(" Target: %s", sif_path)
tmp_dir = cache_dir / "tmp"
tmp_dir.mkdir(parents=True, exist_ok=True)
env = os.environ.copy()
env["APPTAINER_TMPDIR"] = str(tmp_dir)
env["APPTAINER_CACHEDIR"] = str(cache_dir)
try:
result = subprocess.run(
[executable, "build", str(sif_path), image],
capture_output=True, text=True, timeout=600, env=env,
)
if result.returncode != 0:
logger.warning("SIF build failed, falling back to docker:// URL")
logger.warning(" Error: %s", result.stderr[:500])
return image
logger.info("SIF image built successfully")
return str(sif_path)
except subprocess.TimeoutExpired:
logger.warning("SIF build timed out, falling back to docker:// URL")
if sif_path.exists():
sif_path.unlink()
return image
except Exception as e:
logger.warning("SIF build error: %s, falling back to docker:// URL", e)
return image
# -------------------------------------------------------------------------
# SingularityEnvironment
# -------------------------------------------------------------------------
class SingularityEnvironment(BaseEnvironment):
"""Hardened Singularity/Apptainer container with resource limits and persistence.
Security: --containall (isolated PID/IPC/mount namespaces, no host home mount),
--no-home, writable-tmpfs for scratch space. The container cannot see or modify
the host filesystem outside of explicitly bound paths.
Persistence: when enabled, the writable overlay directory is preserved across
sessions so installed packages and files survive cleanup/restore.
"""
def __init__(
self,
image: str,
cwd: str = "~",
timeout: int = 60,
cpu: float = 0,
memory: int = 0,
disk: int = 0,
persistent_filesystem: bool = False,
task_id: str = "default",
):
super().__init__(cwd=cwd, timeout=timeout)
self.executable = _ensure_singularity_available()
self.image = _get_or_build_sif(image, self.executable)
self.instance_id = f"hermes_{uuid.uuid4().hex[:12]}"
self._instance_started = False
self._persistent = persistent_filesystem
self._task_id = task_id
self._overlay_dir: Optional[Path] = None
# Resource limits
self._cpu = cpu
self._memory = memory
# Persistent overlay directory
if self._persistent:
overlay_base = _get_scratch_dir() / "hermes-overlays"
overlay_base.mkdir(parents=True, exist_ok=True)
self._overlay_dir = overlay_base / f"overlay-{task_id}"
self._overlay_dir.mkdir(parents=True, exist_ok=True)
self._start_instance()
def _start_instance(self):
cmd = [self.executable, "instance", "start"]
# Security: full isolation from host
cmd.extend(["--containall", "--no-home"])
# Writable layer
if self._persistent and self._overlay_dir:
# Persistent writable overlay -- survives across restarts
cmd.extend(["--overlay", str(self._overlay_dir)])
else:
cmd.append("--writable-tmpfs")
# Resource limits (cgroup-based, may require root or appropriate config)
if self._memory > 0:
cmd.extend(["--memory", f"{self._memory}M"])
if self._cpu > 0:
cmd.extend(["--cpus", str(self._cpu)])
cmd.extend([str(self.image), self.instance_id])
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
if result.returncode != 0:
raise RuntimeError(f"Failed to start instance: {result.stderr}")
self._instance_started = True
logger.info("Singularity instance %s started (persistent=%s)",
self.instance_id, self._persistent)
except subprocess.TimeoutExpired:
raise RuntimeError("Instance start timed out")
def execute(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
if not self._instance_started:
return {"output": "Instance not started", "returncode": -1}
effective_timeout = timeout or self.timeout
work_dir = cwd or self.cwd
exec_command, sudo_stdin = self._prepare_command(command)
# Merge sudo password (if any) with caller-supplied stdin_data.
if sudo_stdin is not None and stdin_data is not None:
effective_stdin = sudo_stdin + stdin_data
elif sudo_stdin is not None:
effective_stdin = sudo_stdin
else:
effective_stdin = stdin_data
# apptainer exec --pwd doesn't expand ~, so prepend a cd into the command
if work_dir == "~" or work_dir.startswith("~/"):
exec_command = f"cd {work_dir} && {exec_command}"
work_dir = "/tmp"
cmd = [self.executable, "exec", "--pwd", work_dir,
f"instance://{self.instance_id}",
"bash", "-c", exec_command]
try:
import time as _time
_output_chunks = []
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
text=True,
)
if effective_stdin:
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except Exception:
pass
def _drain():
try:
for line in proc.stdout:
_output_chunks.append(line)
except Exception:
pass
reader = threading.Thread(target=_drain, daemon=True)
reader.start()
deadline = _time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
proc.terminate()
try:
proc.wait(timeout=1)
except subprocess.TimeoutExpired:
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted]",
"returncode": 130,
}
if _time.monotonic() > deadline:
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
_time.sleep(0.2)
reader.join(timeout=5)
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
except Exception as e:
return {"output": f"Singularity execution error: {e}", "returncode": 1}
def cleanup(self):
"""Stop the instance. If persistent, the overlay dir survives for next creation."""
if self._instance_started:
try:
subprocess.run(
[self.executable, "instance", "stop", self.instance_id],
capture_output=True, text=True, timeout=30,
)
logger.info("Singularity instance %s stopped", self.instance_id)
except Exception as e:
logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e)
self._instance_started = False
# Record overlay path for persistence restoration
if self._persistent and self._overlay_dir:
snapshots = _load_snapshots()
snapshots[self._task_id] = str(self._overlay_dir)
_save_snapshots(snapshots)

View file

@ -0,0 +1,232 @@
"""SSH remote execution environment with ControlMaster connection persistence."""
import logging
import shutil
import subprocess
import tempfile
import threading
import time
from pathlib import Path
from tools.environments.base import BaseEnvironment
from tools.environments.persistent_shell import PersistentShellMixin
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
def _ensure_ssh_available() -> None:
"""Fail fast with a clear error when the SSH client is unavailable."""
if not shutil.which("ssh"):
raise RuntimeError(
"SSH is not installed or not in PATH. Install OpenSSH client: apt install openssh-client"
)
class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
"""Run commands on a remote machine over SSH.
Uses SSH ControlMaster for connection persistence so subsequent
commands are fast. Security benefit: the agent cannot modify its
own code since execution happens on a separate machine.
Foreground commands are interruptible: the local ssh process is killed
and a remote kill is attempted over the ControlMaster socket.
When ``persistent=True``, a single long-lived bash shell is kept alive
over SSH and state (cwd, env vars, shell variables) persists across
``execute()`` calls. Output capture uses file-based IPC on the remote
host (stdout/stderr/exit-code written to temp files, polled via fast
ControlMaster one-shot reads).
"""
def __init__(self, host: str, user: str, cwd: str = "~",
timeout: int = 60, port: int = 22, key_path: str = "",
persistent: bool = False):
super().__init__(cwd=cwd, timeout=timeout)
self.host = host
self.user = user
self.port = port
self.key_path = key_path
self.persistent = persistent
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
self.control_dir.mkdir(parents=True, exist_ok=True)
self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock"
_ensure_ssh_available()
self._establish_connection()
if self.persistent:
self._init_persistent_shell()
def _build_ssh_command(self, extra_args: list | None = None) -> list:
cmd = ["ssh"]
cmd.extend(["-o", f"ControlPath={self.control_socket}"])
cmd.extend(["-o", "ControlMaster=auto"])
cmd.extend(["-o", "ControlPersist=300"])
cmd.extend(["-o", "BatchMode=yes"])
cmd.extend(["-o", "StrictHostKeyChecking=accept-new"])
cmd.extend(["-o", "ConnectTimeout=10"])
if self.port != 22:
cmd.extend(["-p", str(self.port)])
if self.key_path:
cmd.extend(["-i", self.key_path])
if extra_args:
cmd.extend(extra_args)
cmd.append(f"{self.user}@{self.host}")
return cmd
def _establish_connection(self):
cmd = self._build_ssh_command()
cmd.append("echo 'SSH connection established'")
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
if result.returncode != 0:
error_msg = result.stderr.strip() or result.stdout.strip()
raise RuntimeError(f"SSH connection failed: {error_msg}")
except subprocess.TimeoutExpired:
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
_poll_interval: float = 0.15
@property
def _temp_prefix(self) -> str:
return f"/tmp/hermes-ssh-{self._session_id}"
def _spawn_shell_process(self) -> subprocess.Popen:
cmd = self._build_ssh_command()
cmd.append("bash -l")
return subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True,
)
def _read_temp_files(self, *paths: str) -> list[str]:
if len(paths) == 1:
cmd = self._build_ssh_command()
cmd.append(f"cat {paths[0]} 2>/dev/null")
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=10,
)
return [result.stdout]
except (subprocess.TimeoutExpired, OSError):
return [""]
delim = f"__HERMES_SEP_{self._session_id}__"
script = "; ".join(
f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths
)
cmd = self._build_ssh_command()
cmd.append(script)
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=10,
)
parts = result.stdout.split(delim + "\n")
return [parts[i] if i < len(parts) else "" for i in range(len(paths))]
except (subprocess.TimeoutExpired, OSError):
return [""] * len(paths)
def _kill_shell_children(self):
if self._shell_pid is None:
return
cmd = self._build_ssh_command()
cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true")
try:
subprocess.run(cmd, capture_output=True, timeout=5)
except (subprocess.TimeoutExpired, OSError):
pass
def _cleanup_temp_files(self):
cmd = self._build_ssh_command()
cmd.append(f"rm -f {self._temp_prefix}-*")
try:
subprocess.run(cmd, capture_output=True, timeout=5)
except (subprocess.TimeoutExpired, OSError):
pass
def _execute_oneshot(self, command: str, cwd: str = "", *,
timeout: int | None = None,
stdin_data: str | None = None) -> dict:
work_dir = cwd or self.cwd
exec_command, sudo_stdin = self._prepare_command(command)
wrapped = f'cd {work_dir} && {exec_command}'
effective_timeout = timeout or self.timeout
if sudo_stdin is not None and stdin_data is not None:
effective_stdin = sudo_stdin + stdin_data
elif sudo_stdin is not None:
effective_stdin = sudo_stdin
else:
effective_stdin = stdin_data
cmd = self._build_ssh_command()
cmd.append(wrapped)
kwargs = self._build_run_kwargs(timeout, effective_stdin)
kwargs.pop("timeout", None)
_output_chunks = []
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
text=True,
)
if effective_stdin:
try:
proc.stdin.write(effective_stdin)
proc.stdin.close()
except (BrokenPipeError, OSError):
pass
def _drain():
try:
for line in proc.stdout:
_output_chunks.append(line)
except Exception:
pass
reader = threading.Thread(target=_drain, daemon=True)
reader.start()
deadline = time.monotonic() + effective_timeout
while proc.poll() is None:
if is_interrupted():
proc.terminate()
try:
proc.wait(timeout=1)
except subprocess.TimeoutExpired:
proc.kill()
reader.join(timeout=2)
return {
"output": "".join(_output_chunks) + "\n[Command interrupted]",
"returncode": 130,
}
if time.monotonic() > deadline:
proc.kill()
reader.join(timeout=2)
return self._timeout_result(effective_timeout)
time.sleep(0.2)
reader.join(timeout=5)
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
def cleanup(self):
super().cleanup()
if self.control_socket.exists():
try:
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
"-O", "exit", f"{self.user}@{self.host}"]
subprocess.run(cmd, capture_output=True, timeout=5)
except (OSError, subprocess.SubprocessError):
pass
try:
self.control_socket.unlink()
except OSError:
pass

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,522 @@
#!/usr/bin/env python3
"""File Tools Module - LLM agent file manipulation tools."""
import errno
import json
import logging
import os
import threading
from typing import Optional
from tools.file_operations import ShellFileOperations
from agent.redact import redact_sensitive_text
logger = logging.getLogger(__name__)
_EXPECTED_WRITE_ERRNOS = {errno.EACCES, errno.EPERM, errno.EROFS}
def _is_expected_write_exception(exc: Exception) -> bool:
"""Return True for expected write denials that should not hit error logs."""
if isinstance(exc, PermissionError):
return True
if isinstance(exc, OSError) and exc.errno in _EXPECTED_WRITE_ERRNOS:
return True
return False
_file_ops_lock = threading.Lock()
_file_ops_cache: dict = {}
# Track files read per task to detect re-read loops after context compression.
# Per task_id we store:
# "last_key": the key of the most recent read/search call (or None)
# "consecutive": how many times that exact call has been repeated in a row
# "read_history": set of (path, offset, limit) tuples for get_read_files_summary
_read_tracker_lock = threading.Lock()
_read_tracker: dict = {}
def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
"""Get or create ShellFileOperations for a terminal environment.
Respects the TERMINAL_ENV setting -- if the task_id doesn't have an
environment yet, creates one using the configured backend (local, docker,
modal, etc.) rather than always defaulting to local.
Thread-safe: uses the same per-task creation locks as terminal_tool to
prevent duplicate sandbox creation from concurrent tool calls.
"""
from tools.terminal_tool import (
_active_environments, _env_lock, _create_environment,
_get_env_config, _last_activity, _start_cleanup_thread,
_check_disk_usage_warning,
_creation_locks, _creation_locks_lock,
)
import time
# Fast path: check cache -- but also verify the underlying environment
# is still alive (it may have been killed by the cleanup thread).
with _file_ops_lock:
cached = _file_ops_cache.get(task_id)
if cached is not None:
with _env_lock:
if task_id in _active_environments:
_last_activity[task_id] = time.time()
return cached
else:
# Environment was cleaned up -- invalidate stale cache entry
with _file_ops_lock:
_file_ops_cache.pop(task_id, None)
# Need to ensure the environment exists before building file_ops.
# Acquire per-task lock so only one thread creates the sandbox.
with _creation_locks_lock:
if task_id not in _creation_locks:
_creation_locks[task_id] = threading.Lock()
task_lock = _creation_locks[task_id]
with task_lock:
# Double-check: another thread may have created it while we waited
with _env_lock:
if task_id in _active_environments:
_last_activity[task_id] = time.time()
terminal_env = _active_environments[task_id]
else:
terminal_env = None
if terminal_env is None:
from tools.terminal_tool import _task_env_overrides
config = _get_env_config()
env_type = config["env_type"]
overrides = _task_env_overrides.get(task_id, {})
if env_type == "docker":
image = overrides.get("docker_image") or config["docker_image"]
elif env_type == "singularity":
image = overrides.get("singularity_image") or config["singularity_image"]
elif env_type == "modal":
image = overrides.get("modal_image") or config["modal_image"]
elif env_type == "daytona":
image = overrides.get("daytona_image") or config["daytona_image"]
else:
image = ""
cwd = overrides.get("cwd") or config["cwd"]
logger.info("Creating new %s environment for task %s...", env_type, task_id[:8])
container_config = None
if env_type in ("docker", "singularity", "modal", "daytona"):
container_config = {
"container_cpu": config.get("container_cpu", 1),
"container_memory": config.get("container_memory", 5120),
"container_disk": config.get("container_disk", 51200),
"container_persistent": config.get("container_persistent", True),
"docker_volumes": config.get("docker_volumes", []),
}
ssh_config = None
if env_type == "ssh":
ssh_config = {
"host": config.get("ssh_host", ""),
"user": config.get("ssh_user", ""),
"port": config.get("ssh_port", 22),
"key": config.get("ssh_key", ""),
"persistent": config.get("ssh_persistent", False),
}
local_config = None
if env_type == "local":
local_config = {
"persistent": config.get("local_persistent", False),
}
terminal_env = _create_environment(
env_type=env_type,
image=image,
cwd=cwd,
timeout=config["timeout"],
ssh_config=ssh_config,
container_config=container_config,
local_config=local_config,
task_id=task_id,
host_cwd=config.get("host_cwd"),
)
with _env_lock:
_active_environments[task_id] = terminal_env
_last_activity[task_id] = time.time()
_start_cleanup_thread()
logger.info("%s environment ready for task %s", env_type, task_id[:8])
# Build file_ops from the (guaranteed live) environment and cache it
file_ops = ShellFileOperations(terminal_env)
with _file_ops_lock:
_file_ops_cache[task_id] = file_ops
return file_ops
def clear_file_ops_cache(task_id: str = None):
"""Clear the file operations cache."""
with _file_ops_lock:
if task_id:
_file_ops_cache.pop(task_id, None)
else:
_file_ops_cache.clear()
def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = "default") -> str:
"""Read a file with pagination and line numbers."""
try:
# Security: block direct reads of internal Hermes cache/index files
# to prevent prompt injection via catalog or hub metadata files.
import pathlib as _pathlib
_resolved = _pathlib.Path(path).expanduser().resolve()
_hermes_home = _pathlib.Path("~/.hermes").expanduser().resolve()
_blocked_dirs = [
_hermes_home / "skills" / ".hub" / "index-cache",
_hermes_home / "skills" / ".hub",
]
for _blocked in _blocked_dirs:
try:
_resolved.relative_to(_blocked)
return json.dumps({
"error": (
f"Access denied: {path} is an internal Hermes cache file "
"and cannot be read directly to prevent prompt injection. "
"Use the skills_list or skill_view tools instead."
)
})
except ValueError:
pass
file_ops = _get_file_ops(task_id)
result = file_ops.read_file(path, offset, limit)
if result.content:
result.content = redact_sensitive_text(result.content)
result_dict = result.to_dict()
# Track reads to detect *consecutive* re-read loops.
# The counter resets whenever any other tool is called in between,
# so only truly back-to-back identical reads trigger warnings/blocks.
read_key = ("read", path, offset, limit)
with _read_tracker_lock:
task_data = _read_tracker.setdefault(task_id, {
"last_key": None, "consecutive": 0, "read_history": set(),
})
task_data["read_history"].add((path, offset, limit))
if task_data["last_key"] == read_key:
task_data["consecutive"] += 1
else:
task_data["last_key"] = read_key
task_data["consecutive"] = 1
count = task_data["consecutive"]
if count >= 4:
# Hard block: stop returning content to break the loop
return json.dumps({
"error": (
f"BLOCKED: You have read this exact file region {count} times in a row. "
"The content has NOT changed. You already have this information. "
"STOP re-reading and proceed with your task."
),
"path": path,
"already_read": count,
}, ensure_ascii=False)
elif count >= 3:
result_dict["_warning"] = (
f"You have read this exact file region {count} times consecutively. "
"The content has not changed since your last read. Use the information you already have. "
"If you are stuck in a loop, stop reading and proceed with writing or responding."
)
return json.dumps(result_dict, ensure_ascii=False)
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
def get_read_files_summary(task_id: str = "default") -> list:
"""Return a list of files read in this session for the given task.
Used by context compression to preserve file-read history across
compression boundaries.
"""
with _read_tracker_lock:
task_data = _read_tracker.get(task_id, {})
read_history = task_data.get("read_history", set())
seen_paths: dict = {}
for (path, offset, limit) in read_history:
if path not in seen_paths:
seen_paths[path] = []
seen_paths[path].append(f"lines {offset}-{offset + limit - 1}")
return [
{"path": p, "regions": regions}
for p, regions in sorted(seen_paths.items())
]
def clear_read_tracker(task_id: str = None):
"""Clear the read tracker.
Call with a task_id to clear just that task, or without to clear all.
Should be called when a session is destroyed to prevent memory leaks
in long-running gateway processes.
"""
with _read_tracker_lock:
if task_id:
_read_tracker.pop(task_id, None)
else:
_read_tracker.clear()
def notify_other_tool_call(task_id: str = "default"):
"""Reset consecutive read/search counter for a task.
Called by the tool dispatcher (model_tools.py) whenever a tool OTHER
than read_file / search_files is executed. This ensures we only warn
or block on *truly consecutive* repeated reads if the agent does
anything else in between (write, patch, terminal, etc.) the counter
resets and the next read is treated as fresh.
"""
with _read_tracker_lock:
task_data = _read_tracker.get(task_id)
if task_data:
task_data["last_key"] = None
task_data["consecutive"] = 0
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
"""Write content to a file."""
try:
file_ops = _get_file_ops(task_id)
result = file_ops.write_file(path, content)
return json.dumps(result.to_dict(), ensure_ascii=False)
except Exception as e:
if _is_expected_write_exception(e):
logger.debug("write_file expected denial: %s: %s", type(e).__name__, e)
else:
logger.error("write_file error: %s: %s", type(e).__name__, e, exc_info=True)
return json.dumps({"error": str(e)}, ensure_ascii=False)
def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
new_string: str = None, replace_all: bool = False, patch: str = None,
task_id: str = "default") -> str:
"""Patch a file using replace mode or V4A patch format."""
try:
file_ops = _get_file_ops(task_id)
if mode == "replace":
if not path:
return json.dumps({"error": "path required"})
if old_string is None or new_string is None:
return json.dumps({"error": "old_string and new_string required"})
result = file_ops.patch_replace(path, old_string, new_string, replace_all)
elif mode == "patch":
if not patch:
return json.dumps({"error": "patch content required"})
result = file_ops.patch_v4a(patch)
else:
return json.dumps({"error": f"Unknown mode: {mode}"})
result_dict = result.to_dict()
result_json = json.dumps(result_dict, ensure_ascii=False)
# Hint when old_string not found — saves iterations where the agent
# retries with stale content instead of re-reading the file.
if result_dict.get("error") and "Could not find" in str(result_dict["error"]):
result_json += "\n\n[Hint: old_string not found. Use read_file to verify the current content, or search_files to locate the text.]"
return result_json
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
def search_tool(pattern: str, target: str = "content", path: str = ".",
file_glob: str = None, limit: int = 50, offset: int = 0,
output_mode: str = "content", context: int = 0,
task_id: str = "default") -> str:
"""Search for content or files."""
try:
# Track searches to detect *consecutive* repeated search loops.
# Include pagination args so users can page through truncated
# results without tripping the repeated-search guard.
search_key = (
"search",
pattern,
target,
str(path),
file_glob or "",
limit,
offset,
)
with _read_tracker_lock:
task_data = _read_tracker.setdefault(task_id, {
"last_key": None, "consecutive": 0, "read_history": set(),
})
if task_data["last_key"] == search_key:
task_data["consecutive"] += 1
else:
task_data["last_key"] = search_key
task_data["consecutive"] = 1
count = task_data["consecutive"]
if count >= 4:
return json.dumps({
"error": (
f"BLOCKED: You have run this exact search {count} times in a row. "
"The results have NOT changed. You already have this information. "
"STOP re-searching and proceed with your task."
),
"pattern": pattern,
"already_searched": count,
}, ensure_ascii=False)
file_ops = _get_file_ops(task_id)
result = file_ops.search(
pattern=pattern, path=path, target=target, file_glob=file_glob,
limit=limit, offset=offset, output_mode=output_mode, context=context
)
if hasattr(result, 'matches'):
for m in result.matches:
if hasattr(m, 'content') and m.content:
m.content = redact_sensitive_text(m.content)
result_dict = result.to_dict()
if count >= 3:
result_dict["_warning"] = (
f"You have run this exact search {count} times consecutively. "
"The results have not changed. Use the information you already have."
)
result_json = json.dumps(result_dict, ensure_ascii=False)
# Hint when results were truncated — explicit next offset is clearer
# than relying on the model to infer it from total_count vs match count.
if result_dict.get("truncated"):
next_offset = offset + limit
result_json += f"\n\n[Hint: Results truncated. Use offset={next_offset} to see more, or narrow with a more specific pattern or file_glob.]"
return result_json
except Exception as e:
return json.dumps({"error": str(e)}, ensure_ascii=False)
FILE_TOOLS = [
{"name": "read_file", "function": read_file_tool},
{"name": "write_file", "function": write_file_tool},
{"name": "patch", "function": patch_tool},
{"name": "search_files", "function": search_tool}
]
def get_file_tools():
"""Get the list of file tool definitions."""
return FILE_TOOLS
# ---------------------------------------------------------------------------
# Schemas + Registry
# ---------------------------------------------------------------------------
from tools.registry import registry
def _check_file_reqs():
"""Lazy wrapper to avoid circular import with tools/__init__.py."""
from tools import check_file_requirements
return check_file_requirements()
READ_FILE_SCHEMA = {
"name": "read_file",
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. NOTE: Cannot read images or binary files — use vision_analyze for images.",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path to the file to read (absolute, relative, or ~/path)"},
"offset": {"type": "integer", "description": "Line number to start reading from (1-indexed, default: 1)", "default": 1, "minimum": 1},
"limit": {"type": "integer", "description": "Maximum number of lines to read (default: 500, max: 2000)", "default": 500, "maximum": 2000}
},
"required": ["path"]
}
}
WRITE_FILE_SCHEMA = {
"name": "write_file",
"description": "Write content to a file, completely replacing existing content. Use this instead of echo/cat heredoc in terminal. Creates parent directories automatically. OVERWRITES the entire file — use 'patch' for targeted edits.",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"},
"content": {"type": "string", "description": "Complete content to write to the file"}
},
"required": ["path", "content"]
}
}
PATCH_SCHEMA = {
"name": "patch",
"description": "Targeted find-and-replace edits in files. Use this instead of sed/awk in terminal. Uses fuzzy matching (9 strategies) so minor whitespace/indentation differences won't break it. Returns a unified diff. Auto-runs syntax checks after editing.\n\nReplace mode (default): find a unique string and replace it.\nPatch mode: apply V4A multi-file patches for bulk changes.",
"parameters": {
"type": "object",
"properties": {
"mode": {"type": "string", "enum": ["replace", "patch"], "description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches", "default": "replace"},
"path": {"type": "string", "description": "File path to edit (required for 'replace' mode)"},
"old_string": {"type": "string", "description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness."},
"new_string": {"type": "string", "description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text."},
"replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match (default: false)", "default": False},
"patch": {"type": "string", "description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch"}
},
"required": ["mode"]
}
}
SEARCH_FILES_SCHEMA = {
"name": "search_files",
"description": "Search file contents or find files by name. Use this instead of grep/rg/find/ls in terminal. Ripgrep-backed, faster than shell equivalents.\n\nContent search (target='content'): Regex search inside files. Output modes: full matches with line numbers, file paths only, or match counts.\n\nFile search (target='files'): Find files by glob pattern (e.g., '*.py', '*config*'). Also use this instead of ls — results sorted by modification time.",
"parameters": {
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search"},
"target": {"type": "string", "enum": ["content", "files"], "description": "'content' searches inside file contents, 'files' searches for files by name", "default": "content"},
"path": {"type": "string", "description": "Directory or file to search in (default: current working directory)", "default": "."},
"file_glob": {"type": "string", "description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)"},
"limit": {"type": "integer", "description": "Maximum number of results to return (default: 50)", "default": 50},
"offset": {"type": "integer", "description": "Skip first N results for pagination (default: 0)", "default": 0},
"output_mode": {"type": "string", "enum": ["content", "files_only", "count"], "description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file", "default": "content"},
"context": {"type": "integer", "description": "Number of context lines before and after each match (grep mode only)", "default": 0}
},
"required": ["pattern"]
}
}
def _handle_read_file(args, **kw):
tid = kw.get("task_id") or "default"
return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid)
def _handle_write_file(args, **kw):
tid = kw.get("task_id") or "default"
return write_file_tool(path=args.get("path", ""), content=args.get("content", ""), task_id=tid)
def _handle_patch(args, **kw):
tid = kw.get("task_id") or "default"
return patch_tool(
mode=args.get("mode", "replace"), path=args.get("path"),
old_string=args.get("old_string"), new_string=args.get("new_string"),
replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid)
def _handle_search_files(args, **kw):
tid = kw.get("task_id") or "default"
target_map = {"grep": "content", "find": "files"}
raw_target = args.get("target", "content")
target = target_map.get(raw_target, raw_target)
return search_tool(
pattern=args.get("pattern", ""), target=target, path=args.get("path", "."),
file_glob=args.get("file_glob"), limit=args.get("limit", 50), offset=args.get("offset", 0),
output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid)
registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs, emoji="📖")
registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs, emoji="✍️")
registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs, emoji="🔧")
registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs, emoji="🔎")

View file

@ -0,0 +1,487 @@
#!/usr/bin/env python3
"""
Fuzzy Matching Module for File Operations
Implements a multi-strategy matching chain to robustly find and replace text,
accommodating variations in whitespace, indentation, and escaping common
in LLM-generated code.
The 8-strategy chain (inspired by OpenCode), tried in order:
1. Exact match - Direct string comparison
2. Line-trimmed - Strip leading/trailing whitespace per line
3. Whitespace normalized - Collapse multiple spaces/tabs to single space
4. Indentation flexible - Ignore indentation differences entirely
5. Escape normalized - Convert \\n literals to actual newlines
6. Trimmed boundary - Trim first/last line whitespace only
7. Block anchor - Match first+last lines, use similarity for middle
8. Context-aware - 50% line similarity threshold
Multi-occurrence matching is handled via the replace_all flag.
Usage:
from tools.fuzzy_match import fuzzy_find_and_replace
new_content, match_count, error = fuzzy_find_and_replace(
content="def foo():\\n pass",
old_string="def foo():",
new_string="def bar():",
replace_all=False
)
"""
import re
from typing import Tuple, Optional, List, Callable
from difflib import SequenceMatcher
UNICODE_MAP = {
"\u201c": '"', "\u201d": '"', # smart double quotes
"\u2018": "'", "\u2019": "'", # smart single quotes
"\u2014": "--", "\u2013": "-", # em/en dashes
"\u2026": "...", "\u00a0": " ", # ellipsis and non-breaking space
}
def _unicode_normalize(text: str) -> str:
"""Normalizes Unicode characters to their standard ASCII equivalents."""
for char, repl in UNICODE_MAP.items():
text = text.replace(char, repl)
return text
def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
replace_all: bool = False) -> Tuple[str, int, Optional[str]]:
"""
Find and replace text using a chain of increasingly fuzzy matching strategies.
Args:
content: The file content to search in
old_string: The text to find
new_string: The replacement text
replace_all: If True, replace all occurrences; if False, require uniqueness
Returns:
Tuple of (new_content, match_count, error_message)
- If successful: (modified_content, number_of_replacements, None)
- If failed: (original_content, 0, error_description)
"""
if not old_string:
return content, 0, "old_string cannot be empty"
if old_string == new_string:
return content, 0, "old_string and new_string are identical"
# Try each matching strategy in order
strategies: List[Tuple[str, Callable]] = [
("exact", _strategy_exact),
("line_trimmed", _strategy_line_trimmed),
("whitespace_normalized", _strategy_whitespace_normalized),
("indentation_flexible", _strategy_indentation_flexible),
("escape_normalized", _strategy_escape_normalized),
("trimmed_boundary", _strategy_trimmed_boundary),
("block_anchor", _strategy_block_anchor),
("context_aware", _strategy_context_aware),
]
for strategy_name, strategy_fn in strategies:
matches = strategy_fn(content, old_string)
if matches:
# Found matches with this strategy
if len(matches) > 1 and not replace_all:
return content, 0, (
f"Found {len(matches)} matches for old_string. "
f"Provide more context to make it unique, or use replace_all=True."
)
# Perform replacement
new_content = _apply_replacements(content, matches, new_string)
return new_content, len(matches), None
# No strategy found a match
return content, 0, "Could not find a match for old_string in the file"
def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str:
"""
Apply replacements at the given positions.
Args:
content: Original content
matches: List of (start, end) positions to replace
new_string: Replacement text
Returns:
Content with replacements applied
"""
# Sort matches by position (descending) to replace from end to start
# This preserves positions of earlier matches
sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True)
result = content
for start, end in sorted_matches:
result = result[:start] + new_string + result[end:]
return result
# =============================================================================
# Matching Strategies
# =============================================================================
def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
"""Strategy 1: Exact string match."""
matches = []
start = 0
while True:
pos = content.find(pattern, start)
if pos == -1:
break
matches.append((pos, pos + len(pattern)))
start = pos + 1
return matches
def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]:
"""
Strategy 2: Match with line-by-line whitespace trimming.
Strips leading/trailing whitespace from each line before matching.
"""
# Normalize pattern and content by trimming each line
pattern_lines = [line.strip() for line in pattern.split('\n')]
pattern_normalized = '\n'.join(pattern_lines)
content_lines = content.split('\n')
content_normalized_lines = [line.strip() for line in content_lines]
# Build mapping from normalized positions back to original positions
return _find_normalized_matches(
content, content_lines, content_normalized_lines,
pattern, pattern_normalized
)
def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
"""
Strategy 3: Collapse multiple whitespace to single space.
"""
def normalize(s):
# Collapse multiple spaces/tabs to single space, preserve newlines
return re.sub(r'[ \t]+', ' ', s)
pattern_normalized = normalize(pattern)
content_normalized = normalize(content)
# Find in normalized, map back to original
matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized)
if not matches_in_normalized:
return []
# Map positions back to original content
return _map_normalized_positions(content, content_normalized, matches_in_normalized)
def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]:
"""
Strategy 4: Ignore indentation differences entirely.
Strips all leading whitespace from lines before matching.
"""
def strip_indent(s):
return '\n'.join(line.lstrip() for line in s.split('\n'))
pattern_stripped = strip_indent(pattern)
content_lines = content.split('\n')
content_stripped_lines = [line.lstrip() for line in content_lines]
pattern_lines = [line.lstrip() for line in pattern.split('\n')]
return _find_normalized_matches(
content, content_lines, content_stripped_lines,
pattern, '\n'.join(pattern_lines)
)
def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
"""
Strategy 5: Convert escape sequences to actual characters.
Handles \\n -> newline, \\t -> tab, etc.
"""
def unescape(s):
# Convert common escape sequences
return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r')
pattern_unescaped = unescape(pattern)
if pattern_unescaped == pattern:
# No escapes to convert, skip this strategy
return []
return _strategy_exact(content, pattern_unescaped)
def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]:
"""
Strategy 6: Trim whitespace from first and last lines only.
Useful when the pattern boundaries have whitespace differences.
"""
pattern_lines = pattern.split('\n')
if not pattern_lines:
return []
# Trim only first and last lines
pattern_lines[0] = pattern_lines[0].strip()
if len(pattern_lines) > 1:
pattern_lines[-1] = pattern_lines[-1].strip()
modified_pattern = '\n'.join(pattern_lines)
content_lines = content.split('\n')
# Search through content for matching block
matches = []
pattern_line_count = len(pattern_lines)
for i in range(len(content_lines) - pattern_line_count + 1):
block_lines = content_lines[i:i + pattern_line_count]
# Trim first and last of this block
check_lines = block_lines.copy()
check_lines[0] = check_lines[0].strip()
if len(check_lines) > 1:
check_lines[-1] = check_lines[-1].strip()
if '\n'.join(check_lines) == modified_pattern:
# Found match - calculate original positions
start_pos, end_pos = _calculate_line_positions(
content_lines, i, i + pattern_line_count, len(content)
)
matches.append((start_pos, end_pos))
return matches
def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
"""
Strategy 7: Match by anchoring on first and last lines.
Adjusted with permissive thresholds and unicode normalization.
"""
# Normalize both strings for comparison while keeping original content for offset calculation
norm_pattern = _unicode_normalize(pattern)
norm_content = _unicode_normalize(content)
pattern_lines = norm_pattern.split('\n')
if len(pattern_lines) < 2:
return []
first_line = pattern_lines[0].strip()
last_line = pattern_lines[-1].strip()
# Use normalized lines for matching logic
norm_content_lines = norm_content.split('\n')
# BUT use original lines for calculating start/end positions to prevent index shift
orig_content_lines = content.split('\n')
pattern_line_count = len(pattern_lines)
potential_matches = []
for i in range(len(norm_content_lines) - pattern_line_count + 1):
if (norm_content_lines[i].strip() == first_line and
norm_content_lines[i + pattern_line_count - 1].strip() == last_line):
potential_matches.append(i)
matches = []
candidate_count = len(potential_matches)
# Thresholding logic: 0.10 for unique matches (max flexibility), 0.30 for multiple candidates
threshold = 0.10 if candidate_count == 1 else 0.30
for i in potential_matches:
if pattern_line_count <= 2:
similarity = 1.0
else:
# Compare normalized middle sections
content_middle = '\n'.join(norm_content_lines[i+1:i+pattern_line_count-1])
pattern_middle = '\n'.join(pattern_lines[1:-1])
similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio()
if similarity >= threshold:
# Calculate positions using ORIGINAL lines to ensure correct character offsets in the file
start_pos, end_pos = _calculate_line_positions(
orig_content_lines, i, i + pattern_line_count, len(content)
)
matches.append((start_pos, end_pos))
return matches
def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]:
"""
Strategy 8: Line-by-line similarity with 50% threshold.
Finds blocks where at least 50% of lines have high similarity.
"""
pattern_lines = pattern.split('\n')
content_lines = content.split('\n')
if not pattern_lines:
return []
matches = []
pattern_line_count = len(pattern_lines)
for i in range(len(content_lines) - pattern_line_count + 1):
block_lines = content_lines[i:i + pattern_line_count]
# Calculate line-by-line similarity
high_similarity_count = 0
for p_line, c_line in zip(pattern_lines, block_lines):
sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio()
if sim >= 0.80:
high_similarity_count += 1
# Need at least 50% of lines to have high similarity
if high_similarity_count >= len(pattern_lines) * 0.5:
start_pos, end_pos = _calculate_line_positions(
content_lines, i, i + pattern_line_count, len(content)
)
matches.append((start_pos, end_pos))
return matches
# =============================================================================
# Helper Functions
# =============================================================================
def _calculate_line_positions(content_lines: List[str], start_line: int,
end_line: int, content_length: int) -> Tuple[int, int]:
"""Calculate start and end character positions from line indices.
Args:
content_lines: List of lines (without newlines)
start_line: Starting line index (0-based)
end_line: Ending line index (exclusive, 0-based)
content_length: Total length of the original content string
Returns:
Tuple of (start_pos, end_pos) in the original content
"""
start_pos = sum(len(line) + 1 for line in content_lines[:start_line])
end_pos = sum(len(line) + 1 for line in content_lines[:end_line]) - 1
if end_pos >= content_length:
end_pos = content_length
return start_pos, end_pos
def _find_normalized_matches(content: str, content_lines: List[str],
content_normalized_lines: List[str],
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
"""
Find matches in normalized content and map back to original positions.
Args:
content: Original content string
content_lines: Original content split by lines
content_normalized_lines: Normalized content lines
pattern: Original pattern
pattern_normalized: Normalized pattern
Returns:
List of (start, end) positions in the original content
"""
pattern_norm_lines = pattern_normalized.split('\n')
num_pattern_lines = len(pattern_norm_lines)
matches = []
for i in range(len(content_normalized_lines) - num_pattern_lines + 1):
# Check if this block matches
block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines])
if block == pattern_normalized:
# Found a match - calculate original positions
start_pos, end_pos = _calculate_line_positions(
content_lines, i, i + num_pattern_lines, len(content)
)
matches.append((start_pos, end_pos))
return matches
def _map_normalized_positions(original: str, normalized: str,
normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
"""
Map positions from normalized string back to original.
This is a best-effort mapping that works for whitespace normalization.
"""
if not normalized_matches:
return []
# Build character mapping from normalized to original
orig_to_norm = [] # orig_to_norm[i] = position in normalized
orig_idx = 0
norm_idx = 0
while orig_idx < len(original) and norm_idx < len(normalized):
if original[orig_idx] == normalized[norm_idx]:
orig_to_norm.append(norm_idx)
orig_idx += 1
norm_idx += 1
elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ':
# Original has space/tab, normalized collapsed to space
orig_to_norm.append(norm_idx)
orig_idx += 1
# Don't advance norm_idx yet - wait until all whitespace consumed
if orig_idx < len(original) and original[orig_idx] not in ' \t':
norm_idx += 1
elif original[orig_idx] in ' \t':
# Extra whitespace in original
orig_to_norm.append(norm_idx)
orig_idx += 1
else:
# Mismatch - shouldn't happen with our normalization
orig_to_norm.append(norm_idx)
orig_idx += 1
# Fill remaining
while orig_idx < len(original):
orig_to_norm.append(len(normalized))
orig_idx += 1
# Reverse mapping: for each normalized position, find original range
norm_to_orig_start = {}
norm_to_orig_end = {}
for orig_pos, norm_pos in enumerate(orig_to_norm):
if norm_pos not in norm_to_orig_start:
norm_to_orig_start[norm_pos] = orig_pos
norm_to_orig_end[norm_pos] = orig_pos
# Map matches
original_matches = []
for norm_start, norm_end in normalized_matches:
# Find original start
if norm_start in norm_to_orig_start:
orig_start = norm_to_orig_start[norm_start]
else:
# Find nearest
orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start)
# Find original end
if norm_end - 1 in norm_to_orig_end:
orig_end = norm_to_orig_end[norm_end - 1] + 1
else:
orig_end = orig_start + (norm_end - norm_start)
# Expand to include trailing whitespace that was normalized
while orig_end < len(original) and original[orig_end] in ' \t':
orig_end += 1
original_matches.append((orig_start, min(orig_end, len(original))))
return original_matches

View file

@ -0,0 +1,490 @@
"""Home Assistant tool for controlling smart home devices via REST API.
Registers four LLM-callable tools:
- ``ha_list_entities`` -- list/filter entities by domain or area
- ``ha_get_state`` -- get detailed state of a single entity
- ``ha_list_services`` -- list available services (actions) per domain
- ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.)
Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var.
The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123).
"""
import asyncio
import json
import logging
import os
import re
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
# Kept for backward compatibility (e.g. test monkeypatching); prefer _get_config().
_HASS_URL: str = ""
_HASS_TOKEN: str = ""
def _get_config():
"""Return (hass_url, hass_token) from env vars at call time."""
return (
(_HASS_URL or os.getenv("HASS_URL", "http://homeassistant.local:8123")).rstrip("/"),
_HASS_TOKEN or os.getenv("HASS_TOKEN", ""),
)
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
# Service domains blocked for security -- these allow arbitrary code/command
# execution on the HA host or enable SSRF attacks on the local network.
# HA provides zero service-level access control; all safety must be in our layer.
_BLOCKED_DOMAINS = frozenset({
"shell_command", # arbitrary shell commands as root in HA container
"command_line", # sensors/switches that execute shell commands
"python_script", # sandboxed but can escalate via hass.services.call()
"pyscript", # scripting integration with broader access
"hassio", # addon control, host shutdown/reboot, stdin to containers
"rest_command", # HTTP requests from HA server (SSRF vector)
})
def _get_headers(token: str = "") -> Dict[str, str]:
"""Return authorization headers for HA REST API."""
if not token:
_, token = _get_config()
return {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
# ---------------------------------------------------------------------------
# Async helpers (called from sync handlers via run_until_complete)
# ---------------------------------------------------------------------------
def _filter_and_summarize(
states: list,
domain: Optional[str] = None,
area: Optional[str] = None,
) -> Dict[str, Any]:
"""Filter raw HA states by domain/area and return a compact summary."""
if domain:
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
if area:
area_lower = area.lower()
states = [
s for s in states
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
]
entities = []
for s in states:
entities.append({
"entity_id": s["entity_id"],
"state": s["state"],
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
})
return {"count": len(entities), "entities": entities}
async def _async_list_entities(
domain: Optional[str] = None,
area: Optional[str] = None,
) -> Dict[str, Any]:
"""Fetch entity states from HA and optionally filter by domain/area."""
import aiohttp
hass_url, hass_token = _get_config()
url = f"{hass_url}/api/states"
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=15)) as resp:
resp.raise_for_status()
states = await resp.json()
return _filter_and_summarize(states, domain, area)
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
"""Fetch detailed state of a single entity."""
import aiohttp
hass_url, hass_token = _get_config()
url = f"{hass_url}/api/states/{entity_id}"
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=10)) as resp:
resp.raise_for_status()
data = await resp.json()
return {
"entity_id": data["entity_id"],
"state": data["state"],
"attributes": data.get("attributes", {}),
"last_changed": data.get("last_changed"),
"last_updated": data.get("last_updated"),
}
def _build_service_payload(
entity_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Build the JSON payload for a HA service call."""
payload: Dict[str, Any] = {}
if data:
payload.update(data)
# entity_id parameter takes precedence over data["entity_id"]
if entity_id:
payload["entity_id"] = entity_id
return payload
def _parse_service_response(
domain: str,
service: str,
result: Any,
) -> Dict[str, Any]:
"""Parse HA service call response into a structured result."""
affected = []
if isinstance(result, list):
for s in result:
affected.append({
"entity_id": s.get("entity_id", ""),
"state": s.get("state", ""),
})
return {
"success": True,
"service": f"{domain}.{service}",
"affected_entities": affected,
}
async def _async_call_service(
domain: str,
service: str,
entity_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Call a Home Assistant service."""
import aiohttp
hass_url, hass_token = _get_config()
url = f"{hass_url}/api/services/{domain}/{service}"
payload = _build_service_payload(entity_id, data)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
headers=_get_headers(hass_token),
json=payload,
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
resp.raise_for_status()
result = await resp.json()
return _parse_service_response(domain, service, result)
# ---------------------------------------------------------------------------
# Sync wrappers (handler signature: (args, **kw) -> str)
# ---------------------------------------------------------------------------
def _run_async(coro):
"""Run an async coroutine from a sync handler."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
# Already inside an event loop -- create a new thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(asyncio.run, coro)
return future.result(timeout=30)
else:
return asyncio.run(coro)
def _handle_list_entities(args: dict, **kw) -> str:
"""Handler for ha_list_entities tool."""
domain = args.get("domain")
area = args.get("area")
try:
result = _run_async(_async_list_entities(domain=domain, area=area))
return json.dumps({"result": result})
except Exception as e:
logger.error("ha_list_entities error: %s", e)
return json.dumps({"error": f"Failed to list entities: {e}"})
def _handle_get_state(args: dict, **kw) -> str:
"""Handler for ha_get_state tool."""
entity_id = args.get("entity_id", "")
if not entity_id:
return json.dumps({"error": "Missing required parameter: entity_id"})
if not _ENTITY_ID_RE.match(entity_id):
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
try:
result = _run_async(_async_get_state(entity_id))
return json.dumps({"result": result})
except Exception as e:
logger.error("ha_get_state error: %s", e)
return json.dumps({"error": f"Failed to get state for {entity_id}: {e}"})
def _handle_call_service(args: dict, **kw) -> str:
"""Handler for ha_call_service tool."""
domain = args.get("domain", "")
service = args.get("service", "")
if not domain or not service:
return json.dumps({"error": "Missing required parameters: domain and service"})
if domain in _BLOCKED_DOMAINS:
return json.dumps({
"error": f"Service domain '{domain}' is blocked for security. "
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
})
entity_id = args.get("entity_id")
if entity_id and not _ENTITY_ID_RE.match(entity_id):
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
data = args.get("data")
try:
result = _run_async(_async_call_service(domain, service, entity_id, data))
return json.dumps({"result": result})
except Exception as e:
logger.error("ha_call_service error: %s", e)
return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"})
# ---------------------------------------------------------------------------
# List services
# ---------------------------------------------------------------------------
async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
"""Fetch available services from HA and optionally filter by domain."""
import aiohttp
hass_url, hass_token = _get_config()
url = f"{hass_url}/api/services"
headers = {"Authorization": f"Bearer {hass_token}", "Content-Type": "application/json"}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=15)) as resp:
resp.raise_for_status()
services = await resp.json()
if domain:
services = [s for s in services if s.get("domain") == domain]
# Compact the output for context efficiency
result = []
for svc_domain in services:
d = svc_domain.get("domain", "")
domain_services = {}
for svc_name, svc_info in svc_domain.get("services", {}).items():
svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")}
fields = svc_info.get("fields", {})
if fields:
svc_entry["fields"] = {
k: v.get("description", "") for k, v in fields.items()
if isinstance(v, dict)
}
domain_services[svc_name] = svc_entry
result.append({"domain": d, "services": domain_services})
return {"count": len(result), "domains": result}
def _handle_list_services(args: dict, **kw) -> str:
"""Handler for ha_list_services tool."""
domain = args.get("domain")
try:
result = _run_async(_async_list_services(domain=domain))
return json.dumps({"result": result})
except Exception as e:
logger.error("ha_list_services error: %s", e)
return json.dumps({"error": f"Failed to list services: {e}"})
# ---------------------------------------------------------------------------
# Availability check
# ---------------------------------------------------------------------------
def _check_ha_available() -> bool:
"""Tool is only available when HASS_TOKEN is set."""
return bool(os.getenv("HASS_TOKEN"))
# ---------------------------------------------------------------------------
# Tool schemas
# ---------------------------------------------------------------------------
HA_LIST_ENTITIES_SCHEMA = {
"name": "ha_list_entities",
"description": (
"List Home Assistant entities. Optionally filter by domain "
"(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) "
"or by area name (living room, kitchen, bedroom, etc.)."
),
"parameters": {
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": (
"Entity domain to filter by (e.g. 'light', 'switch', 'climate', "
"'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). "
"Omit to list all entities."
),
},
"area": {
"type": "string",
"description": (
"Area/room name to filter by (e.g. 'living room', 'kitchen'). "
"Matches against entity friendly names. Omit to list all."
),
},
},
"required": [],
},
}
HA_GET_STATE_SCHEMA = {
"name": "ha_get_state",
"description": (
"Get the detailed state of a single Home Assistant entity, including all "
"attributes (brightness, color, temperature setpoint, sensor readings, etc.)."
),
"parameters": {
"type": "object",
"properties": {
"entity_id": {
"type": "string",
"description": (
"The entity ID to query (e.g. 'light.living_room', "
"'climate.thermostat', 'sensor.temperature')."
),
},
},
"required": ["entity_id"],
},
}
HA_LIST_SERVICES_SCHEMA = {
"name": "ha_list_services",
"description": (
"List available Home Assistant services (actions) for device control. "
"Shows what actions can be performed on each device type and what "
"parameters they accept. Use this to discover how to control devices "
"found via ha_list_entities."
),
"parameters": {
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": (
"Filter by domain (e.g. 'light', 'climate', 'switch'). "
"Omit to list services for all domains."
),
},
},
"required": [],
},
}
HA_CALL_SERVICE_SCHEMA = {
"name": "ha_call_service",
"description": (
"Call a Home Assistant service to control a device. Use ha_list_services "
"to discover available services and their parameters for each domain."
),
"parameters": {
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": (
"Service domain (e.g. 'light', 'switch', 'climate', "
"'cover', 'media_player', 'fan', 'scene', 'script')."
),
},
"service": {
"type": "string",
"description": (
"Service name (e.g. 'turn_on', 'turn_off', 'toggle', "
"'set_temperature', 'set_hvac_mode', 'open_cover', "
"'close_cover', 'set_volume_level')."
),
},
"entity_id": {
"type": "string",
"description": (
"Target entity ID (e.g. 'light.living_room'). "
"Some services (like scene.turn_on) may not need this."
),
},
"data": {
"type": "object",
"description": (
"Additional service data. Examples: "
'{"brightness": 255, "color_name": "blue"} for lights, '
'{"temperature": 22, "hvac_mode": "heat"} for climate, '
'{"volume_level": 0.5} for media players.'
),
},
},
"required": ["domain", "service"],
},
}
# ---------------------------------------------------------------------------
# Registration
# ---------------------------------------------------------------------------
from tools.registry import registry
registry.register(
name="ha_list_entities",
toolset="homeassistant",
schema=HA_LIST_ENTITIES_SCHEMA,
handler=_handle_list_entities,
check_fn=_check_ha_available,
emoji="🏠",
)
registry.register(
name="ha_get_state",
toolset="homeassistant",
schema=HA_GET_STATE_SCHEMA,
handler=_handle_get_state,
check_fn=_check_ha_available,
emoji="🏠",
)
registry.register(
name="ha_list_services",
toolset="homeassistant",
schema=HA_LIST_SERVICES_SCHEMA,
handler=_handle_list_services,
check_fn=_check_ha_available,
emoji="🏠",
)
registry.register(
name="ha_call_service",
toolset="homeassistant",
schema=HA_CALL_SERVICE_SCHEMA,
handler=_handle_call_service,
check_fn=_check_ha_available,
emoji="🏠",
)

View file

@ -0,0 +1,264 @@
"""Honcho tools for user context retrieval.
Registers three complementary tools, ordered by capability:
honcho_context dialectic Q&A (LLM-powered, direct answers)
honcho_search semantic search (fast, no LLM, raw excerpts)
honcho_profile peer card (fast, no LLM, structured facts)
Use honcho_context when you need Honcho to synthesize an answer.
Use honcho_search or honcho_profile when you want raw data to reason
over yourself.
The session key is injected at runtime by the agent loop via
``set_session_context()``.
"""
import json
import logging
logger = logging.getLogger(__name__)
# ── Module-level state (injected by AIAgent at init time) ──
_session_manager = None # HonchoSessionManager instance
_session_key: str | None = None # Current session key (e.g., "telegram:123456")
def set_session_context(session_manager, session_key: str) -> None:
"""Register the active Honcho session manager and key.
Called by AIAgent.__init__ when Honcho is enabled.
"""
global _session_manager, _session_key
_session_manager = session_manager
_session_key = session_key
def clear_session_context() -> None:
"""Clear session context (for testing or shutdown)."""
global _session_manager, _session_key
_session_manager = None
_session_key = None
# ── Availability check ──
def _check_honcho_available() -> bool:
"""Tool is only available when Honcho is active."""
return _session_manager is not None and _session_key is not None
def _resolve_session_context(**kwargs):
"""Prefer the calling agent's session context over module-global fallback."""
session_manager = kwargs.get("honcho_manager") or _session_manager
session_key = kwargs.get("honcho_session_key") or _session_key
return session_manager, session_key
# ── honcho_profile ──
_PROFILE_SCHEMA = {
"name": "honcho_profile",
"description": (
"Retrieve the user's peer card from Honcho — a curated list of key facts "
"about them (name, role, preferences, communication style, patterns). "
"Fast, no LLM reasoning, minimal cost. "
"Use this at conversation start or when you need a quick factual snapshot. "
"Use honcho_context instead when you need Honcho to synthesize an answer."
),
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
}
def _handle_honcho_profile(args: dict, **kw) -> str:
session_manager, session_key = _resolve_session_context(**kw)
if not session_manager or not session_key:
return json.dumps({"error": "Honcho is not active for this session."})
try:
card = session_manager.get_peer_card(session_key)
if not card:
return json.dumps({"result": "No profile facts available yet. The user's profile builds over time through conversations."})
return json.dumps({"result": card})
except Exception as e:
logger.error("Error fetching Honcho peer card: %s", e)
return json.dumps({"error": f"Failed to fetch profile: {e}"})
# ── honcho_search ──
_SEARCH_SCHEMA = {
"name": "honcho_search",
"description": (
"Semantic search over Honcho's stored context about the user. "
"Returns raw excerpts ranked by relevance to your query — no LLM synthesis. "
"Cheaper and faster than honcho_context. "
"Good when you want to find specific past facts and reason over them yourself. "
"Use honcho_context when you need a direct synthesized answer."
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "What to search for in Honcho's memory (e.g. 'programming languages', 'past projects', 'timezone').",
},
"max_tokens": {
"type": "integer",
"description": "Token budget for returned context (default 800, max 2000).",
},
},
"required": ["query"],
},
}
def _handle_honcho_search(args: dict, **kw) -> str:
query = args.get("query", "")
if not query:
return json.dumps({"error": "Missing required parameter: query"})
session_manager, session_key = _resolve_session_context(**kw)
if not session_manager or not session_key:
return json.dumps({"error": "Honcho is not active for this session."})
max_tokens = min(int(args.get("max_tokens", 800)), 2000)
try:
result = session_manager.search_context(session_key, query, max_tokens=max_tokens)
if not result:
return json.dumps({"result": "No relevant context found."})
return json.dumps({"result": result})
except Exception as e:
logger.error("Error searching Honcho context: %s", e)
return json.dumps({"error": f"Failed to search context: {e}"})
# ── honcho_context (dialectic — LLM-powered) ──
_QUERY_SCHEMA = {
"name": "honcho_context",
"description": (
"Ask Honcho a natural language question and get a synthesized answer. "
"Uses Honcho's LLM (dialectic reasoning) — higher cost than honcho_profile or honcho_search. "
"Can query about any peer: the user (default), the AI assistant, or any named peer. "
"Examples: 'What are the user's main goals?', 'What has hermes been working on?', "
"'What is the user's technical expertise level?'"
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "A natural language question.",
},
"peer": {
"type": "string",
"description": "Which peer to query about: 'user' (default) or 'ai'. Omit for user.",
},
},
"required": ["query"],
},
}
def _handle_honcho_context(args: dict, **kw) -> str:
query = args.get("query", "")
if not query:
return json.dumps({"error": "Missing required parameter: query"})
session_manager, session_key = _resolve_session_context(**kw)
if not session_manager or not session_key:
return json.dumps({"error": "Honcho is not active for this session."})
peer_target = args.get("peer", "user")
try:
result = session_manager.dialectic_query(session_key, query, peer=peer_target)
return json.dumps({"result": result or "No result from Honcho."})
except Exception as e:
logger.error("Error querying Honcho context: %s", e)
return json.dumps({"error": f"Failed to query context: {e}"})
# ── honcho_conclude ──
_CONCLUDE_SCHEMA = {
"name": "honcho_conclude",
"description": (
"Write a conclusion about the user back to Honcho's memory. "
"Conclusions are persistent facts that build the user's profile — "
"preferences, corrections, clarifications, project context, or anything "
"the user tells you that should be remembered across sessions. "
"Use this when the user explicitly states a preference, corrects you, "
"or shares something they want remembered. "
"Examples: 'User prefers dark mode', 'User's project uses Python 3.11', "
"'User corrected: their name is spelled Eri not Eric'."
),
"parameters": {
"type": "object",
"properties": {
"conclusion": {
"type": "string",
"description": "A factual statement about the user to persist in memory.",
}
},
"required": ["conclusion"],
},
}
def _handle_honcho_conclude(args: dict, **kw) -> str:
conclusion = args.get("conclusion", "")
if not conclusion:
return json.dumps({"error": "Missing required parameter: conclusion"})
session_manager, session_key = _resolve_session_context(**kw)
if not session_manager or not session_key:
return json.dumps({"error": "Honcho is not active for this session."})
try:
ok = session_manager.create_conclusion(session_key, conclusion)
if ok:
return json.dumps({"result": f"Conclusion saved: {conclusion}"})
return json.dumps({"error": "Failed to save conclusion."})
except Exception as e:
logger.error("Error creating Honcho conclusion: %s", e)
return json.dumps({"error": f"Failed to save conclusion: {e}"})
# ── Registration ──
from tools.registry import registry
registry.register(
name="honcho_profile",
toolset="honcho",
schema=_PROFILE_SCHEMA,
handler=_handle_honcho_profile,
check_fn=_check_honcho_available,
emoji="🔮",
)
registry.register(
name="honcho_search",
toolset="honcho",
schema=_SEARCH_SCHEMA,
handler=_handle_honcho_search,
check_fn=_check_honcho_available,
emoji="🔮",
)
registry.register(
name="honcho_context",
toolset="honcho",
schema=_QUERY_SCHEMA,
handler=_handle_honcho_context,
check_fn=_check_honcho_available,
emoji="🔮",
)
registry.register(
name="honcho_conclude",
toolset="honcho",
schema=_CONCLUDE_SCHEMA,
handler=_handle_honcho_conclude,
check_fn=_check_honcho_available,
emoji="🔮",
)

View file

@ -0,0 +1,562 @@
#!/usr/bin/env python3
"""
Image Generation Tools Module
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
Available tools:
- image_generate_tool: Generate images from text prompts with automatic upscaling
Features:
- High-quality image generation using FLUX 2 Pro model
- Automatic 2x upscaling using Clarity Upscaler for enhanced quality
- Comprehensive parameter control (size, steps, guidance, etc.)
- Proper error handling and validation with fallback to original images
- Debug logging support
- Sync mode for immediate results
Usage:
from image_generation_tool import image_generate_tool
import asyncio
# Generate and automatically upscale an image
result = await image_generate_tool(
prompt="A serene mountain landscape with cherry blossoms",
image_size="landscape_4_3",
num_images=1
)
"""
import json
import logging
import os
import datetime
from typing import Dict, Any, Optional, Union
import fal_client
from tools.debug_helpers import DebugSession
logger = logging.getLogger(__name__)
# Configuration for image generation
DEFAULT_MODEL = "fal-ai/flux-2-pro"
DEFAULT_ASPECT_RATIO = "landscape"
DEFAULT_NUM_INFERENCE_STEPS = 50
DEFAULT_GUIDANCE_SCALE = 4.5
DEFAULT_NUM_IMAGES = 1
DEFAULT_OUTPUT_FORMAT = "png"
# Safety settings
ENABLE_SAFETY_CHECKER = False
SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive)
# Aspect ratio mapping - simplified choices for model to select
ASPECT_RATIO_MAP = {
"landscape": "landscape_16_9",
"square": "square_hd",
"portrait": "portrait_16_9"
}
VALID_ASPECT_RATIOS = list(ASPECT_RATIO_MAP.keys())
# Configuration for automatic upscaling
UPSCALER_MODEL = "fal-ai/clarity-upscaler"
UPSCALER_FACTOR = 2
UPSCALER_SAFETY_CHECKER = False
UPSCALER_DEFAULT_PROMPT = "masterpiece, best quality, highres"
UPSCALER_NEGATIVE_PROMPT = "(worst quality, low quality, normal quality:2)"
UPSCALER_CREATIVITY = 0.35
UPSCALER_RESEMBLANCE = 0.6
UPSCALER_GUIDANCE_SCALE = 4
UPSCALER_NUM_INFERENCE_STEPS = 18
# Valid parameter values for validation based on FLUX 2 Pro documentation
VALID_IMAGE_SIZES = [
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
]
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
_debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG")
def _validate_parameters(
image_size: Union[str, Dict[str, int]],
num_inference_steps: int,
guidance_scale: float,
num_images: int,
output_format: str,
acceleration: str = "none"
) -> Dict[str, Any]:
"""
Validate and normalize image generation parameters for FLUX 2 Pro model.
Args:
image_size: Either a preset string or custom size dict
num_inference_steps: Number of inference steps
guidance_scale: Guidance scale value
num_images: Number of images to generate
output_format: Output format for images
acceleration: Acceleration mode for generation speed
Returns:
Dict[str, Any]: Validated and normalized parameters
Raises:
ValueError: If any parameter is invalid
"""
validated = {}
# Validate image_size
if isinstance(image_size, str):
if image_size not in VALID_IMAGE_SIZES:
raise ValueError(f"Invalid image_size '{image_size}'. Must be one of: {VALID_IMAGE_SIZES}")
validated["image_size"] = image_size
elif isinstance(image_size, dict):
if "width" not in image_size or "height" not in image_size:
raise ValueError("Custom image_size must contain 'width' and 'height' keys")
if not isinstance(image_size["width"], int) or not isinstance(image_size["height"], int):
raise ValueError("Custom image_size width and height must be integers")
if image_size["width"] < 64 or image_size["height"] < 64:
raise ValueError("Custom image_size dimensions must be at least 64x64")
if image_size["width"] > 2048 or image_size["height"] > 2048:
raise ValueError("Custom image_size dimensions must not exceed 2048x2048")
validated["image_size"] = image_size
else:
raise ValueError("image_size must be either a preset string or a dict with width/height")
# Validate num_inference_steps
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
raise ValueError("num_inference_steps must be an integer between 1 and 100")
validated["num_inference_steps"] = num_inference_steps
# Validate guidance_scale (FLUX 2 Pro default is 4.5)
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
validated["guidance_scale"] = float(guidance_scale)
# Validate num_images
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
raise ValueError("num_images must be an integer between 1 and 4")
validated["num_images"] = num_images
# Validate output_format
if output_format not in VALID_OUTPUT_FORMATS:
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
validated["output_format"] = output_format
# Validate acceleration
if acceleration not in VALID_ACCELERATION_MODES:
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
validated["acceleration"] = acceleration
return validated
def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
"""
Upscale an image using FAL.ai's Clarity Upscaler.
Uses the synchronous fal_client API to avoid event loop lifecycle issues
when called from threaded contexts (e.g. gateway thread pool).
Args:
image_url (str): URL of the image to upscale
original_prompt (str): Original prompt used to generate the image
Returns:
Dict[str, Any]: Upscaled image data or None if upscaling fails
"""
try:
logger.info("Upscaling image with Clarity Upscaler...")
# Prepare arguments for upscaler
upscaler_arguments = {
"image_url": image_url,
"prompt": f"{UPSCALER_DEFAULT_PROMPT}, {original_prompt}",
"upscale_factor": UPSCALER_FACTOR,
"negative_prompt": UPSCALER_NEGATIVE_PROMPT,
"creativity": UPSCALER_CREATIVITY,
"resemblance": UPSCALER_RESEMBLANCE,
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
}
# Use sync API — fal_client.submit() uses httpx.Client (no event loop).
# The async API (submit_async) caches a global httpx.AsyncClient via
# @cached_property, which breaks when asyncio.run() destroys the loop
# between calls (gateway thread-pool pattern).
handler = fal_client.submit(
UPSCALER_MODEL,
arguments=upscaler_arguments
)
# Get the upscaled result (sync — blocks until done)
result = handler.get()
if result and "image" in result:
upscaled_image = result["image"]
logger.info("Image upscaled successfully to %sx%s", upscaled_image.get('width', 'unknown'), upscaled_image.get('height', 'unknown'))
return {
"url": upscaled_image["url"],
"width": upscaled_image.get("width", 0),
"height": upscaled_image.get("height", 0),
"upscaled": True,
"upscale_factor": UPSCALER_FACTOR
}
else:
logger.error("Upscaler returned invalid response")
return None
except Exception as e:
logger.error("Error upscaling image: %s", e, exc_info=True)
return None
def image_generate_tool(
prompt: str,
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS,
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
num_images: int = DEFAULT_NUM_IMAGES,
output_format: str = DEFAULT_OUTPUT_FORMAT,
seed: Optional[int] = None
) -> str:
"""
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
Uses the synchronous fal_client API to avoid event loop lifecycle issues.
The async API's global httpx.AsyncClient (cached via @cached_property) breaks
when asyncio.run() destroys and recreates event loops between calls, which
happens in the gateway's thread-pool pattern.
Args:
prompt (str): The text prompt describing the desired image
aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape")
num_inference_steps (int): Number of denoising steps (1-50, default: 50)
guidance_scale (float): How closely to follow prompt (0.1-20.0, default: 4.5)
num_images (int): Number of images to generate (1-4, default: 1)
output_format (str): Image format "jpeg" or "png" (default: "png")
seed (Optional[int]): Random seed for reproducible results (optional)
Returns:
str: JSON string containing minimal generation results:
{
"success": bool,
"image": str or None # URL of the upscaled image, or None if failed
}
"""
# Validate and map aspect_ratio to actual image_size
aspect_ratio_lower = aspect_ratio.lower().strip() if aspect_ratio else DEFAULT_ASPECT_RATIO
if aspect_ratio_lower not in ASPECT_RATIO_MAP:
logger.warning("Invalid aspect_ratio '%s', defaulting to '%s'", aspect_ratio, DEFAULT_ASPECT_RATIO)
aspect_ratio_lower = DEFAULT_ASPECT_RATIO
image_size = ASPECT_RATIO_MAP[aspect_ratio_lower]
debug_call_data = {
"parameters": {
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"image_size": image_size,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"num_images": num_images,
"output_format": output_format,
"seed": seed
},
"error": None,
"success": False,
"images_generated": 0,
"generation_time": 0
}
start_time = datetime.datetime.now()
try:
logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80])
# Validate prompt
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
raise ValueError("Prompt is required and must be a non-empty string")
# Check API key availability
if not os.getenv("FAL_KEY"):
raise ValueError("FAL_KEY environment variable not set")
# Validate other parameters
validated_params = _validate_parameters(
image_size, num_inference_steps, guidance_scale, num_images, output_format, "none"
)
# Prepare arguments for FAL.ai FLUX 2 Pro API
arguments = {
"prompt": prompt.strip(),
"image_size": validated_params["image_size"],
"num_inference_steps": validated_params["num_inference_steps"],
"guidance_scale": validated_params["guidance_scale"],
"num_images": validated_params["num_images"],
"output_format": validated_params["output_format"],
"enable_safety_checker": ENABLE_SAFETY_CHECKER,
"safety_tolerance": SAFETY_TOLERANCE,
"sync_mode": True # Use sync mode for immediate results
}
# Add seed if provided
if seed is not None and isinstance(seed, int):
arguments["seed"] = seed
logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...")
logger.info(" Model: %s", DEFAULT_MODEL)
logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size)
logger.info(" Steps: %s", validated_params['num_inference_steps'])
logger.info(" Guidance: %s", validated_params['guidance_scale'])
# Submit request to FAL.ai using sync API (avoids cached event loop issues)
handler = fal_client.submit(
DEFAULT_MODEL,
arguments=arguments
)
# Get the result (sync — blocks until done)
result = handler.get()
generation_time = (datetime.datetime.now() - start_time).total_seconds()
# Process the response
if not result or "images" not in result:
raise ValueError("Invalid response from FAL.ai API - no images returned")
images = result.get("images", [])
if not images:
raise ValueError("No images were generated")
# Format image data and upscale images
formatted_images = []
for img in images:
if isinstance(img, dict) and "url" in img:
original_image = {
"url": img["url"],
"width": img.get("width", 0),
"height": img.get("height", 0)
}
# Attempt to upscale the image
upscaled_image = _upscale_image(img["url"], prompt.strip())
if upscaled_image:
# Use upscaled image if successful
formatted_images.append(upscaled_image)
else:
# Fall back to original image if upscaling fails
logger.warning("Using original image as fallback")
original_image["upscaled"] = False
formatted_images.append(original_image)
if not formatted_images:
raise ValueError("No valid image URLs returned from API")
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count)
# Prepare successful response - minimal format
response_data = {
"success": True,
"image": formatted_images[0]["url"] if formatted_images else None
}
debug_call_data["success"] = True
debug_call_data["images_generated"] = len(formatted_images)
debug_call_data["generation_time"] = generation_time
# Log debug information
_debug.log_call("image_generate_tool", debug_call_data)
_debug.save()
return json.dumps(response_data, indent=2, ensure_ascii=False)
except Exception as e:
generation_time = (datetime.datetime.now() - start_time).total_seconds()
error_msg = f"Error generating image: {str(e)}"
logger.error("%s", error_msg, exc_info=True)
# Prepare error response - minimal format
response_data = {
"success": False,
"image": None
}
debug_call_data["error"] = error_msg
debug_call_data["generation_time"] = generation_time
_debug.log_call("image_generate_tool", debug_call_data)
_debug.save()
return json.dumps(response_data, indent=2, ensure_ascii=False)
def check_fal_api_key() -> bool:
"""
Check if the FAL.ai API key is available in environment variables.
Returns:
bool: True if API key is set, False otherwise
"""
return bool(os.getenv("FAL_KEY"))
def check_image_generation_requirements() -> bool:
"""
Check if all requirements for image generation tools are met.
Returns:
bool: True if requirements are met, False otherwise
"""
try:
# Check API key
if not check_fal_api_key():
return False
# Check if fal_client is available
import fal_client
return True
except ImportError:
return False
def get_debug_session_info() -> Dict[str, Any]:
"""
Get information about the current debug session.
Returns:
Dict[str, Any]: Dictionary containing debug session information
"""
return _debug.get_session_info()
if __name__ == "__main__":
"""
Simple test/demo when run directly
"""
print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling")
print("=" * 60)
# Check if API key is available
api_available = check_fal_api_key()
if not api_available:
print("❌ FAL_KEY environment variable not set")
print("Please set your API key: export FAL_KEY='your-key-here'")
print("Get API key at: https://fal.ai/")
exit(1)
else:
print("✅ FAL.ai API key found")
# Check if fal_client is available
try:
import fal_client
print("✅ fal_client library available")
except ImportError:
print("❌ fal_client library not found")
print("Please install: pip install fal-client")
exit(1)
print("🛠️ Image generation tools ready for use!")
print(f"🤖 Using model: {DEFAULT_MODEL}")
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
# Show debug mode status
if _debug.active:
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{_debug.session_id}.json")
else:
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
print("\nBasic usage:")
print(" from image_generation_tool import image_generate_tool")
print(" import asyncio")
print("")
print(" async def main():")
print(" # Generate image with automatic 2x upscaling")
print(" result = await image_generate_tool(")
print(" prompt='A serene mountain landscape with cherry blossoms',")
print(" image_size='landscape_4_3',")
print(" num_images=1")
print(" )")
print(" print(result)")
print(" asyncio.run(main())")
print("\nSupported image sizes:")
for size in VALID_IMAGE_SIZES:
print(f" - {size}")
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
print("\nAcceleration modes:")
for mode in VALID_ACCELERATION_MODES:
print(f" - {mode}")
print("\nExample prompts:")
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
print(" - 'Modern architecture building with glass facade, sunset lighting'")
print(" - 'Abstract art with vibrant colors and geometric patterns'")
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
print(" - 'Futuristic cityscape with flying cars and neon lights'")
print("\nDebug mode:")
print(" # Enable debug logging")
print(" export IMAGE_TOOLS_DEBUG=true")
print(" # Debug logs capture all image generation calls and results")
print(" # Logs saved to: ./logs/image_tools_debug_UUID.json")
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
from tools.registry import registry
IMAGE_GENERATE_SCHEMA = {
"name": "image_generate",
"description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL. Display it using markdown: ![description](URL)",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "The text prompt describing the desired image. Be detailed and descriptive."
},
"aspect_ratio": {
"type": "string",
"enum": ["landscape", "square", "portrait"],
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
"default": "landscape"
}
},
"required": ["prompt"]
}
}
def _handle_image_generate(args, **kw):
prompt = args.get("prompt", "")
if not prompt:
return json.dumps({"error": "prompt is required for image generation"})
return image_generate_tool(
prompt=prompt,
aspect_ratio=args.get("aspect_ratio", "landscape"),
num_inference_steps=50,
guidance_scale=4.5,
num_images=1,
output_format="png",
seed=None,
)
registry.register(
name="image_generate",
toolset="image_gen",
schema=IMAGE_GENERATE_SCHEMA,
handler=_handle_image_generate,
check_fn=check_image_generation_requirements,
requires_env=["FAL_KEY"],
is_async=False, # Switched to sync fal_client API to fix "Event loop is closed" in gateway
emoji="🎨",
)

View file

@ -0,0 +1,28 @@
"""Shared interrupt signaling for all tools.
Provides a global threading.Event that any tool can check to determine
if the user has requested an interrupt. The agent's interrupt() method
sets this event, and tools poll it during long-running operations.
Usage in tools:
from tools.interrupt import is_interrupted
if is_interrupted():
return {"output": "[interrupted]", "returncode": 130}
"""
import threading
_interrupt_event = threading.Event()
def set_interrupt(active: bool) -> None:
"""Called by the agent to signal or clear the interrupt."""
if active:
_interrupt_event.set()
else:
_interrupt_event.clear()
def is_interrupted() -> bool:
"""Check if an interrupt has been requested. Safe to call from any thread."""
return _interrupt_event.is_set()

View file

@ -0,0 +1,249 @@
"""Thin OAuth adapter for MCP HTTP servers.
Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements
``httpx.Auth``) with Hermes-specific token storage and browser-based
authorization. The SDK handles all of the heavy lifting: PKCE generation,
metadata discovery, dynamic client registration, token exchange, and refresh.
Usage in mcp_tool.py::
from tools.mcp_oauth import build_oauth_auth
auth = build_oauth_auth(server_name, server_url)
# pass ``auth`` as the httpx auth parameter
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import socket
import threading
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any
from urllib.parse import parse_qs, urlparse
logger = logging.getLogger(__name__)
_TOKEN_DIR_NAME = "mcp-tokens"
# ---------------------------------------------------------------------------
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
# ---------------------------------------------------------------------------
def _sanitize_server_name(name: str) -> str:
"""Sanitize server name for safe use as a filename."""
import re
clean = re.sub(r"[^\w\-]", "-", name.strip().lower())
clean = re.sub(r"-+", "-", clean).strip("-")
return clean[:60] or "unnamed"
class HermesTokenStorage:
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
def __init__(self, server_name: str):
self._server_name = _sanitize_server_name(server_name)
def _base_dir(self) -> Path:
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
d = home / _TOKEN_DIR_NAME
d.mkdir(parents=True, exist_ok=True)
return d
def _tokens_path(self) -> Path:
return self._base_dir() / f"{self._server_name}.json"
def _client_path(self) -> Path:
return self._base_dir() / f"{self._server_name}.client.json"
# -- TokenStorage protocol (async) --
async def get_tokens(self):
data = self._read_json(self._tokens_path())
if not data:
return None
try:
from mcp.shared.auth import OAuthToken
return OAuthToken(**data)
except Exception:
return None
async def set_tokens(self, tokens) -> None:
self._write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
async def get_client_info(self):
data = self._read_json(self._client_path())
if not data:
return None
try:
from mcp.shared.auth import OAuthClientInformationFull
return OAuthClientInformationFull(**data)
except Exception:
return None
async def set_client_info(self, client_info) -> None:
self._write_json(self._client_path(), client_info.model_dump(exclude_none=True))
# -- helpers --
@staticmethod
def _read_json(path: Path) -> dict | None:
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return None
@staticmethod
def _write_json(path: Path, data: dict) -> None:
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
try:
path.chmod(0o600)
except OSError:
pass
def remove(self) -> None:
"""Delete stored tokens and client info for this server."""
for p in (self._tokens_path(), self._client_path()):
try:
p.unlink(missing_ok=True)
except OSError:
pass
# ---------------------------------------------------------------------------
# Browser-based callback handler
# ---------------------------------------------------------------------------
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _make_callback_handler():
"""Create a callback handler class with instance-scoped result storage."""
result = {"auth_code": None, "state": None}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
qs = parse_qs(urlparse(self.path).query)
result["auth_code"] = (qs.get("code") or [None])[0]
result["state"] = (qs.get("state") or [None])[0]
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
def log_message(self, *_args: Any) -> None:
pass
return Handler, result
# Port chosen at build time and shared with the callback handler via closure.
_oauth_port: int | None = None
async def _redirect_to_browser(auth_url: str) -> None:
"""Open the authorization URL in the user's browser."""
try:
if _can_open_browser():
webbrowser.open(auth_url)
print(f" Opened browser for authorization...")
else:
print(f"\n Open this URL to authorize:\n {auth_url}\n")
except Exception:
print(f"\n Open this URL to authorize:\n {auth_url}\n")
async def _wait_for_callback() -> tuple[str, str | None]:
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect."""
global _oauth_port
port = _oauth_port or _find_free_port()
HandlerClass, result = _make_callback_handler()
server = HTTPServer(("127.0.0.1", port), HandlerClass)
def _serve():
server.timeout = 120
server.handle_request()
thread = threading.Thread(target=_serve, daemon=True)
thread.start()
for _ in range(1200): # 120 seconds
await asyncio.sleep(0.1)
if result["auth_code"] is not None:
break
server.server_close()
code = result["auth_code"] or ""
state = result["state"]
if not code:
print(" Browser callback timed out. Paste the authorization code manually:")
code = input(" Code: ").strip()
return code, state
def _can_open_browser() -> bool:
if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"):
return False
if not os.environ.get("DISPLAY") and os.name != "nt" and "darwin" not in os.uname().sysname.lower():
return False
return True
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def build_oauth_auth(server_name: str, server_url: str):
"""Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE.
Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery,
registration, PKCE, token exchange, and refresh automatically.
Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``),
or ``None`` if the MCP SDK auth module is not available.
"""
try:
from mcp.client.auth import OAuthClientProvider
from mcp.shared.auth import OAuthClientMetadata
except ImportError:
logger.warning("MCP SDK auth module not available — OAuth disabled")
return None
global _oauth_port
_oauth_port = _find_free_port()
redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback"
client_metadata = OAuthClientMetadata(
client_name="Hermes Agent",
redirect_uris=[redirect_uri],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope="openid profile email offline_access",
token_endpoint_auth_method="none",
)
storage = HermesTokenStorage(server_name)
return OAuthClientProvider(
server_url=server_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=_redirect_to_browser,
callback_handler=_wait_for_callback,
timeout=120.0,
)
def remove_oauth_tokens(server_name: str) -> None:
"""Delete stored OAuth tokens and client info for a server."""
HermesTokenStorage(server_name).remove()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,547 @@
#!/usr/bin/env python3
"""
Memory Tool Module - Persistent Curated Memory
Provides bounded, file-backed memory that persists across sessions. Two stores:
- MEMORY.md: agent's personal notes and observations (environment facts, project
conventions, tool quirks, things learned)
- USER.md: what the agent knows about the user (preferences, communication style,
expectations, workflow habits)
Both are injected into the system prompt as a frozen snapshot at session start.
Mid-session writes update files on disk immediately (durable) but do NOT change
the system prompt -- this preserves the prefix cache for the entire session.
The snapshot refreshes on the next session start.
Entry delimiter: § (section sign). Entries can be multiline.
Character limits (not tokens) because char counts are model-independent.
Design:
- Single `memory` tool with action parameter: add, replace, remove, read
- replace/remove use short unique substring matching (not full text or IDs)
- Behavioral guidance lives in the tool schema description
- Frozen snapshot pattern: system prompt is stable, tool responses show live state
"""
import fcntl
import json
import logging
import os
import re
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Any, List, Optional
logger = logging.getLogger(__name__)
# Where memory files live
MEMORY_DIR = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "memories"
ENTRY_DELIMITER = "\n§\n"
# ---------------------------------------------------------------------------
# Memory content scanning — lightweight check for injection/exfiltration
# in content that gets injected into the system prompt.
# ---------------------------------------------------------------------------
_MEMORY_THREAT_PATTERNS = [
# Prompt injection
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
(r'you\s+are\s+now\s+', "role_hijack"),
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
(r'system\s+prompt\s+override', "sys_prompt_override"),
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
# Exfiltration via curl/wget with secrets
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)', "read_secrets"),
# Persistence via shell rc
(r'authorized_keys', "ssh_backdoor"),
(r'\$HOME/\.ssh|\~/\.ssh', "ssh_access"),
(r'\$HOME/\.hermes/\.env|\~/\.hermes/\.env', "hermes_env"),
]
# Subset of invisible chars for injection detection
_INVISIBLE_CHARS = {
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
}
def _scan_memory_content(content: str) -> Optional[str]:
"""Scan memory content for injection/exfil patterns. Returns error string if blocked."""
# Check invisible unicode
for char in _INVISIBLE_CHARS:
if char in content:
return f"Blocked: content contains invisible unicode character U+{ord(char):04X} (possible injection)."
# Check threat patterns
for pattern, pid in _MEMORY_THREAT_PATTERNS:
if re.search(pattern, content, re.IGNORECASE):
return f"Blocked: content matches threat pattern '{pid}'. Memory entries are injected into the system prompt and must not contain injection or exfiltration payloads."
return None
class MemoryStore:
"""
Bounded curated memory with file persistence. One instance per AIAgent.
Maintains two parallel states:
- _system_prompt_snapshot: frozen at load time, used for system prompt injection.
Never mutated mid-session. Keeps prefix cache stable.
- memory_entries / user_entries: live state, mutated by tool calls, persisted to disk.
Tool responses always reflect this live state.
"""
def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375):
self.memory_entries: List[str] = []
self.user_entries: List[str] = []
self.memory_char_limit = memory_char_limit
self.user_char_limit = user_char_limit
# Frozen snapshot for system prompt -- set once at load_from_disk()
self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": ""}
def load_from_disk(self):
"""Load entries from MEMORY.md and USER.md, capture system prompt snapshot."""
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
self.memory_entries = self._read_file(MEMORY_DIR / "MEMORY.md")
self.user_entries = self._read_file(MEMORY_DIR / "USER.md")
# Deduplicate entries (preserves order, keeps first occurrence)
self.memory_entries = list(dict.fromkeys(self.memory_entries))
self.user_entries = list(dict.fromkeys(self.user_entries))
# Capture frozen snapshot for system prompt injection
self._system_prompt_snapshot = {
"memory": self._render_block("memory", self.memory_entries),
"user": self._render_block("user", self.user_entries),
}
@staticmethod
@contextmanager
def _file_lock(path: Path):
"""Acquire an exclusive file lock for read-modify-write safety.
Uses a separate .lock file so the memory file itself can still be
atomically replaced via os.replace().
"""
lock_path = path.with_suffix(path.suffix + ".lock")
lock_path.parent.mkdir(parents=True, exist_ok=True)
fd = open(lock_path, "w")
try:
fcntl.flock(fd, fcntl.LOCK_EX)
yield
finally:
fcntl.flock(fd, fcntl.LOCK_UN)
fd.close()
@staticmethod
def _path_for(target: str) -> Path:
if target == "user":
return MEMORY_DIR / "USER.md"
return MEMORY_DIR / "MEMORY.md"
def _reload_target(self, target: str):
"""Re-read entries from disk into in-memory state.
Called under file lock to get the latest state before mutating.
"""
fresh = self._read_file(self._path_for(target))
fresh = list(dict.fromkeys(fresh)) # deduplicate
self._set_entries(target, fresh)
def save_to_disk(self, target: str):
"""Persist entries to the appropriate file. Called after every mutation."""
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
self._write_file(self._path_for(target), self._entries_for(target))
def _entries_for(self, target: str) -> List[str]:
if target == "user":
return self.user_entries
return self.memory_entries
def _set_entries(self, target: str, entries: List[str]):
if target == "user":
self.user_entries = entries
else:
self.memory_entries = entries
def _char_count(self, target: str) -> int:
entries = self._entries_for(target)
if not entries:
return 0
return len(ENTRY_DELIMITER.join(entries))
def _char_limit(self, target: str) -> int:
if target == "user":
return self.user_char_limit
return self.memory_char_limit
def add(self, target: str, content: str) -> Dict[str, Any]:
"""Append a new entry. Returns error if it would exceed the char limit."""
content = content.strip()
if not content:
return {"success": False, "error": "Content cannot be empty."}
# Scan for injection/exfiltration before accepting
scan_error = _scan_memory_content(content)
if scan_error:
return {"success": False, "error": scan_error}
with self._file_lock(self._path_for(target)):
# Re-read from disk under lock to pick up writes from other sessions
self._reload_target(target)
entries = self._entries_for(target)
limit = self._char_limit(target)
# Reject exact duplicates
if content in entries:
return self._success_response(target, "Entry already exists (no duplicate added).")
# Calculate what the new total would be
new_entries = entries + [content]
new_total = len(ENTRY_DELIMITER.join(new_entries))
if new_total > limit:
current = self._char_count(target)
return {
"success": False,
"error": (
f"Memory at {current:,}/{limit:,} chars. "
f"Adding this entry ({len(content)} chars) would exceed the limit. "
f"Replace or remove existing entries first."
),
"current_entries": entries,
"usage": f"{current:,}/{limit:,}",
}
entries.append(content)
self._set_entries(target, entries)
self.save_to_disk(target)
return self._success_response(target, "Entry added.")
def replace(self, target: str, old_text: str, new_content: str) -> Dict[str, Any]:
"""Find entry containing old_text substring, replace it with new_content."""
old_text = old_text.strip()
new_content = new_content.strip()
if not old_text:
return {"success": False, "error": "old_text cannot be empty."}
if not new_content:
return {"success": False, "error": "new_content cannot be empty. Use 'remove' to delete entries."}
# Scan replacement content for injection/exfiltration
scan_error = _scan_memory_content(new_content)
if scan_error:
return {"success": False, "error": scan_error}
with self._file_lock(self._path_for(target)):
self._reload_target(target)
entries = self._entries_for(target)
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
if len(matches) == 0:
return {"success": False, "error": f"No entry matched '{old_text}'."}
if len(matches) > 1:
# If all matches are identical (exact duplicates), operate on the first one
unique_texts = set(e for _, e in matches)
if len(unique_texts) > 1:
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
return {
"success": False,
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
"matches": previews,
}
# All identical -- safe to replace just the first
idx = matches[0][0]
limit = self._char_limit(target)
# Check that replacement doesn't blow the budget
test_entries = entries.copy()
test_entries[idx] = new_content
new_total = len(ENTRY_DELIMITER.join(test_entries))
if new_total > limit:
return {
"success": False,
"error": (
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
f"Shorten the new content or remove other entries first."
),
}
entries[idx] = new_content
self._set_entries(target, entries)
self.save_to_disk(target)
return self._success_response(target, "Entry replaced.")
def remove(self, target: str, old_text: str) -> Dict[str, Any]:
"""Remove the entry containing old_text substring."""
old_text = old_text.strip()
if not old_text:
return {"success": False, "error": "old_text cannot be empty."}
with self._file_lock(self._path_for(target)):
self._reload_target(target)
entries = self._entries_for(target)
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
if len(matches) == 0:
return {"success": False, "error": f"No entry matched '{old_text}'."}
if len(matches) > 1:
# If all matches are identical (exact duplicates), remove the first one
unique_texts = set(e for _, e in matches)
if len(unique_texts) > 1:
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
return {
"success": False,
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
"matches": previews,
}
# All identical -- safe to remove just the first
idx = matches[0][0]
entries.pop(idx)
self._set_entries(target, entries)
self.save_to_disk(target)
return self._success_response(target, "Entry removed.")
def format_for_system_prompt(self, target: str) -> Optional[str]:
"""
Return the frozen snapshot for system prompt injection.
This returns the state captured at load_from_disk() time, NOT the live
state. Mid-session writes do not affect this. This keeps the system
prompt stable across all turns, preserving the prefix cache.
Returns None if the snapshot is empty (no entries at load time).
"""
block = self._system_prompt_snapshot.get(target, "")
return block if block else None
# -- Internal helpers --
def _success_response(self, target: str, message: str = None) -> Dict[str, Any]:
entries = self._entries_for(target)
current = self._char_count(target)
limit = self._char_limit(target)
pct = int((current / limit) * 100) if limit > 0 else 0
resp = {
"success": True,
"target": target,
"entries": entries,
"usage": f"{pct}% — {current:,}/{limit:,} chars",
"entry_count": len(entries),
}
if message:
resp["message"] = message
return resp
def _render_block(self, target: str, entries: List[str]) -> str:
"""Render a system prompt block with header and usage indicator."""
if not entries:
return ""
limit = self._char_limit(target)
content = ENTRY_DELIMITER.join(entries)
current = len(content)
pct = int((current / limit) * 100) if limit > 0 else 0
if target == "user":
header = f"USER PROFILE (who the user is) [{pct}% — {current:,}/{limit:,} chars]"
else:
header = f"MEMORY (your personal notes) [{pct}% — {current:,}/{limit:,} chars]"
separator = "" * 46
return f"{separator}\n{header}\n{separator}\n{content}"
@staticmethod
def _read_file(path: Path) -> List[str]:
"""Read a memory file and split into entries.
No file locking needed: _write_file uses atomic rename, so readers
always see either the previous complete file or the new complete file.
"""
if not path.exists():
return []
try:
raw = path.read_text(encoding="utf-8")
except (OSError, IOError):
return []
if not raw.strip():
return []
# Use ENTRY_DELIMITER for consistency with _write_file. Splitting by "§"
# alone would incorrectly split entries that contain "§" in their content.
entries = [e.strip() for e in raw.split(ENTRY_DELIMITER)]
return [e for e in entries if e]
@staticmethod
def _write_file(path: Path, entries: List[str]):
"""Write entries to a memory file using atomic temp-file + rename.
Previous implementation used open("w") + flock, but "w" truncates the
file *before* the lock is acquired, creating a race window where
concurrent readers see an empty file. Atomic rename avoids this:
readers always see either the old complete file or the new one.
"""
content = ENTRY_DELIMITER.join(entries) if entries else ""
try:
# Write to temp file in same directory (same filesystem for atomic rename)
fd, tmp_path = tempfile.mkstemp(
dir=str(path.parent), suffix=".tmp", prefix=".mem_"
)
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(content)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, str(path)) # Atomic on same filesystem
except BaseException:
# Clean up temp file on any failure
try:
os.unlink(tmp_path)
except OSError:
pass
raise
except (OSError, IOError) as e:
raise RuntimeError(f"Failed to write memory file {path}: {e}")
def memory_tool(
action: str,
target: str = "memory",
content: str = None,
old_text: str = None,
store: Optional[MemoryStore] = None,
) -> str:
"""
Single entry point for the memory tool. Dispatches to MemoryStore methods.
Returns JSON string with results.
"""
if store is None:
return json.dumps({"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, ensure_ascii=False)
if target not in ("memory", "user"):
return json.dumps({"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False)
if action == "add":
if not content:
return json.dumps({"success": False, "error": "Content is required for 'add' action."}, ensure_ascii=False)
result = store.add(target, content)
elif action == "replace":
if not old_text:
return json.dumps({"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False)
if not content:
return json.dumps({"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False)
result = store.replace(target, old_text, content)
elif action == "remove":
if not old_text:
return json.dumps({"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False)
result = store.remove(target, old_text)
else:
return json.dumps({"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False)
return json.dumps(result, ensure_ascii=False)
def check_memory_requirements() -> bool:
"""Memory tool has no external requirements -- always available."""
return True
# =============================================================================
# OpenAI Function-Calling Schema
# =============================================================================
MEMORY_SCHEMA = {
"name": "memory",
"description": (
"Save durable information to persistent memory that survives across sessions. "
"Memory is injected into future turns, so keep it compact and focused on facts "
"that will still matter later.\n\n"
"WHEN TO SAVE (do this proactively, don't wait to be asked):\n"
"- User corrects you or says 'remember this' / 'don't do that again'\n"
"- User shares a preference, habit, or personal detail (name, role, timezone, coding style)\n"
"- You discover something about the environment (OS, installed tools, project structure)\n"
"- You learn a convention, API quirk, or workflow specific to this user's setup\n"
"- You identify a stable fact that will be useful again in future sessions\n\n"
"PRIORITY: User preferences and corrections > environment facts > procedural knowledge. "
"The most valuable memory prevents the user from having to repeat themselves.\n\n"
"Do NOT save task progress, session outcomes, completed-work logs, or temporary TODO "
"state to memory; use session_search to recall those from past transcripts.\n"
"If you've discovered a new way to do something, solved a problem that could be "
"necessary later, save it as a skill with the skill tool.\n\n"
"TWO TARGETS:\n"
"- 'user': who the user is -- name, role, preferences, communication style, pet peeves\n"
"- 'memory': your notes -- environment facts, project conventions, tool quirks, lessons learned\n\n"
"ACTIONS: add (new entry), replace (update existing -- old_text identifies it), "
"remove (delete -- old_text identifies it).\n\n"
"SKIP: trivial/obvious info, things easily re-discovered, raw data dumps, and temporary task state."
),
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["add", "replace", "remove"],
"description": "The action to perform."
},
"target": {
"type": "string",
"enum": ["memory", "user"],
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile."
},
"content": {
"type": "string",
"description": "The entry content. Required for 'add' and 'replace'."
},
"old_text": {
"type": "string",
"description": "Short unique substring identifying the entry to replace or remove."
},
},
"required": ["action", "target"],
},
}
# --- Registry ---
from tools.registry import registry
registry.register(
name="memory",
toolset="memory",
schema=MEMORY_SCHEMA,
handler=lambda args, **kw: memory_tool(
action=args.get("action", ""),
target=args.get("target", "memory"),
content=args.get("content"),
old_text=args.get("old_text"),
store=kw.get("store")),
check_fn=check_memory_requirements,
emoji="🧠",
)

View file

@ -0,0 +1,548 @@
#!/usr/bin/env python3
"""
Mixture-of-Agents Tool Module
This module implements the Mixture-of-Agents (MoA) methodology that leverages
the collective strengths of multiple LLMs through a layered architecture to
achieve state-of-the-art performance on complex reasoning tasks.
Based on the research paper: "Mixture-of-Agents Enhances Large Language Model Capabilities"
by Junlin Wang et al. (arXiv:2406.04692v1)
Key Features:
- Multi-layer LLM collaboration for enhanced reasoning
- Parallel processing of reference models for efficiency
- Intelligent aggregation and synthesis of diverse responses
- Specialized for extremely difficult problems requiring intense reasoning
- Optimized for coding, mathematics, and complex analytical tasks
Available Tool:
- mixture_of_agents_tool: Process complex queries using multiple frontier models
Architecture:
1. Reference models generate diverse initial responses in parallel
2. Aggregator model synthesizes responses into a high-quality output
3. Multiple layers can be used for iterative refinement (future enhancement)
Models Used (via OpenRouter):
- Reference Models: claude-opus-4.6, gemini-3-pro-preview, gpt-5.4-pro, deepseek-v3.2
- Aggregator Model: claude-opus-4.6 (highest capability for synthesis)
Configuration:
To customize the MoA setup, modify the configuration constants at the top of this file:
- REFERENCE_MODELS: List of models for generating diverse initial responses
- AGGREGATOR_MODEL: Model used to synthesize the final response
- REFERENCE_TEMPERATURE/AGGREGATOR_TEMPERATURE: Sampling temperatures
- MIN_SUCCESSFUL_REFERENCES: Minimum successful models needed to proceed
Usage:
from mixture_of_agents_tool import mixture_of_agents_tool
import asyncio
# Process a complex query
result = await mixture_of_agents_tool(
user_prompt="Solve this complex mathematical proof..."
)
"""
import json
import logging
import os
import asyncio
import datetime
from typing import Dict, Any, List, Optional
from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key
from tools.debug_helpers import DebugSession
logger = logging.getLogger(__name__)
# Configuration for MoA processing
# Reference models - these generate diverse initial responses in parallel.
# Keep this list aligned with current top-tier OpenRouter frontier options.
REFERENCE_MODELS = [
"anthropic/claude-opus-4.6",
"google/gemini-3-pro-preview",
"openai/gpt-5.4-pro",
"deepseek/deepseek-v3.2",
]
# Aggregator model - synthesizes reference responses into final output.
# Prefer the strongest synthesis model in the current OpenRouter lineup.
AGGREGATOR_MODEL = "anthropic/claude-opus-4.6"
# Temperature settings optimized for MoA performance
REFERENCE_TEMPERATURE = 0.6 # Balanced creativity for diverse perspectives
AGGREGATOR_TEMPERATURE = 0.4 # Focused synthesis for consistency
# Failure handling configuration
MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed
# System prompt for the aggregator model (from the research paper)
AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
Responses from models:"""
_debug = DebugSession("moa_tools", env_var="MOA_TOOLS_DEBUG")
def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str:
"""
Construct the final system prompt for the aggregator including all model responses.
Args:
system_prompt (str): Base system prompt for aggregation
responses (List[str]): List of responses from reference models
Returns:
str: Complete system prompt with enumerated responses
"""
response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)])
return f"{system_prompt}\n\n{response_text}"
async def _run_reference_model_safe(
model: str,
user_prompt: str,
temperature: float = REFERENCE_TEMPERATURE,
max_tokens: int = 32000,
max_retries: int = 6
) -> tuple[str, str, bool]:
"""
Run a single reference model with retry logic and graceful failure handling.
Args:
model (str): Model identifier to use
user_prompt (str): The user's query
temperature (float): Sampling temperature for response generation
max_tokens (int): Maximum tokens in response
max_retries (int): Maximum number of retry attempts
Returns:
tuple[str, str, bool]: (model_name, response_content_or_error, success_flag)
"""
for attempt in range(max_retries):
try:
logger.info("Querying %s (attempt %s/%s)", model, attempt + 1, max_retries)
# Build parameters for the API call
api_params = {
"model": model,
"messages": [{"role": "user", "content": user_prompt}],
"extra_body": {
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
}
# GPT models (especially gpt-4o-mini) don't support custom temperature values
# Only include temperature for non-GPT models
if not model.lower().startswith('gpt-'):
api_params["temperature"] = temperature
response = await _get_openrouter_client().chat.completions.create(**api_params)
content = response.choices[0].message.content.strip()
logger.info("%s responded (%s characters)", model, len(content))
return model, content, True
except Exception as e:
error_str = str(e)
# Keep retry-path logging concise; full tracebacks are reserved for
# terminal failure paths so long-running MoA retries don't flood logs.
if "invalid" in error_str.lower():
logger.warning("%s invalid request error (attempt %s): %s", model, attempt + 1, error_str)
elif "rate" in error_str.lower() or "limit" in error_str.lower():
logger.warning("%s rate limit error (attempt %s): %s", model, attempt + 1, error_str)
else:
logger.warning("%s unknown error (attempt %s): %s", model, attempt + 1, error_str)
if attempt < max_retries - 1:
# Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s
sleep_time = min(2 ** (attempt + 1), 60)
logger.info("Retrying in %ss...", sleep_time)
await asyncio.sleep(sleep_time)
else:
error_msg = f"{model} failed after {max_retries} attempts: {error_str}"
logger.error("%s", error_msg, exc_info=True)
return model, error_msg, False
async def _run_aggregator_model(
system_prompt: str,
user_prompt: str,
temperature: float = AGGREGATOR_TEMPERATURE,
max_tokens: int = None
) -> str:
"""
Run the aggregator model to synthesize the final response.
Args:
system_prompt (str): System prompt with all reference responses
user_prompt (str): Original user query
temperature (float): Focused temperature for consistent aggregation
max_tokens (int): Maximum tokens in final response
Returns:
str: Synthesized final response
"""
logger.info("Running aggregator model: %s", AGGREGATOR_MODEL)
# Build parameters for the API call
api_params = {
"model": AGGREGATOR_MODEL,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"extra_body": {
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
}
# GPT models (especially gpt-4o-mini) don't support custom temperature values
# Only include temperature for non-GPT models
if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
api_params["temperature"] = temperature
response = await _get_openrouter_client().chat.completions.create(**api_params)
content = response.choices[0].message.content.strip()
logger.info("Aggregation complete (%s characters)", len(content))
return content
async def mixture_of_agents_tool(
user_prompt: str,
reference_models: Optional[List[str]] = None,
aggregator_model: Optional[str] = None
) -> str:
"""
Process a complex query using the Mixture-of-Agents methodology.
This tool leverages multiple frontier language models to collaboratively solve
extremely difficult problems requiring intense reasoning. It's particularly
effective for:
- Complex mathematical proofs and calculations
- Advanced coding problems and algorithm design
- Multi-step analytical reasoning tasks
- Problems requiring diverse domain expertise
- Tasks where single models show limitations
The MoA approach uses a fixed 2-layer architecture:
1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6)
2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4)
Args:
user_prompt (str): The complex query or problem to solve
reference_models (Optional[List[str]]): Custom reference models to use
aggregator_model (Optional[str]): Custom aggregator model to use
Returns:
str: JSON string containing the MoA results with the following structure:
{
"success": bool,
"response": str,
"models_used": {
"reference_models": List[str],
"aggregator_model": str
},
"processing_time": float
}
Raises:
Exception: If MoA processing fails or API key is not set
"""
start_time = datetime.datetime.now()
debug_call_data = {
"parameters": {
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
"reference_models": reference_models or REFERENCE_MODELS,
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
"reference_temperature": REFERENCE_TEMPERATURE,
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
"min_successful_references": MIN_SUCCESSFUL_REFERENCES
},
"error": None,
"success": False,
"reference_responses_count": 0,
"failed_models_count": 0,
"failed_models": [],
"final_response_length": 0,
"processing_time_seconds": 0,
"models_used": {}
}
try:
logger.info("Starting Mixture-of-Agents processing...")
logger.info("Query: %s", user_prompt[:100])
# Validate API key availability
if not os.getenv("OPENROUTER_API_KEY"):
raise ValueError("OPENROUTER_API_KEY environment variable not set")
# Use provided models or defaults
ref_models = reference_models or REFERENCE_MODELS
agg_model = aggregator_model or AGGREGATOR_MODEL
logger.info("Using %s reference models in 2-layer MoA architecture", len(ref_models))
# Layer 1: Generate diverse responses from reference models (with failure handling)
logger.info("Layer 1: Generating reference responses...")
model_results = await asyncio.gather(*[
_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE)
for model in ref_models
])
# Separate successful and failed responses
successful_responses = []
failed_models = []
for model_name, content, success in model_results:
if success:
successful_responses.append(content)
else:
failed_models.append(model_name)
successful_count = len(successful_responses)
failed_count = len(failed_models)
logger.info("Reference model results: %s successful, %s failed", successful_count, failed_count)
if failed_models:
logger.warning("Failed models: %s", ', '.join(failed_models))
# Check if we have enough successful responses to proceed
if successful_count < MIN_SUCCESSFUL_REFERENCES:
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
debug_call_data["reference_responses_count"] = successful_count
debug_call_data["failed_models_count"] = failed_count
debug_call_data["failed_models"] = failed_models
# Layer 2: Aggregate responses using the aggregator model
logger.info("Layer 2: Synthesizing final response...")
aggregator_system_prompt = _construct_aggregator_prompt(
AGGREGATOR_SYSTEM_PROMPT,
successful_responses
)
final_response = await _run_aggregator_model(
aggregator_system_prompt,
user_prompt,
AGGREGATOR_TEMPERATURE
)
# Calculate processing time
end_time = datetime.datetime.now()
processing_time = (end_time - start_time).total_seconds()
logger.info("MoA processing completed in %.2f seconds", processing_time)
# Prepare successful response (only final aggregated result, minimal fields)
result = {
"success": True,
"response": final_response,
"models_used": {
"reference_models": ref_models,
"aggregator_model": agg_model
}
}
debug_call_data["success"] = True
debug_call_data["final_response_length"] = len(final_response)
debug_call_data["processing_time_seconds"] = processing_time
debug_call_data["models_used"] = result["models_used"]
# Log debug information
_debug.log_call("mixture_of_agents_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
except Exception as e:
error_msg = f"Error in MoA processing: {str(e)}"
logger.error("%s", error_msg, exc_info=True)
# Calculate processing time even for errors
end_time = datetime.datetime.now()
processing_time = (end_time - start_time).total_seconds()
# Prepare error response (minimal fields)
result = {
"success": False,
"response": "MoA processing failed. Please try again or use a single model for this query.",
"models_used": {
"reference_models": reference_models or REFERENCE_MODELS,
"aggregator_model": aggregator_model or AGGREGATOR_MODEL
},
"error": error_msg
}
debug_call_data["error"] = error_msg
debug_call_data["processing_time_seconds"] = processing_time
_debug.log_call("mixture_of_agents_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
def check_moa_requirements() -> bool:
"""
Check if all requirements for MoA tools are met.
Returns:
bool: True if requirements are met, False otherwise
"""
return check_openrouter_api_key()
def get_debug_session_info() -> Dict[str, Any]:
"""
Get information about the current debug session.
Returns:
Dict[str, Any]: Dictionary containing debug session information
"""
return _debug.get_session_info()
def get_available_models() -> Dict[str, List[str]]:
"""
Get information about available models for MoA processing.
Returns:
Dict[str, List[str]]: Dictionary with reference and aggregator models
"""
return {
"reference_models": REFERENCE_MODELS,
"aggregator_models": [AGGREGATOR_MODEL],
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL]
}
def get_moa_configuration() -> Dict[str, Any]:
"""
Get the current MoA configuration settings.
Returns:
Dict[str, Any]: Dictionary containing all configuration parameters
"""
return {
"reference_models": REFERENCE_MODELS,
"aggregator_model": AGGREGATOR_MODEL,
"reference_temperature": REFERENCE_TEMPERATURE,
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
"total_reference_models": len(REFERENCE_MODELS),
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail"
}
if __name__ == "__main__":
"""
Simple test/demo when run directly
"""
print("🤖 Mixture-of-Agents Tool Module")
print("=" * 50)
# Check if API key is available
api_available = check_openrouter_api_key()
if not api_available:
print("❌ OPENROUTER_API_KEY environment variable not set")
print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'")
print("Get API key at: https://openrouter.ai/")
exit(1)
else:
print("✅ OpenRouter API key found")
print("🛠️ MoA tools ready for use!")
# Show current configuration
config = get_moa_configuration()
print(f"\n⚙️ Current Configuration:")
print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}")
print(f" 🧠 Aggregator model: {config['aggregator_model']}")
print(f" 🌡️ Reference temperature: {config['reference_temperature']}")
print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}")
print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}")
print(f" 📊 Minimum successful models: {config['min_successful_references']}")
# Show debug mode status
if _debug.active:
print(f"\n🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{_debug.session_id}.json")
else:
print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)")
print("\nBasic usage:")
print(" from mixture_of_agents_tool import mixture_of_agents_tool")
print(" import asyncio")
print("")
print(" async def main():")
print(" result = await mixture_of_agents_tool(")
print(" user_prompt='Solve this complex mathematical proof...'")
print(" )")
print(" print(result)")
print(" asyncio.run(main())")
print("\nBest use cases:")
print(" - Complex mathematical proofs and calculations")
print(" - Advanced coding problems and algorithm design")
print(" - Multi-step analytical reasoning tasks")
print(" - Problems requiring diverse domain expertise")
print(" - Tasks where single models show limitations")
print("\nPerformance characteristics:")
print(" - Higher latency due to multiple model calls")
print(" - Significantly improved quality for complex tasks")
print(" - Parallel processing for efficiency")
print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation")
print(" - Token-efficient: only returns final aggregated response")
print(" - Resilient: continues with partial model failures")
print(f" - Configurable: easy to modify models and settings at top of file")
print(" - State-of-the-art results on challenging benchmarks")
print("\nDebug mode:")
print(" # Enable debug logging")
print(" export MOA_TOOLS_DEBUG=true")
print(" # Debug logs capture all MoA processing steps and metrics")
print(" # Logs saved to: ./logs/moa_tools_debug_UUID.json")
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
from tools.registry import registry
MOA_SCHEMA = {
"name": "mixture_of_agents",
"description": "Route a hard problem through multiple frontier LLMs collaboratively. Makes 5 API calls (4 reference models + 1 aggregator) with maximum reasoning effort — use sparingly for genuinely difficult problems. Best for: complex math, advanced algorithms, multi-step analytical reasoning, problems benefiting from diverse perspectives.",
"parameters": {
"type": "object",
"properties": {
"user_prompt": {
"type": "string",
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning."
}
},
"required": ["user_prompt"]
}
}
registry.register(
name="mixture_of_agents",
toolset="moa",
schema=MOA_SCHEMA,
handler=lambda args, **kw: mixture_of_agents_tool(user_prompt=args.get("user_prompt", "")),
check_fn=check_moa_requirements,
requires_env=["OPENROUTER_API_KEY"],
is_async=True,
emoji="🧠",
)

View file

@ -0,0 +1 @@
So I just tried Neuphonic and Im genuinely impressed. It's super responsive, it sounds clean, supports voice cloning, and the agent feature is fun to play with too. Highly recommend it for podcasts, conversations, or even just messing around with voiceovers.

Binary file not shown.

View file

@ -0,0 +1,104 @@
#!/usr/bin/env python3
"""Standalone NeuTTS synthesis helper.
Called by tts_tool.py via subprocess to keep the TTS model (~500MB)
in a separate process that exits after synthesis no lingering memory.
Usage:
python -m tools.neutts_synth --text "Hello" --out output.wav \
--ref-audio samples/jo.wav --ref-text samples/jo.txt
Requires: python -m pip install -U neutts[all]
System: apt install espeak-ng (or brew install espeak-ng)
"""
import argparse
import struct
import sys
from pathlib import Path
def _write_wav(path: str, samples, sample_rate: int = 24000) -> None:
"""Write a WAV file from float32 samples (no soundfile dependency)."""
import numpy as np
if not isinstance(samples, np.ndarray):
samples = np.array(samples, dtype=np.float32)
samples = samples.flatten()
# Clamp and convert to int16
samples = np.clip(samples, -1.0, 1.0)
pcm = (samples * 32767).astype(np.int16)
num_channels = 1
bits_per_sample = 16
byte_rate = sample_rate * num_channels * (bits_per_sample // 8)
block_align = num_channels * (bits_per_sample // 8)
data_size = len(pcm) * (bits_per_sample // 8)
with open(path, "wb") as f:
f.write(b"RIFF")
f.write(struct.pack("<I", 36 + data_size))
f.write(b"WAVE")
f.write(b"fmt ")
f.write(struct.pack("<IHHIIHH", 16, 1, num_channels, sample_rate,
byte_rate, block_align, bits_per_sample))
f.write(b"data")
f.write(struct.pack("<I", data_size))
f.write(pcm.tobytes())
def main():
parser = argparse.ArgumentParser(description="NeuTTS synthesis helper")
parser.add_argument("--text", required=True, help="Text to synthesize")
parser.add_argument("--out", required=True, help="Output WAV path")
parser.add_argument("--ref-audio", required=True, help="Reference voice audio path")
parser.add_argument("--ref-text", required=True, help="Reference voice transcript path")
parser.add_argument("--model", default="neuphonic/neutts-air-q4-gguf",
help="HuggingFace backbone model repo")
parser.add_argument("--device", default="cpu", help="Device (cpu/cuda/mps)")
args = parser.parse_args()
# Validate inputs
ref_audio = Path(args.ref_audio).expanduser()
ref_text_path = Path(args.ref_text).expanduser()
if not ref_audio.exists():
print(f"Error: reference audio not found: {ref_audio}", file=sys.stderr)
sys.exit(1)
if not ref_text_path.exists():
print(f"Error: reference text not found: {ref_text_path}", file=sys.stderr)
sys.exit(1)
ref_text = ref_text_path.read_text(encoding="utf-8").strip()
# Import and run NeuTTS
try:
from neutts import NeuTTS
except ImportError:
print("Error: neutts not installed. Run: python -m pip install -U neutts[all]", file=sys.stderr)
sys.exit(1)
tts = NeuTTS(
backbone_repo=args.model,
backbone_device=args.device,
codec_repo="neuphonic/neucodec",
codec_device=args.device,
)
ref_codes = tts.encode_reference(str(ref_audio))
wav = tts.infer(args.text, ref_codes, ref_text)
# Write output
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
try:
import soundfile as sf
sf.write(str(out_path), wav, 24000)
except ImportError:
_write_wav(str(out_path), wav, 24000)
print(f"OK: {out_path}", file=sys.stderr)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,33 @@
"""Shared OpenRouter API client for Hermes tools.
Provides a single lazy-initialized AsyncOpenAI client that all tool modules
can share. Routes through the centralized provider router in
agent/auxiliary_client.py so auth, headers, and API format are handled
consistently.
"""
import os
_client = None
def get_async_client():
"""Return a shared async OpenAI-compatible client for OpenRouter.
The client is created lazily on first call and reused thereafter.
Uses the centralized provider router for auth and client construction.
Raises ValueError if OPENROUTER_API_KEY is not set.
"""
global _client
if _client is None:
from agent.auxiliary_client import resolve_provider_client
client, _model = resolve_provider_client("openrouter", async_mode=True)
if client is None:
raise ValueError("OPENROUTER_API_KEY environment variable not set")
_client = client
return _client
def check_api_key() -> bool:
"""Check whether the OpenRouter API key is present."""
return bool(os.getenv("OPENROUTER_API_KEY"))

View file

@ -0,0 +1,438 @@
#!/usr/bin/env python3
"""
V4A Patch Format Parser
Parses the V4A patch format used by codex, cline, and other coding agents.
V4A Format:
*** Begin Patch
*** Update File: path/to/file.py
@@ optional context hint @@
context line (space prefix)
-removed line (minus prefix)
+added line (plus prefix)
*** Add File: path/to/new.py
+new file content
+line 2
*** Delete File: path/to/old.py
*** Move File: old/path.py -> new/path.py
*** End Patch
Usage:
from tools.patch_parser import parse_v4a_patch, apply_v4a_operations
operations, error = parse_v4a_patch(patch_content)
if error:
print(f"Parse error: {error}")
else:
result = apply_v4a_operations(operations, file_ops)
"""
import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Any
from enum import Enum
class OperationType(Enum):
ADD = "add"
UPDATE = "update"
DELETE = "delete"
MOVE = "move"
@dataclass
class HunkLine:
"""A single line in a patch hunk."""
prefix: str # ' ', '-', or '+'
content: str
@dataclass
class Hunk:
"""A group of changes within a file."""
context_hint: Optional[str] = None
lines: List[HunkLine] = field(default_factory=list)
@dataclass
class PatchOperation:
"""A single operation in a V4A patch."""
operation: OperationType
file_path: str
new_path: Optional[str] = None # For move operations
hunks: List[Hunk] = field(default_factory=list)
content: Optional[str] = None # For add file operations
def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]:
"""
Parse a V4A format patch.
Args:
patch_content: The patch text in V4A format
Returns:
Tuple of (operations, error_message)
- If successful: (list_of_operations, None)
- If failed: ([], error_description)
"""
lines = patch_content.split('\n')
operations: List[PatchOperation] = []
# Find patch boundaries
start_idx = None
end_idx = None
for i, line in enumerate(lines):
if '*** Begin Patch' in line or '***Begin Patch' in line:
start_idx = i
elif '*** End Patch' in line or '***End Patch' in line:
end_idx = i
break
if start_idx is None:
# Try to parse without explicit begin marker
start_idx = -1
if end_idx is None:
end_idx = len(lines)
# Parse operations between boundaries
i = start_idx + 1
current_op: Optional[PatchOperation] = None
current_hunk: Optional[Hunk] = None
while i < end_idx:
line = lines[i]
# Check for file operation markers
update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line)
add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line)
delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line)
move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line)
if update_match:
# Save previous operation
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.UPDATE,
file_path=update_match.group(1).strip()
)
current_hunk = None
elif add_match:
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.ADD,
file_path=add_match.group(1).strip()
)
current_hunk = Hunk()
elif delete_match:
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.DELETE,
file_path=delete_match.group(1).strip()
)
operations.append(current_op)
current_op = None
current_hunk = None
elif move_match:
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
current_op = PatchOperation(
operation=OperationType.MOVE,
file_path=move_match.group(1).strip(),
new_path=move_match.group(2).strip()
)
operations.append(current_op)
current_op = None
current_hunk = None
elif line.startswith('@@'):
# Context hint / hunk marker
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
# Extract context hint
hint_match = re.match(r'@@\s*(.+?)\s*@@', line)
hint = hint_match.group(1) if hint_match else None
current_hunk = Hunk(context_hint=hint)
elif current_op and line:
# Parse hunk line
if current_hunk is None:
current_hunk = Hunk()
if line.startswith('+'):
current_hunk.lines.append(HunkLine('+', line[1:]))
elif line.startswith('-'):
current_hunk.lines.append(HunkLine('-', line[1:]))
elif line.startswith(' '):
current_hunk.lines.append(HunkLine(' ', line[1:]))
elif line.startswith('\\'):
# "\ No newline at end of file" marker - skip
pass
else:
# Treat as context line (implicit space prefix)
current_hunk.lines.append(HunkLine(' ', line))
i += 1
# Don't forget the last operation
if current_op:
if current_hunk and current_hunk.lines:
current_op.hunks.append(current_hunk)
operations.append(current_op)
return operations, None
def apply_v4a_operations(operations: List[PatchOperation],
file_ops: Any) -> 'PatchResult':
"""
Apply V4A patch operations using a file operations interface.
Args:
operations: List of PatchOperation from parse_v4a_patch
file_ops: Object with read_file, write_file methods
Returns:
PatchResult with results of all operations
"""
# Import here to avoid circular imports
from tools.file_operations import PatchResult
files_modified = []
files_created = []
files_deleted = []
all_diffs = []
errors = []
for op in operations:
try:
if op.operation == OperationType.ADD:
result = _apply_add(op, file_ops)
if result[0]:
files_created.append(op.file_path)
all_diffs.append(result[1])
else:
errors.append(f"Failed to add {op.file_path}: {result[1]}")
elif op.operation == OperationType.DELETE:
result = _apply_delete(op, file_ops)
if result[0]:
files_deleted.append(op.file_path)
all_diffs.append(result[1])
else:
errors.append(f"Failed to delete {op.file_path}: {result[1]}")
elif op.operation == OperationType.MOVE:
result = _apply_move(op, file_ops)
if result[0]:
files_modified.append(f"{op.file_path} -> {op.new_path}")
all_diffs.append(result[1])
else:
errors.append(f"Failed to move {op.file_path}: {result[1]}")
elif op.operation == OperationType.UPDATE:
result = _apply_update(op, file_ops)
if result[0]:
files_modified.append(op.file_path)
all_diffs.append(result[1])
else:
errors.append(f"Failed to update {op.file_path}: {result[1]}")
except Exception as e:
errors.append(f"Error processing {op.file_path}: {str(e)}")
# Run lint on all modified/created files
lint_results = {}
for f in files_modified + files_created:
if hasattr(file_ops, '_check_lint'):
lint_result = file_ops._check_lint(f)
lint_results[f] = lint_result.to_dict()
combined_diff = '\n'.join(all_diffs)
if errors:
return PatchResult(
success=False,
diff=combined_diff,
files_modified=files_modified,
files_created=files_created,
files_deleted=files_deleted,
lint=lint_results if lint_results else None,
error='; '.join(errors)
)
return PatchResult(
success=True,
diff=combined_diff,
files_modified=files_modified,
files_created=files_created,
files_deleted=files_deleted,
lint=lint_results if lint_results else None
)
def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
"""Apply an add file operation."""
# Extract content from hunks (all + lines)
content_lines = []
for hunk in op.hunks:
for line in hunk.lines:
if line.prefix == '+':
content_lines.append(line.content)
content = '\n'.join(content_lines)
result = file_ops.write_file(op.file_path, content)
if result.error:
return False, result.error
diff = f"--- /dev/null\n+++ b/{op.file_path}\n"
diff += '\n'.join(f"+{line}" for line in content_lines)
return True, diff
def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
"""Apply a delete file operation."""
# Read file first for diff
read_result = file_ops.read_file(op.file_path)
if read_result.error and "not found" in read_result.error.lower():
# File doesn't exist, nothing to delete
return True, f"# {op.file_path} already deleted or doesn't exist"
# Delete directly via shell command using the underlying environment
rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}")
if rm_result.exit_code != 0:
return False, rm_result.stdout
diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted"
return True, diff
def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
"""Apply a move file operation."""
# Use shell mv command
mv_result = file_ops._exec(
f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}"
)
if mv_result.exit_code != 0:
return False, mv_result.stdout
diff = f"# Moved: {op.file_path} -> {op.new_path}"
return True, diff
def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
"""Apply an update file operation."""
# Read current content
read_result = file_ops.read_file(op.file_path, limit=10000)
if read_result.error:
return False, f"Cannot read file: {read_result.error}"
# Parse content (remove line numbers)
current_lines = []
for line in read_result.content.split('\n'):
if re.match(r'^\s*\d+\|', line):
# Line format: " 123|content"
parts = line.split('|', 1)
if len(parts) == 2:
current_lines.append(parts[1])
else:
current_lines.append(line)
else:
current_lines.append(line)
current_content = '\n'.join(current_lines)
# Apply each hunk
new_content = current_content
for hunk in op.hunks:
# Build search pattern from context and removed lines
search_lines = []
replace_lines = []
for line in hunk.lines:
if line.prefix == ' ':
search_lines.append(line.content)
replace_lines.append(line.content)
elif line.prefix == '-':
search_lines.append(line.content)
elif line.prefix == '+':
replace_lines.append(line.content)
if search_lines:
search_pattern = '\n'.join(search_lines)
replacement = '\n'.join(replace_lines)
# Use fuzzy matching
from tools.fuzzy_match import fuzzy_find_and_replace
new_content, count, error = fuzzy_find_and_replace(
new_content, search_pattern, replacement, replace_all=False
)
if error and count == 0:
# Try with context hint if available
if hunk.context_hint:
# Find the context hint location and search nearby
hint_pos = new_content.find(hunk.context_hint)
if hint_pos != -1:
# Search in a window around the hint
window_start = max(0, hint_pos - 500)
window_end = min(len(new_content), hint_pos + 2000)
window = new_content[window_start:window_end]
window_new, count, error = fuzzy_find_and_replace(
window, search_pattern, replacement, replace_all=False
)
if count > 0:
new_content = new_content[:window_start] + window_new + new_content[window_end:]
error = None
if error:
return False, f"Could not apply hunk: {error}"
# Write new content
write_result = file_ops.write_file(op.file_path, new_content)
if write_result.error:
return False, write_result.error
# Generate diff
import difflib
diff_lines = difflib.unified_diff(
current_content.splitlines(keepends=True),
new_content.splitlines(keepends=True),
fromfile=f"a/{op.file_path}",
tofile=f"b/{op.file_path}"
)
diff = ''.join(diff_lines)
return True, diff

View file

@ -0,0 +1,891 @@
"""
Process Registry -- In-memory registry for managed background processes.
Tracks processes spawned via terminal(background=true), providing:
- Output buffering (rolling 200KB window)
- Status polling and log retrieval
- Blocking wait with interrupt support
- Process killing
- Crash recovery via JSON checkpoint file
- Session-scoped tracking for gateway reset protection
Background processes execute THROUGH the environment interface -- nothing
runs on the host machine unless TERMINAL_ENV=local. For Docker, Singularity,
Modal, Daytona, and SSH backends, the command runs inside the sandbox.
Usage:
from tools.process_registry import process_registry
# Spawn a background process (called from terminal_tool)
session = process_registry.spawn(env, "pytest -v", task_id="task_123")
# Poll for status
result = process_registry.poll(session.id)
# Block until done
result = process_registry.wait(session.id, timeout=300)
# Kill it
process_registry.kill(session.id)
"""
import json
import logging
import os
import platform
import shlex
import shutil
import signal
import subprocess
import threading
import time
import uuid
_IS_WINDOWS = platform.system() == "Windows"
from tools.environments.local import _find_shell, _sanitize_subprocess_env
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
from hermes_cli.config import get_hermes_home
logger = logging.getLogger(__name__)
# Checkpoint file for crash recovery (gateway only)
CHECKPOINT_PATH = get_hermes_home() / "processes.json"
# Limits
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
@dataclass
class ProcessSession:
"""A tracked background process with output buffering."""
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
command: str # Original command string
task_id: str = "" # Task/sandbox isolation key
session_key: str = "" # Gateway session key (for reset protection)
pid: Optional[int] = None # OS process ID
process: Optional[subprocess.Popen] = None # Popen handle (local only)
env_ref: Any = None # Reference to the environment object
cwd: Optional[str] = None # Working directory
started_at: float = 0.0 # time.time() of spawn
exited: bool = False # Whether the process has finished
exit_code: Optional[int] = None # Exit code (None if still running)
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
max_output_chars: int = MAX_OUTPUT_CHARS
detached: bool = False # True if recovered from crash (no pipe)
# Watcher/notification metadata (persisted for crash recovery)
watcher_platform: str = ""
watcher_chat_id: str = ""
watcher_thread_id: str = ""
watcher_interval: int = 0 # 0 = no watcher configured
_lock: threading.Lock = field(default_factory=threading.Lock)
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
class ProcessRegistry:
"""
In-memory registry of running and finished background processes.
Thread-safe. Accessed from:
- Executor threads (terminal_tool, process tool handlers)
- Gateway asyncio loop (watcher tasks, session reset checks)
- Cleanup thread (sandbox reaping coordination)
"""
_SHELL_NOISE_SUBSTRINGS = (
"bash: cannot set terminal process group",
"bash: no job control in this shell",
"no job control in this shell",
"cannot set terminal process group",
"tcsetattr: Inappropriate ioctl for device",
)
def __init__(self):
self._running: Dict[str, ProcessSession] = {}
self._finished: Dict[str, ProcessSession] = {}
self._lock = threading.Lock()
# Side-channel for check_interval watchers (gateway reads after agent run)
self.pending_watchers: List[Dict[str, Any]] = []
@staticmethod
def _clean_shell_noise(text: str) -> str:
"""Strip shell startup warnings from the beginning of output."""
lines = text.split("\n")
while lines and any(noise in lines[0] for noise in ProcessRegistry._SHELL_NOISE_SUBSTRINGS):
lines.pop(0)
return "\n".join(lines)
# ----- Spawn -----
def spawn_local(
self,
command: str,
cwd: str = None,
task_id: str = "",
session_key: str = "",
env_vars: dict = None,
use_pty: bool = False,
) -> ProcessSession:
"""
Spawn a background process locally.
Only for TERMINAL_ENV=local. Other backends use spawn_via_env().
Args:
use_pty: If True, use a pseudo-terminal via ptyprocess for interactive
CLI tools (Codex, Claude Code, Python REPL). Falls back to
subprocess.Popen if ptyprocess is not installed.
"""
session = ProcessSession(
id=f"proc_{uuid.uuid4().hex[:12]}",
command=command,
task_id=task_id,
session_key=session_key,
cwd=cwd or os.getcwd(),
started_at=time.time(),
)
if use_pty:
# Try PTY mode for interactive CLI tools
try:
if _IS_WINDOWS:
from winpty import PtyProcess as _PtyProcessCls
else:
from ptyprocess import PtyProcess as _PtyProcessCls
user_shell = _find_shell()
pty_env = _sanitize_subprocess_env(os.environ, env_vars)
pty_env["PYTHONUNBUFFERED"] = "1"
pty_proc = _PtyProcessCls.spawn(
[user_shell, "-lic", command],
cwd=session.cwd,
env=pty_env,
dimensions=(30, 120),
)
session.pid = pty_proc.pid
# Store the pty handle on the session for read/write
session._pty = pty_proc
# PTY reader thread
reader = threading.Thread(
target=self._pty_reader_loop,
args=(session,),
daemon=True,
name=f"proc-pty-reader-{session.id}",
)
session._reader_thread = reader
reader.start()
with self._lock:
self._prune_if_needed()
self._running[session.id] = session
self._write_checkpoint()
return session
except ImportError:
logger.warning("ptyprocess not installed, falling back to pipe mode")
except Exception as e:
logger.warning("PTY spawn failed (%s), falling back to pipe mode", e)
# Standard Popen path (non-PTY or PTY fallback)
# Use the user's login shell for consistency with LocalEnvironment --
# ensures rc files are sourced and user tools are available.
user_shell = _find_shell()
# Force unbuffered output for Python scripts so progress is visible
# during background execution (libraries like tqdm/datasets buffer when
# stdout is a pipe, hiding output from process(action="poll")).
bg_env = _sanitize_subprocess_env(os.environ, env_vars)
bg_env["PYTHONUNBUFFERED"] = "1"
proc = subprocess.Popen(
[user_shell, "-lic", command],
text=True,
cwd=session.cwd,
env=bg_env,
encoding="utf-8",
errors="replace",
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE,
preexec_fn=None if _IS_WINDOWS else os.setsid,
)
session.process = proc
session.pid = proc.pid
# Start output reader thread
reader = threading.Thread(
target=self._reader_loop,
args=(session,),
daemon=True,
name=f"proc-reader-{session.id}",
)
session._reader_thread = reader
reader.start()
with self._lock:
self._prune_if_needed()
self._running[session.id] = session
self._write_checkpoint()
return session
def spawn_via_env(
self,
env: Any,
command: str,
cwd: str = None,
task_id: str = "",
session_key: str = "",
timeout: int = 10,
) -> ProcessSession:
"""
Spawn a background process through a non-local environment backend.
For Docker/Singularity/Modal/Daytona/SSH: runs the command inside the sandbox
using the environment's execute() interface. We wrap the command to
capture the in-sandbox PID and redirect output to a log file inside
the sandbox, then poll the log via subsequent execute() calls.
This is less capable than local spawn (no live stdout pipe, no stdin),
but it ensures the command runs in the correct sandbox context.
"""
session = ProcessSession(
id=f"proc_{uuid.uuid4().hex[:12]}",
command=command,
task_id=task_id,
session_key=session_key,
cwd=cwd,
started_at=time.time(),
env_ref=env,
)
# Run the command in the sandbox with output capture
log_path = f"/tmp/hermes_bg_{session.id}.log"
pid_path = f"/tmp/hermes_bg_{session.id}.pid"
quoted_command = shlex.quote(command)
bg_command = (
f"nohup bash -c {quoted_command} > {log_path} 2>&1 & "
f"echo $! > {pid_path} && cat {pid_path}"
)
try:
result = env.execute(bg_command, timeout=timeout)
output = result.get("output", "").strip()
# Try to extract the PID from the output
for line in output.splitlines():
line = line.strip()
if line.isdigit():
session.pid = int(line)
break
except Exception as e:
session.exited = True
session.exit_code = -1
session.output_buffer = f"Failed to start: {e}"
if not session.exited:
# Start a poller thread that periodically reads the log file
reader = threading.Thread(
target=self._env_poller_loop,
args=(session, env, log_path, pid_path),
daemon=True,
name=f"proc-poller-{session.id}",
)
session._reader_thread = reader
reader.start()
with self._lock:
self._prune_if_needed()
self._running[session.id] = session
self._write_checkpoint()
return session
# ----- Reader / Poller Threads -----
def _reader_loop(self, session: ProcessSession):
"""Background thread: read stdout from a local Popen process."""
first_chunk = True
try:
while True:
chunk = session.process.stdout.read(4096)
if not chunk:
break
if first_chunk:
chunk = self._clean_shell_noise(chunk)
first_chunk = False
with session._lock:
session.output_buffer += chunk
if len(session.output_buffer) > session.max_output_chars:
session.output_buffer = session.output_buffer[-session.max_output_chars:]
except Exception as e:
logger.debug("Process stdout reader ended: %s", e)
# Process exited
try:
session.process.wait(timeout=5)
except Exception as e:
logger.debug("Process wait timed out or failed: %s", e)
session.exited = True
session.exit_code = session.process.returncode
self._move_to_finished(session)
def _env_poller_loop(
self, session: ProcessSession, env: Any, log_path: str, pid_path: str
):
"""Background thread: poll a sandbox log file for non-local backends."""
while not session.exited:
time.sleep(2) # Poll every 2 seconds
try:
# Read new output from the log file
result = env.execute(f"cat {log_path} 2>/dev/null", timeout=10)
new_output = result.get("output", "")
if new_output:
with session._lock:
session.output_buffer = new_output
if len(session.output_buffer) > session.max_output_chars:
session.output_buffer = session.output_buffer[-session.max_output_chars:]
# Check if process is still running
check = env.execute(
f"kill -0 $(cat {pid_path} 2>/dev/null) 2>/dev/null; echo $?",
timeout=5,
)
check_output = check.get("output", "").strip()
if check_output and check_output.splitlines()[-1].strip() != "0":
# Process has exited -- get exit code
exit_result = env.execute(
f"wait $(cat {pid_path} 2>/dev/null) 2>/dev/null; echo $?",
timeout=5,
)
exit_str = exit_result.get("output", "").strip()
try:
session.exit_code = int(exit_str.splitlines()[-1].strip())
except (ValueError, IndexError):
session.exit_code = -1
session.exited = True
self._move_to_finished(session)
return
except Exception:
# Environment might be gone (sandbox reaped, etc.)
session.exited = True
session.exit_code = -1
self._move_to_finished(session)
return
def _pty_reader_loop(self, session: ProcessSession):
"""Background thread: read output from a PTY process."""
pty = session._pty
try:
while pty.isalive():
try:
chunk = pty.read(4096)
if chunk:
# ptyprocess returns bytes
text = chunk if isinstance(chunk, str) else chunk.decode("utf-8", errors="replace")
with session._lock:
session.output_buffer += text
if len(session.output_buffer) > session.max_output_chars:
session.output_buffer = session.output_buffer[-session.max_output_chars:]
except EOFError:
break
except Exception:
break
except Exception as e:
logger.debug("PTY stdout reader ended: %s", e)
# Process exited
try:
pty.wait()
except Exception as e:
logger.debug("PTY wait timed out or failed: %s", e)
session.exited = True
session.exit_code = pty.exitstatus if hasattr(pty, 'exitstatus') else -1
self._move_to_finished(session)
def _move_to_finished(self, session: ProcessSession):
"""Move a session from running to finished."""
with self._lock:
self._running.pop(session.id, None)
self._finished[session.id] = session
self._write_checkpoint()
# ----- Query Methods -----
def get(self, session_id: str) -> Optional[ProcessSession]:
"""Get a session by ID (running or finished)."""
with self._lock:
return self._running.get(session_id) or self._finished.get(session_id)
def poll(self, session_id: str) -> dict:
"""Check status and get new output for a background process."""
from tools.ansi_strip import strip_ansi
session = self.get(session_id)
if session is None:
return {"status": "not_found", "error": f"No process with ID {session_id}"}
with session._lock:
output_preview = strip_ansi(session.output_buffer[-1000:]) if session.output_buffer else ""
result = {
"session_id": session.id,
"command": session.command,
"status": "exited" if session.exited else "running",
"pid": session.pid,
"uptime_seconds": int(time.time() - session.started_at),
"output_preview": output_preview,
}
if session.exited:
result["exit_code"] = session.exit_code
if session.detached:
result["detached"] = True
result["note"] = "Process recovered after restart -- output history unavailable"
return result
def read_log(self, session_id: str, offset: int = 0, limit: int = 200) -> dict:
"""Read the full output log with optional pagination by lines."""
from tools.ansi_strip import strip_ansi
session = self.get(session_id)
if session is None:
return {"status": "not_found", "error": f"No process with ID {session_id}"}
with session._lock:
full_output = strip_ansi(session.output_buffer)
lines = full_output.splitlines()
total_lines = len(lines)
# Default: last N lines
if offset == 0 and limit > 0:
selected = lines[-limit:]
else:
selected = lines[offset:offset + limit]
return {
"session_id": session.id,
"status": "exited" if session.exited else "running",
"output": "\n".join(selected),
"total_lines": total_lines,
"showing": f"{len(selected)} lines",
}
def wait(self, session_id: str, timeout: int = None) -> dict:
"""
Block until a process exits, timeout, or interrupt.
Args:
session_id: The process to wait for.
timeout: Max seconds to block. Falls back to TERMINAL_TIMEOUT config.
Returns:
dict with status ("exited", "timeout", "interrupted", "not_found")
and output snapshot.
"""
from tools.ansi_strip import strip_ansi
from tools.terminal_tool import _interrupt_event
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
max_timeout = default_timeout
requested_timeout = timeout
timeout_note = None
if requested_timeout and requested_timeout > max_timeout:
effective_timeout = max_timeout
timeout_note = (
f"Requested wait of {requested_timeout}s was clamped "
f"to configured limit of {max_timeout}s"
)
else:
effective_timeout = requested_timeout or max_timeout
session = self.get(session_id)
if session is None:
return {"status": "not_found", "error": f"No process with ID {session_id}"}
deadline = time.monotonic() + effective_timeout
while time.monotonic() < deadline:
if session.exited:
result = {
"status": "exited",
"exit_code": session.exit_code,
"output": strip_ansi(session.output_buffer[-2000:]),
}
if timeout_note:
result["timeout_note"] = timeout_note
return result
if _interrupt_event.is_set():
result = {
"status": "interrupted",
"output": strip_ansi(session.output_buffer[-1000:]),
"note": "User sent a new message -- wait interrupted",
}
if timeout_note:
result["timeout_note"] = timeout_note
return result
time.sleep(1)
result = {
"status": "timeout",
"output": strip_ansi(session.output_buffer[-1000:]),
}
if timeout_note:
result["timeout_note"] = timeout_note
else:
result["timeout_note"] = f"Waited {effective_timeout}s, process still running"
return result
def kill_process(self, session_id: str) -> dict:
"""Kill a background process."""
session = self.get(session_id)
if session is None:
return {"status": "not_found", "error": f"No process with ID {session_id}"}
if session.exited:
return {
"status": "already_exited",
"exit_code": session.exit_code,
}
# Kill via PTY, Popen (local), or env execute (non-local)
try:
if session._pty:
# PTY process -- terminate via ptyprocess
try:
session._pty.terminate(force=True)
except Exception:
if session.pid:
os.kill(session.pid, signal.SIGTERM)
elif session.process:
# Local process -- kill the process group
try:
if _IS_WINDOWS:
session.process.terminate()
else:
os.killpg(os.getpgid(session.process.pid), signal.SIGTERM)
except (ProcessLookupError, PermissionError):
session.process.kill()
elif session.env_ref and session.pid:
# Non-local -- kill inside sandbox
session.env_ref.execute(f"kill {session.pid} 2>/dev/null", timeout=5)
session.exited = True
session.exit_code = -15 # SIGTERM
self._move_to_finished(session)
self._write_checkpoint()
return {"status": "killed", "session_id": session.id}
except Exception as e:
return {"status": "error", "error": str(e)}
def write_stdin(self, session_id: str, data: str) -> dict:
"""Send raw data to a running process's stdin (no newline appended)."""
session = self.get(session_id)
if session is None:
return {"status": "not_found", "error": f"No process with ID {session_id}"}
if session.exited:
return {"status": "already_exited", "error": "Process has already finished"}
# PTY mode -- write through pty handle (expects bytes)
if hasattr(session, '_pty') and session._pty:
try:
pty_data = data.encode("utf-8") if isinstance(data, str) else data
session._pty.write(pty_data)
return {"status": "ok", "bytes_written": len(data)}
except Exception as e:
return {"status": "error", "error": str(e)}
# Popen mode -- write through stdin pipe
if not session.process or not session.process.stdin:
return {"status": "error", "error": "Process stdin not available (non-local backend or stdin closed)"}
try:
session.process.stdin.write(data)
session.process.stdin.flush()
return {"status": "ok", "bytes_written": len(data)}
except Exception as e:
return {"status": "error", "error": str(e)}
def submit_stdin(self, session_id: str, data: str = "") -> dict:
"""Send data + newline to a running process's stdin (like pressing Enter)."""
return self.write_stdin(session_id, data + "\n")
def list_sessions(self, task_id: str = None) -> list:
"""List all running and recently-finished processes."""
with self._lock:
all_sessions = list(self._running.values()) + list(self._finished.values())
if task_id:
all_sessions = [s for s in all_sessions if s.task_id == task_id]
result = []
for s in all_sessions:
entry = {
"session_id": s.id,
"command": s.command[:200],
"cwd": s.cwd,
"pid": s.pid,
"started_at": time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(s.started_at)),
"uptime_seconds": int(time.time() - s.started_at),
"status": "exited" if s.exited else "running",
"output_preview": s.output_buffer[-200:] if s.output_buffer else "",
}
if s.exited:
entry["exit_code"] = s.exit_code
if s.detached:
entry["detached"] = True
result.append(entry)
return result
# ----- Session/Task Queries (for gateway integration) -----
def has_active_processes(self, task_id: str) -> bool:
"""Check if there are active (running) processes for a task_id."""
with self._lock:
return any(
s.task_id == task_id and not s.exited
for s in self._running.values()
)
def has_active_for_session(self, session_key: str) -> bool:
"""Check if there are active processes for a gateway session key."""
with self._lock:
return any(
s.session_key == session_key and not s.exited
for s in self._running.values()
)
def kill_all(self, task_id: str = None) -> int:
"""Kill all running processes, optionally filtered by task_id. Returns count killed."""
with self._lock:
targets = [
s for s in self._running.values()
if (task_id is None or s.task_id == task_id) and not s.exited
]
killed = 0
for session in targets:
result = self.kill_process(session.id)
if result.get("status") in ("killed", "already_exited"):
killed += 1
return killed
# ----- Cleanup / Pruning -----
def _prune_if_needed(self):
"""Remove oldest finished sessions if over MAX_PROCESSES. Must hold _lock."""
# First prune expired finished sessions
now = time.time()
expired = [
sid for sid, s in self._finished.items()
if (now - s.started_at) > FINISHED_TTL_SECONDS
]
for sid in expired:
del self._finished[sid]
# If still over limit, remove oldest finished
total = len(self._running) + len(self._finished)
if total >= MAX_PROCESSES and self._finished:
oldest_id = min(self._finished, key=lambda sid: self._finished[sid].started_at)
del self._finished[oldest_id]
def cleanup_expired(self):
"""Public method to prune expired finished sessions."""
with self._lock:
self._prune_if_needed()
# ----- Checkpoint (crash recovery) -----
def _write_checkpoint(self):
"""Write running process metadata to checkpoint file atomically."""
try:
with self._lock:
entries = []
for s in self._running.values():
if not s.exited:
entries.append({
"session_id": s.id,
"command": s.command,
"pid": s.pid,
"cwd": s.cwd,
"started_at": s.started_at,
"task_id": s.task_id,
"session_key": s.session_key,
"watcher_platform": s.watcher_platform,
"watcher_chat_id": s.watcher_chat_id,
"watcher_thread_id": s.watcher_thread_id,
"watcher_interval": s.watcher_interval,
})
# Atomic write to avoid corruption on crash
from utils import atomic_json_write
atomic_json_write(CHECKPOINT_PATH, entries)
except Exception as e:
logger.debug("Failed to write checkpoint file: %s", e, exc_info=True)
def recover_from_checkpoint(self) -> int:
"""
On gateway startup, probe PIDs from checkpoint file.
Returns the number of processes recovered as detached.
"""
if not CHECKPOINT_PATH.exists():
return 0
try:
entries = json.loads(CHECKPOINT_PATH.read_text(encoding="utf-8"))
except Exception:
return 0
recovered = 0
for entry in entries:
pid = entry.get("pid")
if not pid:
continue
# Check if PID is still alive
alive = False
try:
os.kill(pid, 0)
alive = True
except (ProcessLookupError, PermissionError):
pass
if alive:
session = ProcessSession(
id=entry["session_id"],
command=entry.get("command", "unknown"),
task_id=entry.get("task_id", ""),
session_key=entry.get("session_key", ""),
pid=pid,
cwd=entry.get("cwd"),
started_at=entry.get("started_at", time.time()),
detached=True, # Can't read output, but can report status + kill
watcher_platform=entry.get("watcher_platform", ""),
watcher_chat_id=entry.get("watcher_chat_id", ""),
watcher_thread_id=entry.get("watcher_thread_id", ""),
watcher_interval=entry.get("watcher_interval", 0),
)
with self._lock:
self._running[session.id] = session
recovered += 1
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
# Re-enqueue watcher so gateway can resume notifications
if session.watcher_interval > 0:
self.pending_watchers.append({
"session_id": session.id,
"check_interval": session.watcher_interval,
"session_key": session.session_key,
"platform": session.watcher_platform,
"chat_id": session.watcher_chat_id,
"thread_id": session.watcher_thread_id,
})
# Clear the checkpoint (will be rewritten as processes finish)
try:
from utils import atomic_json_write
atomic_json_write(CHECKPOINT_PATH, [])
except Exception as e:
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)
return recovered
# Module-level singleton
process_registry = ProcessRegistry()
# ---------------------------------------------------------------------------
# Registry -- the "process" tool schema + handler
# ---------------------------------------------------------------------------
from tools.registry import registry
PROCESS_SCHEMA = {
"name": "process",
"description": (
"Manage background processes started with terminal(background=true). "
"Actions: 'list' (show all), 'poll' (check status + new output), "
"'log' (full output with pagination), 'wait' (block until done or timeout), "
"'kill' (terminate), 'write' (send raw stdin data without newline), "
"'submit' (send data + Enter, for answering prompts)."
),
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["list", "poll", "log", "wait", "kill", "write", "submit"],
"description": "Action to perform on background processes"
},
"session_id": {
"type": "string",
"description": "Process session ID (from terminal background output). Required for all actions except 'list'."
},
"data": {
"type": "string",
"description": "Text to send to process stdin (for 'write' and 'submit' actions)"
},
"timeout": {
"type": "integer",
"description": "Max seconds to block for 'wait' action. Returns partial output on timeout.",
"minimum": 1
},
"offset": {
"type": "integer",
"description": "Line offset for 'log' action (default: last 200 lines)"
},
"limit": {
"type": "integer",
"description": "Max lines to return for 'log' action",
"minimum": 1
}
},
"required": ["action"]
}
}
def _handle_process(args, **kw):
import json as _json
task_id = kw.get("task_id")
action = args.get("action", "")
# Coerce to string — some models send session_id as an integer
session_id = str(args.get("session_id", "")) if args.get("session_id") is not None else ""
if action == "list":
return _json.dumps({"processes": process_registry.list_sessions(task_id=task_id)}, ensure_ascii=False)
elif action in ("poll", "log", "wait", "kill", "write", "submit"):
if not session_id:
return _json.dumps({"error": f"session_id is required for {action}"}, ensure_ascii=False)
if action == "poll":
return _json.dumps(process_registry.poll(session_id), ensure_ascii=False)
elif action == "log":
return _json.dumps(process_registry.read_log(
session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)), ensure_ascii=False)
elif action == "wait":
return _json.dumps(process_registry.wait(session_id, timeout=args.get("timeout")), ensure_ascii=False)
elif action == "kill":
return _json.dumps(process_registry.kill_process(session_id), ensure_ascii=False)
elif action == "write":
return _json.dumps(process_registry.write_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
elif action == "submit":
return _json.dumps(process_registry.submit_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
return _json.dumps({"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"}, ensure_ascii=False)
registry.register(
name="process",
toolset="terminal",
schema=PROCESS_SCHEMA,
handler=_handle_process,
emoji="⚙️",
)

View file

@ -0,0 +1,237 @@
"""Central registry for all hermes-agent tools.
Each tool file calls ``registry.register()`` at module level to declare its
schema, handler, toolset membership, and availability check. ``model_tools.py``
queries the registry instead of maintaining its own parallel data structures.
Import chain (circular-import safe):
tools/registry.py (no imports from model_tools or tool files)
^
tools/*.py (import from tools.registry at module level)
^
model_tools.py (imports tools.registry + all tool modules)
^
run_agent.py, cli.py, batch_runner.py, etc.
"""
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Set
logger = logging.getLogger(__name__)
class ToolEntry:
"""Metadata for a single registered tool."""
__slots__ = (
"name", "toolset", "schema", "handler", "check_fn",
"requires_env", "is_async", "description", "emoji",
)
def __init__(self, name, toolset, schema, handler, check_fn,
requires_env, is_async, description, emoji):
self.name = name
self.toolset = toolset
self.schema = schema
self.handler = handler
self.check_fn = check_fn
self.requires_env = requires_env
self.is_async = is_async
self.description = description
self.emoji = emoji
class ToolRegistry:
"""Singleton registry that collects tool schemas + handlers from tool files."""
def __init__(self):
self._tools: Dict[str, ToolEntry] = {}
self._toolset_checks: Dict[str, Callable] = {}
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
def register(
self,
name: str,
toolset: str,
schema: dict,
handler: Callable,
check_fn: Callable = None,
requires_env: list = None,
is_async: bool = False,
description: str = "",
emoji: str = "",
):
"""Register a tool. Called at module-import time by each tool file."""
self._tools[name] = ToolEntry(
name=name,
toolset=toolset,
schema=schema,
handler=handler,
check_fn=check_fn,
requires_env=requires_env or [],
is_async=is_async,
description=description or schema.get("description", ""),
emoji=emoji,
)
if check_fn and toolset not in self._toolset_checks:
self._toolset_checks[toolset] = check_fn
# ------------------------------------------------------------------
# Schema retrieval
# ------------------------------------------------------------------
def get_definitions(self, tool_names: Set[str], quiet: bool = False) -> List[dict]:
"""Return OpenAI-format tool schemas for the requested tool names.
Only tools whose ``check_fn()`` returns True (or have no check_fn)
are included.
"""
result = []
for name in sorted(tool_names):
entry = self._tools.get(name)
if not entry:
continue
if entry.check_fn:
try:
if not entry.check_fn():
if not quiet:
logger.debug("Tool %s unavailable (check failed)", name)
continue
except Exception:
if not quiet:
logger.debug("Tool %s check raised; skipping", name)
continue
result.append({"type": "function", "function": entry.schema})
return result
# ------------------------------------------------------------------
# Dispatch
# ------------------------------------------------------------------
def dispatch(self, name: str, args: dict, **kwargs) -> str:
"""Execute a tool handler by name.
* Async handlers are bridged automatically via ``_run_async()``.
* All exceptions are caught and returned as ``{"error": "..."}``
for consistent error format.
"""
entry = self._tools.get(name)
if not entry:
return json.dumps({"error": f"Unknown tool: {name}"})
try:
if entry.is_async:
from model_tools import _run_async
return _run_async(entry.handler(args, **kwargs))
return entry.handler(args, **kwargs)
except Exception as e:
logger.exception("Tool %s dispatch error: %s", name, e)
return json.dumps({"error": f"Tool execution failed: {type(e).__name__}: {e}"})
# ------------------------------------------------------------------
# Query helpers (replace redundant dicts in model_tools.py)
# ------------------------------------------------------------------
def get_all_tool_names(self) -> List[str]:
"""Return sorted list of all registered tool names."""
return sorted(self._tools.keys())
def get_toolset_for_tool(self, name: str) -> Optional[str]:
"""Return the toolset a tool belongs to, or None."""
entry = self._tools.get(name)
return entry.toolset if entry else None
def get_emoji(self, name: str, default: str = "") -> str:
"""Return the emoji for a tool, or *default* if unset."""
entry = self._tools.get(name)
return (entry.emoji if entry and entry.emoji else default)
def get_tool_to_toolset_map(self) -> Dict[str, str]:
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
return {name: e.toolset for name, e in self._tools.items()}
def is_toolset_available(self, toolset: str) -> bool:
"""Check if a toolset's requirements are met.
Returns False (rather than crashing) when the check function raises
an unexpected exception (e.g. network error, missing import, bad config).
"""
check = self._toolset_checks.get(toolset)
if not check:
return True
try:
return bool(check())
except Exception:
logger.debug("Toolset %s check raised; marking unavailable", toolset)
return False
def check_toolset_requirements(self) -> Dict[str, bool]:
"""Return ``{toolset: available_bool}`` for every toolset."""
toolsets = set(e.toolset for e in self._tools.values())
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)}
def get_available_toolsets(self) -> Dict[str, dict]:
"""Return toolset metadata for UI display."""
toolsets: Dict[str, dict] = {}
for entry in self._tools.values():
ts = entry.toolset
if ts not in toolsets:
toolsets[ts] = {
"available": self.is_toolset_available(ts),
"tools": [],
"description": "",
"requirements": [],
}
toolsets[ts]["tools"].append(entry.name)
if entry.requires_env:
for env in entry.requires_env:
if env not in toolsets[ts]["requirements"]:
toolsets[ts]["requirements"].append(env)
return toolsets
def get_toolset_requirements(self) -> Dict[str, dict]:
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
result: Dict[str, dict] = {}
for entry in self._tools.values():
ts = entry.toolset
if ts not in result:
result[ts] = {
"name": ts,
"env_vars": [],
"check_fn": self._toolset_checks.get(ts),
"setup_url": None,
"tools": [],
}
if entry.name not in result[ts]["tools"]:
result[ts]["tools"].append(entry.name)
for env in entry.requires_env:
if env not in result[ts]["env_vars"]:
result[ts]["env_vars"].append(env)
return result
def check_tool_availability(self, quiet: bool = False):
"""Return (available_toolsets, unavailable_info) like the old function."""
available = []
unavailable = []
seen = set()
for entry in self._tools.values():
ts = entry.toolset
if ts in seen:
continue
seen.add(ts)
if self.is_toolset_available(ts):
available.append(ts)
else:
unavailable.append({
"name": ts,
"env_vars": entry.requires_env,
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
})
return available, unavailable
# Module-level singleton
registry = ToolRegistry()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,691 @@
"""Send Message Tool -- cross-channel messaging via platform APIs.
Sends a message to a user or channel on any connected messaging platform
(Telegram, Discord, Slack). Supports listing available targets and resolving
human-friendly channel names to IDs. Works in both CLI and gateway contexts.
"""
import json
import logging
import os
import re
import ssl
import time
logger = logging.getLogger(__name__)
_TELEGRAM_TOPIC_TARGET_RE = re.compile(r"^\s*(-?\d+)(?::(\d+))?\s*$")
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"}
_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"}
_VOICE_EXTS = {".ogg", ".opus"}
SEND_MESSAGE_SCHEMA = {
"name": "send_message",
"description": (
"Send a message to a connected messaging platform, or list available targets.\n\n"
"IMPORTANT: When the user asks to send to a specific channel or person "
"(not just a bare platform name), call send_message(action='list') FIRST to see "
"available targets, then send to the correct one.\n"
"If the user just says a platform name like 'send to telegram', send directly "
"to the home channel without listing first."
),
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["send", "list"],
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms."
},
"target": {
"type": "string",
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', 'platform:chat_id', or Telegram topic 'telegram:chat_id:thread_id'. Examples: 'telegram', 'telegram:-1001234567890:17585', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'"
},
"message": {
"type": "string",
"description": "The message text to send"
}
},
"required": []
}
}
def send_message_tool(args, **kw):
"""Handle cross-channel send_message tool calls."""
action = args.get("action", "send")
if action == "list":
return _handle_list()
return _handle_send(args)
def _handle_list():
"""Return formatted list of available messaging targets."""
try:
from gateway.channel_directory import format_directory_for_display
return json.dumps({"targets": format_directory_for_display()})
except Exception as e:
return json.dumps({"error": f"Failed to load channel directory: {e}"})
def _handle_send(args):
"""Send a message to a platform target."""
target = args.get("target", "")
message = args.get("message", "")
if not target or not message:
return json.dumps({"error": "Both 'target' and 'message' are required when action='send'"})
parts = target.split(":", 1)
platform_name = parts[0].strip().lower()
target_ref = parts[1].strip() if len(parts) > 1 else None
chat_id = None
thread_id = None
if target_ref:
chat_id, thread_id, is_explicit = _parse_target_ref(platform_name, target_ref)
else:
is_explicit = False
# Resolve human-friendly channel names to numeric IDs
if target_ref and not is_explicit:
try:
from gateway.channel_directory import resolve_channel_name
resolved = resolve_channel_name(platform_name, target_ref)
if resolved:
chat_id, thread_id, _ = _parse_target_ref(platform_name, resolved)
else:
return json.dumps({
"error": f"Could not resolve '{target_ref}' on {platform_name}. "
f"Use send_message(action='list') to see available targets."
})
except Exception:
return json.dumps({
"error": f"Could not resolve '{target_ref}' on {platform_name}. "
f"Try using a numeric channel ID instead."
})
from tools.interrupt import is_interrupted
if is_interrupted():
return json.dumps({"error": "Interrupted"})
try:
from gateway.config import load_gateway_config, Platform
config = load_gateway_config()
except Exception as e:
return json.dumps({"error": f"Failed to load gateway config: {e}"})
platform_map = {
"telegram": Platform.TELEGRAM,
"discord": Platform.DISCORD,
"slack": Platform.SLACK,
"whatsapp": Platform.WHATSAPP,
"signal": Platform.SIGNAL,
"matrix": Platform.MATRIX,
"mattermost": Platform.MATTERMOST,
"homeassistant": Platform.HOMEASSISTANT,
"dingtalk": Platform.DINGTALK,
"email": Platform.EMAIL,
"sms": Platform.SMS,
}
platform = platform_map.get(platform_name)
if not platform:
avail = ", ".join(platform_map.keys())
return json.dumps({"error": f"Unknown platform: {platform_name}. Available: {avail}"})
pconfig = config.platforms.get(platform)
if not pconfig or not pconfig.enabled:
return json.dumps({"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/config.yaml or environment variables."})
from gateway.platforms.base import BasePlatformAdapter
media_files, cleaned_message = BasePlatformAdapter.extract_media(message)
mirror_text = cleaned_message.strip() or _describe_media_for_mirror(media_files)
used_home_channel = False
if not chat_id:
home = config.get_home_channel(platform)
if home:
chat_id = home.chat_id
used_home_channel = True
else:
return json.dumps({
"error": f"No home channel set for {platform_name} to determine where to send the message. "
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
})
duplicate_skip = _maybe_skip_cron_duplicate_send(platform_name, chat_id, thread_id)
if duplicate_skip:
return json.dumps(duplicate_skip)
try:
from model_tools import _run_async
result = _run_async(
_send_to_platform(
platform,
pconfig,
chat_id,
cleaned_message,
thread_id=thread_id,
media_files=media_files,
)
)
if used_home_channel and isinstance(result, dict) and result.get("success"):
result["note"] = f"Sent to {platform_name} home channel (chat_id: {chat_id})"
# Mirror the sent message into the target's gateway session
if isinstance(result, dict) and result.get("success") and mirror_text:
try:
from gateway.mirror import mirror_to_session
source_label = os.getenv("HERMES_SESSION_PLATFORM", "cli")
if mirror_to_session(platform_name, chat_id, mirror_text, source_label=source_label, thread_id=thread_id):
result["mirrored"] = True
except Exception:
pass
return json.dumps(result)
except Exception as e:
return json.dumps({"error": f"Send failed: {e}"})
def _parse_target_ref(platform_name: str, target_ref: str):
"""Parse a tool target into chat_id/thread_id and whether it is explicit."""
if platform_name == "telegram":
match = _TELEGRAM_TOPIC_TARGET_RE.fullmatch(target_ref)
if match:
return match.group(1), match.group(2), True
if target_ref.lstrip("-").isdigit():
return target_ref, None, True
return None, None, False
def _describe_media_for_mirror(media_files):
"""Return a human-readable mirror summary when a message only contains media."""
if not media_files:
return ""
if len(media_files) == 1:
media_path, is_voice = media_files[0]
ext = os.path.splitext(media_path)[1].lower()
if is_voice and ext in _VOICE_EXTS:
return "[Sent voice message]"
if ext in _IMAGE_EXTS:
return "[Sent image attachment]"
if ext in _VIDEO_EXTS:
return "[Sent video attachment]"
if ext in _AUDIO_EXTS:
return "[Sent audio attachment]"
return "[Sent document attachment]"
return f"[Sent {len(media_files)} media attachments]"
def _get_cron_auto_delivery_target():
"""Return the cron scheduler's auto-delivery target for the current run, if any."""
platform = os.getenv("HERMES_CRON_AUTO_DELIVER_PLATFORM", "").strip().lower()
chat_id = os.getenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID", "").strip()
if not platform or not chat_id:
return None
thread_id = os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID", "").strip() or None
return {
"platform": platform,
"chat_id": chat_id,
"thread_id": thread_id,
}
def _maybe_skip_cron_duplicate_send(platform_name: str, chat_id: str, thread_id: str | None):
"""Skip redundant cron send_message calls when the scheduler will auto-deliver there."""
auto_target = _get_cron_auto_delivery_target()
if not auto_target:
return None
same_target = (
auto_target["platform"] == platform_name
and str(auto_target["chat_id"]) == str(chat_id)
and auto_target.get("thread_id") == thread_id
)
if not same_target:
return None
target_label = f"{platform_name}:{chat_id}"
if thread_id is not None:
target_label += f":{thread_id}"
return {
"success": True,
"skipped": True,
"reason": "cron_auto_delivery_duplicate_target",
"target": target_label,
"note": (
f"Skipped send_message to {target_label}. This cron job will already auto-deliver "
"its final response to that same target. Put the intended user-facing content in "
"your final response instead, or use a different target if you want an additional message."
),
}
async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, media_files=None):
"""Route a message to the appropriate platform sender.
Long messages are automatically chunked to fit within platform limits
using the same smart-splitting algorithm as the gateway adapters
(preserves code-block boundaries, adds part indicators).
"""
from gateway.config import Platform
from gateway.platforms.base import BasePlatformAdapter
from gateway.platforms.telegram import TelegramAdapter
from gateway.platforms.discord import DiscordAdapter
from gateway.platforms.slack import SlackAdapter
media_files = media_files or []
# Platform message length limits (from adapter class attributes)
_MAX_LENGTHS = {
Platform.TELEGRAM: TelegramAdapter.MAX_MESSAGE_LENGTH,
Platform.DISCORD: DiscordAdapter.MAX_MESSAGE_LENGTH,
Platform.SLACK: SlackAdapter.MAX_MESSAGE_LENGTH,
}
# Smart-chunk the message to fit within platform limits.
# For short messages or platforms without a known limit this is a no-op.
max_len = _MAX_LENGTHS.get(platform)
if max_len:
chunks = BasePlatformAdapter.truncate_message(message, max_len)
else:
chunks = [message]
# --- Telegram: special handling for media attachments ---
if platform == Platform.TELEGRAM:
last_result = None
for i, chunk in enumerate(chunks):
is_last = (i == len(chunks) - 1)
result = await _send_telegram(
pconfig.token,
chat_id,
chunk,
media_files=media_files if is_last else [],
thread_id=thread_id,
)
if isinstance(result, dict) and result.get("error"):
return result
last_result = result
return last_result
# --- Non-Telegram platforms ---
if media_files and not message.strip():
return {
"error": (
f"send_message MEDIA delivery is currently only supported for telegram; "
f"target {platform.value} had only media attachments"
)
}
warning = None
if media_files:
warning = (
f"MEDIA attachments were omitted for {platform.value}; "
"native send_message media delivery is currently only supported for telegram"
)
last_result = None
for chunk in chunks:
if platform == Platform.DISCORD:
result = await _send_discord(pconfig.token, chat_id, chunk)
elif platform == Platform.SLACK:
result = await _send_slack(pconfig.token, chat_id, chunk)
elif platform == Platform.WHATSAPP:
result = await _send_whatsapp(pconfig.extra, chat_id, chunk)
elif platform == Platform.SIGNAL:
result = await _send_signal(pconfig.extra, chat_id, chunk)
elif platform == Platform.EMAIL:
result = await _send_email(pconfig.extra, chat_id, chunk)
elif platform == Platform.SMS:
result = await _send_sms(pconfig.api_key, chat_id, chunk)
else:
result = {"error": f"Direct sending not yet implemented for {platform.value}"}
if isinstance(result, dict) and result.get("error"):
return result
last_result = result
if warning and isinstance(last_result, dict) and last_result.get("success"):
warnings = list(last_result.get("warnings", []))
warnings.append(warning)
last_result["warnings"] = warnings
return last_result
async def _send_telegram(token, chat_id, message, media_files=None, thread_id=None):
"""Send via Telegram Bot API (one-shot, no polling needed).
Applies markdownMarkdownV2 formatting (same as the gateway adapter)
so that bold, links, and headers render correctly. If the message
already contains HTML tags, it is sent with ``parse_mode='HTML'``
instead, bypassing MarkdownV2 conversion.
"""
try:
from telegram import Bot
from telegram.constants import ParseMode
# Auto-detect HTML tags — if present, skip MarkdownV2 and send as HTML.
# Inspired by github.com/ashaney — PR #1568.
_has_html = bool(re.search(r'<[a-zA-Z/][^>]*>', message))
if _has_html:
formatted = message
send_parse_mode = ParseMode.HTML
else:
# Reuse the gateway adapter's format_message for markdown→MarkdownV2
try:
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2
_adapter = TelegramAdapter.__new__(TelegramAdapter)
formatted = _adapter.format_message(message)
except Exception:
# Fallback: send as-is if formatting unavailable
formatted = message
send_parse_mode = ParseMode.MARKDOWN_V2
bot = Bot(token=token)
int_chat_id = int(chat_id)
media_files = media_files or []
thread_kwargs = {}
if thread_id is not None:
thread_kwargs["message_thread_id"] = int(thread_id)
last_msg = None
warnings = []
if formatted.strip():
try:
last_msg = await bot.send_message(
chat_id=int_chat_id, text=formatted,
parse_mode=send_parse_mode, **thread_kwargs
)
except Exception as md_error:
# Parse failed, fall back to plain text
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower() or "html" in str(md_error).lower():
logger.warning("Parse mode %s failed in _send_telegram, falling back to plain text: %s", send_parse_mode, md_error)
if not _has_html:
try:
from gateway.platforms.telegram import _strip_mdv2
plain = _strip_mdv2(formatted)
except Exception:
plain = message
else:
plain = message
last_msg = await bot.send_message(
chat_id=int_chat_id, text=plain,
parse_mode=None, **thread_kwargs
)
else:
raise
for media_path, is_voice in media_files:
if not os.path.exists(media_path):
warning = f"Media file not found, skipping: {media_path}"
logger.warning(warning)
warnings.append(warning)
continue
ext = os.path.splitext(media_path)[1].lower()
try:
with open(media_path, "rb") as f:
if ext in _IMAGE_EXTS:
last_msg = await bot.send_photo(
chat_id=int_chat_id, photo=f, **thread_kwargs
)
elif ext in _VIDEO_EXTS:
last_msg = await bot.send_video(
chat_id=int_chat_id, video=f, **thread_kwargs
)
elif ext in _VOICE_EXTS and is_voice:
last_msg = await bot.send_voice(
chat_id=int_chat_id, voice=f, **thread_kwargs
)
elif ext in _AUDIO_EXTS:
last_msg = await bot.send_audio(
chat_id=int_chat_id, audio=f, **thread_kwargs
)
else:
last_msg = await bot.send_document(
chat_id=int_chat_id, document=f, **thread_kwargs
)
except Exception as e:
warning = f"Failed to send media {media_path}: {e}"
logger.error(warning)
warnings.append(warning)
if last_msg is None:
error = "No deliverable text or media remained after processing MEDIA tags"
if warnings:
return {"error": error, "warnings": warnings}
return {"error": error}
result = {
"success": True,
"platform": "telegram",
"chat_id": chat_id,
"message_id": str(last_msg.message_id),
}
if warnings:
result["warnings"] = warnings
return result
except ImportError:
return {"error": "python-telegram-bot not installed. Run: pip install python-telegram-bot"}
except Exception as e:
return {"error": f"Telegram send failed: {e}"}
async def _send_discord(token, chat_id, message):
"""Send a single message via Discord REST API (no websocket client needed).
Chunking is handled by _send_to_platform() before this is called.
"""
try:
import aiohttp
except ImportError:
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
try:
url = f"https://discord.com/api/v10/channels/{chat_id}/messages"
headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json={"content": message}) as resp:
if resp.status not in (200, 201):
body = await resp.text()
return {"error": f"Discord API error ({resp.status}): {body}"}
data = await resp.json()
return {"success": True, "platform": "discord", "chat_id": chat_id, "message_id": data.get("id")}
except Exception as e:
return {"error": f"Discord send failed: {e}"}
async def _send_slack(token, chat_id, message):
"""Send via Slack Web API."""
try:
import aiohttp
except ImportError:
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
try:
url = "https://slack.com/api/chat.postMessage"
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json={"channel": chat_id, "text": message}) as resp:
data = await resp.json()
if data.get("ok"):
return {"success": True, "platform": "slack", "chat_id": chat_id, "message_id": data.get("ts")}
return {"error": f"Slack API error: {data.get('error', 'unknown')}"}
except Exception as e:
return {"error": f"Slack send failed: {e}"}
async def _send_whatsapp(extra, chat_id, message):
"""Send via the local WhatsApp bridge HTTP API."""
try:
import aiohttp
except ImportError:
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
try:
bridge_port = extra.get("bridge_port", 3000)
async with aiohttp.ClientSession() as session:
async with session.post(
f"http://localhost:{bridge_port}/send",
json={"chatId": chat_id, "message": message},
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
if resp.status == 200:
data = await resp.json()
return {
"success": True,
"platform": "whatsapp",
"chat_id": chat_id,
"message_id": data.get("messageId"),
}
body = await resp.text()
return {"error": f"WhatsApp bridge error ({resp.status}): {body}"}
except Exception as e:
return {"error": f"WhatsApp send failed: {e}"}
async def _send_signal(extra, chat_id, message):
"""Send via signal-cli JSON-RPC API."""
try:
import httpx
except ImportError:
return {"error": "httpx not installed"}
try:
http_url = extra.get("http_url", "http://127.0.0.1:8080").rstrip("/")
account = extra.get("account", "")
if not account:
return {"error": "Signal account not configured"}
params = {"account": account, "message": message}
if chat_id.startswith("group:"):
params["groupId"] = chat_id[6:]
else:
params["recipient"] = [chat_id]
payload = {
"jsonrpc": "2.0",
"method": "send",
"params": params,
"id": f"send_{int(time.time() * 1000)}",
}
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(f"{http_url}/api/v1/rpc", json=payload)
resp.raise_for_status()
data = resp.json()
if "error" in data:
return {"error": f"Signal RPC error: {data['error']}"}
return {"success": True, "platform": "signal", "chat_id": chat_id}
except Exception as e:
return {"error": f"Signal send failed: {e}"}
async def _send_email(extra, chat_id, message):
"""Send via SMTP (one-shot, no persistent connection needed)."""
import smtplib
from email.mime.text import MIMEText
address = extra.get("address") or os.getenv("EMAIL_ADDRESS", "")
password = os.getenv("EMAIL_PASSWORD", "")
smtp_host = extra.get("smtp_host") or os.getenv("EMAIL_SMTP_HOST", "")
smtp_port = int(os.getenv("EMAIL_SMTP_PORT", "587"))
if not all([address, password, smtp_host]):
return {"error": "Email not configured (EMAIL_ADDRESS, EMAIL_PASSWORD, EMAIL_SMTP_HOST required)"}
try:
msg = MIMEText(message, "plain", "utf-8")
msg["From"] = address
msg["To"] = chat_id
msg["Subject"] = "Hermes Agent"
server = smtplib.SMTP(smtp_host, smtp_port)
server.starttls(context=ssl.create_default_context())
server.login(address, password)
server.send_message(msg)
server.quit()
return {"success": True, "platform": "email", "chat_id": chat_id}
except Exception as e:
return {"error": f"Email send failed: {e}"}
async def _send_sms(auth_token, chat_id, message):
"""Send a single SMS via Twilio REST API.
Uses HTTP Basic auth (Account SID : Auth Token) and form-encoded POST.
Chunking is handled by _send_to_platform() before this is called.
"""
try:
import aiohttp
except ImportError:
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
import base64
account_sid = os.getenv("TWILIO_ACCOUNT_SID", "")
from_number = os.getenv("TWILIO_PHONE_NUMBER", "")
if not account_sid or not auth_token or not from_number:
return {"error": "SMS not configured (TWILIO_ACCOUNT_SID, TWILIO_AUTH_TOKEN, TWILIO_PHONE_NUMBER required)"}
# Strip markdown — SMS renders it as literal characters
message = re.sub(r"\*\*(.+?)\*\*", r"\1", message, flags=re.DOTALL)
message = re.sub(r"\*(.+?)\*", r"\1", message, flags=re.DOTALL)
message = re.sub(r"__(.+?)__", r"\1", message, flags=re.DOTALL)
message = re.sub(r"_(.+?)_", r"\1", message, flags=re.DOTALL)
message = re.sub(r"```[a-z]*\n?", "", message)
message = re.sub(r"`(.+?)`", r"\1", message)
message = re.sub(r"^#{1,6}\s+", "", message, flags=re.MULTILINE)
message = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", message)
message = re.sub(r"\n{3,}", "\n\n", message)
message = message.strip()
try:
creds = f"{account_sid}:{auth_token}"
encoded = base64.b64encode(creds.encode("ascii")).decode("ascii")
url = f"https://api.twilio.com/2010-04-01/Accounts/{account_sid}/Messages.json"
headers = {"Authorization": f"Basic {encoded}"}
async with aiohttp.ClientSession() as session:
form_data = aiohttp.FormData()
form_data.add_field("From", from_number)
form_data.add_field("To", chat_id)
form_data.add_field("Body", message)
async with session.post(url, data=form_data, headers=headers) as resp:
body = await resp.json()
if resp.status >= 400:
error_msg = body.get("message", str(body))
return {"error": f"Twilio API error ({resp.status}): {error_msg}"}
msg_sid = body.get("sid", "")
return {"success": True, "platform": "sms", "chat_id": chat_id, "message_id": msg_sid}
except Exception as e:
return {"error": f"SMS send failed: {e}"}
def _check_send_message():
"""Gate send_message on gateway running (always available on messaging platforms)."""
platform = os.getenv("HERMES_SESSION_PLATFORM", "")
if platform and platform != "local":
return True
try:
from gateway.status import is_gateway_running
return is_gateway_running()
except Exception:
return False
# --- Registry ---
from tools.registry import registry
registry.register(
name="send_message",
toolset="messaging",
schema=SEND_MESSAGE_SCHEMA,
handler=send_message_tool,
check_fn=_check_send_message,
emoji="📨",
)

View file

@ -0,0 +1,420 @@
#!/usr/bin/env python3
"""
Session Search Tool - Long-Term Conversation Recall
Searches past session transcripts in SQLite via FTS5, then summarizes the top
matching sessions using a cheap/fast model (same pattern as web_extract).
Returns focused summaries of past conversations rather than raw transcripts,
keeping the main model's context window clean.
Flow:
1. FTS5 search finds matching messages ranked by relevance
2. Groups by session, takes the top N unique sessions (default 3)
3. Loads each session's conversation, truncates to ~100k chars centered on matches
4. Sends to Gemini Flash with a focused summarization prompt
5. Returns per-session summaries with metadata
"""
import asyncio
import concurrent.futures
import json
import os
import logging
from typing import Dict, Any, List, Optional, Union
from agent.auxiliary_client import async_call_llm
MAX_SESSION_CHARS = 100_000
MAX_SUMMARY_TOKENS = 10000
def _format_timestamp(ts: Union[int, float, str, None]) -> str:
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
Returns "unknown" for None, str(ts) if conversion fails.
"""
if ts is None:
return "unknown"
try:
if isinstance(ts, (int, float)):
from datetime import datetime
dt = datetime.fromtimestamp(ts)
return dt.strftime("%B %d, %Y at %I:%M %p")
if isinstance(ts, str):
if ts.replace(".", "").replace("-", "").isdigit():
from datetime import datetime
dt = datetime.fromtimestamp(float(ts))
return dt.strftime("%B %d, %Y at %I:%M %p")
return ts
except (ValueError, OSError, OverflowError) as e:
# Log specific errors for debugging while gracefully handling edge cases
logging.debug("Failed to format timestamp %s: %s", ts, e, exc_info=True)
except Exception as e:
logging.debug("Unexpected error formatting timestamp %s: %s", ts, e, exc_info=True)
return str(ts)
def _format_conversation(messages: List[Dict[str, Any]]) -> str:
"""Format session messages into a readable transcript for summarization."""
parts = []
for msg in messages:
role = msg.get("role", "unknown").upper()
content = msg.get("content") or ""
tool_name = msg.get("tool_name")
if role == "TOOL" and tool_name:
# Truncate long tool outputs
if len(content) > 500:
content = content[:250] + "\n...[truncated]...\n" + content[-250:]
parts.append(f"[TOOL:{tool_name}]: {content}")
elif role == "ASSISTANT":
# Include tool call names if present
tool_calls = msg.get("tool_calls")
if tool_calls and isinstance(tool_calls, list):
tc_names = []
for tc in tool_calls:
if isinstance(tc, dict):
name = tc.get("name") or tc.get("function", {}).get("name", "?")
tc_names.append(name)
if tc_names:
parts.append(f"[ASSISTANT]: [Called: {', '.join(tc_names)}]")
if content:
parts.append(f"[ASSISTANT]: {content}")
else:
parts.append(f"[ASSISTANT]: {content}")
else:
parts.append(f"[{role}]: {content}")
return "\n\n".join(parts)
def _truncate_around_matches(
full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS
) -> str:
"""
Truncate a conversation transcript to max_chars, centered around
where the query terms appear. Keeps content near matches, trims the edges.
"""
if len(full_text) <= max_chars:
return full_text
# Find the first occurrence of any query term
query_terms = query.lower().split()
text_lower = full_text.lower()
first_match = len(full_text)
for term in query_terms:
pos = text_lower.find(term)
if pos != -1 and pos < first_match:
first_match = pos
if first_match == len(full_text):
# No match found, take from the start
first_match = 0
# Center the window around the first match
half = max_chars // 2
start = max(0, first_match - half)
end = min(len(full_text), start + max_chars)
if end - start < max_chars:
start = max(0, end - max_chars)
truncated = full_text[start:end]
prefix = "...[earlier conversation truncated]...\n\n" if start > 0 else ""
suffix = "\n\n...[later conversation truncated]..." if end < len(full_text) else ""
return prefix + truncated + suffix
async def _summarize_session(
conversation_text: str, query: str, session_meta: Dict[str, Any]
) -> Optional[str]:
"""Summarize a single session conversation focused on the search query."""
system_prompt = (
"You are reviewing a past conversation transcript to help recall what happened. "
"Summarize the conversation with a focus on the search topic. Include:\n"
"1. What the user asked about or wanted to accomplish\n"
"2. What actions were taken and what the outcomes were\n"
"3. Key decisions, solutions found, or conclusions reached\n"
"4. Any specific commands, files, URLs, or technical details that were important\n"
"5. Anything left unresolved or notable\n\n"
"Be thorough but concise. Preserve specific details (commands, paths, error messages) "
"that would be useful to recall. Write in past tense as a factual recap."
)
source = session_meta.get("source", "unknown")
started = _format_timestamp(session_meta.get("started_at"))
user_prompt = (
f"Search topic: {query}\n"
f"Session source: {source}\n"
f"Session date: {started}\n\n"
f"CONVERSATION TRANSCRIPT:\n{conversation_text}\n\n"
f"Summarize this conversation with focus on: {query}"
)
max_retries = 3
for attempt in range(max_retries):
try:
response = await async_call_llm(
task="session_search",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.1,
max_tokens=MAX_SUMMARY_TOKENS,
)
return response.choices[0].message.content.strip()
except RuntimeError:
logging.warning("No auxiliary model available for session summarization")
return None
except Exception as e:
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1))
else:
logging.warning(
"Session summarization failed after %d attempts: %s",
max_retries,
e,
exc_info=True,
)
return None
def session_search(
query: str,
role_filter: str = None,
limit: int = 3,
db=None,
current_session_id: str = None,
) -> str:
"""
Search past sessions and return focused summaries of matching conversations.
Uses FTS5 to find matches, then summarizes the top sessions with Gemini Flash.
The current session is excluded from results since the agent already has that context.
"""
if db is None:
return json.dumps({"success": False, "error": "Session database not available."}, ensure_ascii=False)
if not query or not query.strip():
return json.dumps({"success": False, "error": "Query cannot be empty."}, ensure_ascii=False)
query = query.strip()
limit = min(limit, 5) # Cap at 5 sessions to avoid excessive LLM calls
try:
# Parse role filter
role_list = None
if role_filter and role_filter.strip():
role_list = [r.strip() for r in role_filter.split(",") if r.strip()]
# FTS5 search -- get matches ranked by relevance
raw_results = db.search_messages(
query=query,
role_filter=role_list,
limit=50, # Get more matches to find unique sessions
offset=0,
)
if not raw_results:
return json.dumps({
"success": True,
"query": query,
"results": [],
"count": 0,
"message": "No matching sessions found.",
}, ensure_ascii=False)
# Resolve child sessions to their parent — delegation stores detailed
# content in child sessions, but the user's conversation is the parent.
def _resolve_to_parent(session_id: str) -> str:
"""Walk delegation chain to find the root parent session ID."""
visited = set()
sid = session_id
while sid and sid not in visited:
visited.add(sid)
try:
session = db.get_session(sid)
if not session:
break
parent = session.get("parent_session_id")
if parent:
sid = parent
else:
break
except Exception as e:
logging.debug(
"Error resolving parent for session %s: %s",
sid,
e,
exc_info=True,
)
break
return sid
current_lineage_root = (
_resolve_to_parent(current_session_id) if current_session_id else None
)
# Group by resolved (parent) session_id, dedup, skip the current
# session lineage. Compression and delegation create child sessions
# that still belong to the same active conversation.
seen_sessions = {}
for result in raw_results:
raw_sid = result["session_id"]
resolved_sid = _resolve_to_parent(raw_sid)
# Skip the current session lineage — the agent already has that
# context, even if older turns live in parent fragments.
if current_lineage_root and resolved_sid == current_lineage_root:
continue
if current_session_id and raw_sid == current_session_id:
continue
if resolved_sid not in seen_sessions:
result = dict(result)
result["session_id"] = resolved_sid
seen_sessions[resolved_sid] = result
if len(seen_sessions) >= limit:
break
# Prepare all sessions for parallel summarization
tasks = []
for session_id, match_info in seen_sessions.items():
try:
messages = db.get_messages_as_conversation(session_id)
if not messages:
continue
session_meta = db.get_session(session_id) or {}
conversation_text = _format_conversation(messages)
conversation_text = _truncate_around_matches(conversation_text, query)
tasks.append((session_id, match_info, conversation_text, session_meta))
except Exception as e:
logging.warning(
"Failed to prepare session %s: %s",
session_id,
e,
exc_info=True,
)
# Summarize all sessions in parallel
async def _summarize_all() -> List[Union[str, Exception]]:
"""Summarize all sessions in parallel."""
coros = [
_summarize_session(text, query, meta)
for _, _, text, meta in tasks
]
return await asyncio.gather(*coros, return_exceptions=True)
try:
asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
results = pool.submit(lambda: asyncio.run(_summarize_all())).result(timeout=60)
except RuntimeError:
# No event loop running, create a new one
results = asyncio.run(_summarize_all())
except concurrent.futures.TimeoutError:
logging.warning(
"Session summarization timed out after 60 seconds",
exc_info=True,
)
return json.dumps({
"success": False,
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
}, ensure_ascii=False)
summaries = []
for (session_id, match_info, _, _), result in zip(tasks, results):
if isinstance(result, Exception):
logging.warning(
"Failed to summarize session %s: %s",
session_id,
result,
exc_info=True,
)
continue
if result:
summaries.append({
"session_id": session_id,
"when": _format_timestamp(match_info.get("session_started")),
"source": match_info.get("source", "unknown"),
"model": match_info.get("model"),
"summary": result,
})
return json.dumps({
"success": True,
"query": query,
"results": summaries,
"count": len(summaries),
"sessions_searched": len(seen_sessions),
}, ensure_ascii=False)
except Exception as e:
logging.error("Session search failed: %s", e, exc_info=True)
return json.dumps({"success": False, "error": f"Search failed: {str(e)}"}, ensure_ascii=False)
def check_session_search_requirements() -> bool:
"""Requires SQLite state database and an auxiliary text model."""
try:
from hermes_state import DEFAULT_DB_PATH
return DEFAULT_DB_PATH.parent.exists()
except ImportError:
return False
SESSION_SEARCH_SCHEMA = {
"name": "session_search",
"description": (
"Search your long-term memory of past conversations. This is your recall -- "
"every past session is searchable, and this tool summarizes what happened.\n\n"
"USE THIS PROACTIVELY when:\n"
"- The user says 'we did this before', 'remember when', 'last time', 'as I mentioned'\n"
"- The user asks about a topic you worked on before but don't have in current context\n"
"- The user references a project, person, or concept that seems familiar but isn't in memory\n"
"- You want to check if you've solved a similar problem before\n"
"- The user asks 'what did we do about X?' or 'how did we fix Y?'\n\n"
"Don't hesitate to search when it is actually cross-session -- it's fast and cheap. "
"Better to search and confirm than to guess or ask the user to repeat themselves.\n\n"
"Search syntax: keywords joined with OR for broad recall (elevenlabs OR baseten OR funding), "
"phrases for exact match (\"docker networking\"), boolean (python NOT java), prefix (deploy*). "
"IMPORTANT: Use OR between keywords for best results — FTS5 defaults to AND which misses "
"sessions that only mention some terms. If a broad OR query returns nothing, try individual "
"keyword searches in parallel. Returns summaries of the top matching sessions."
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query — keywords, phrases, or boolean expressions to find in past sessions.",
},
"role_filter": {
"type": "string",
"description": "Optional: only search messages from specific roles (comma-separated). E.g. 'user,assistant' to skip tool outputs.",
},
"limit": {
"type": "integer",
"description": "Max sessions to summarize (default: 3, max: 5).",
"default": 3,
},
},
"required": ["query"],
},
}
# --- Registry ---
from tools.registry import registry
registry.register(
name="session_search",
toolset="session_search",
schema=SESSION_SEARCH_SCHEMA,
handler=lambda args, **kw: session_search(
query=args.get("query", ""),
role_filter=args.get("role_filter"),
limit=args.get("limit", 3),
db=kw.get("db"),
current_session_id=kw.get("current_session_id")),
check_fn=check_session_search_requirements,
emoji="🔍",
)

View file

@ -0,0 +1,664 @@
#!/usr/bin/env python3
"""
Skill Manager Tool -- Agent-Managed Skill Creation & Editing
Allows the agent to create, update, and delete skills, turning successful
approaches into reusable procedural knowledge. New skills are created in
~/.hermes/skills/. Existing skills (bundled, hub-installed, or user-created)
can be modified or deleted wherever they live.
Skills are the agent's procedural memory: they capture *how to do a specific
type of task* based on proven experience. General memory (MEMORY.md, USER.md) is
broad and declarative. Skills are narrow and actionable.
Actions:
create -- Create a new skill (SKILL.md + directory structure)
edit -- Replace the SKILL.md content of a user skill (full rewrite)
patch -- Targeted find-and-replace within SKILL.md or any supporting file
delete -- Remove a user skill entirely
write_file -- Add/overwrite a supporting file (reference, template, script, asset)
remove_file-- Remove a supporting file from a user skill
Directory layout for user skills:
~/.hermes/skills/
my-skill/
SKILL.md
references/
templates/
scripts/
assets/
category-name/
another-skill/
SKILL.md
"""
import json
import logging
import os
import re
import shutil
import tempfile
from pathlib import Path
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
# Import security scanner — agent-created skills get the same scrutiny as
# community hub installs.
try:
from tools.skills_guard import scan_skill, should_allow_install, format_scan_report
_GUARD_AVAILABLE = True
except ImportError:
_GUARD_AVAILABLE = False
def _security_scan_skill(skill_dir: Path) -> Optional[str]:
"""Scan a skill directory after write. Returns error string if blocked, else None."""
if not _GUARD_AVAILABLE:
return None
try:
result = scan_skill(skill_dir, source="agent-created")
allowed, reason = should_allow_install(result)
if allowed is False:
report = format_scan_report(result)
return f"Security scan blocked this skill ({reason}):\n{report}"
if allowed is None:
# "ask" — allow but include the warning so the user sees the findings
report = format_scan_report(result)
logger.warning("Agent-created skill has security findings: %s", reason)
# Don't block — return None to allow, but log the warning
return None
except Exception as e:
logger.warning("Security scan failed for %s: %s", skill_dir, e, exc_info=True)
return None
import yaml
# All skills live in ~/.hermes/skills/ (single source of truth)
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
SKILLS_DIR = HERMES_HOME / "skills"
MAX_NAME_LENGTH = 64
MAX_DESCRIPTION_LENGTH = 1024
# Characters allowed in skill names (filesystem-safe, URL-friendly)
VALID_NAME_RE = re.compile(r'^[a-z0-9][a-z0-9._-]*$')
# Subdirectories allowed for write_file/remove_file
ALLOWED_SUBDIRS = {"references", "templates", "scripts", "assets"}
def check_skill_manage_requirements() -> bool:
"""Skill management has no external requirements -- always available."""
return True
# =============================================================================
# Validation helpers
# =============================================================================
def _validate_name(name: str) -> Optional[str]:
"""Validate a skill name. Returns error message or None if valid."""
if not name:
return "Skill name is required."
if len(name) > MAX_NAME_LENGTH:
return f"Skill name exceeds {MAX_NAME_LENGTH} characters."
if not VALID_NAME_RE.match(name):
return (
f"Invalid skill name '{name}'. Use lowercase letters, numbers, "
f"hyphens, dots, and underscores. Must start with a letter or digit."
)
return None
def _validate_frontmatter(content: str) -> Optional[str]:
"""
Validate that SKILL.md content has proper frontmatter with required fields.
Returns error message or None if valid.
"""
if not content.strip():
return "Content cannot be empty."
if not content.startswith("---"):
return "SKILL.md must start with YAML frontmatter (---). See existing skills for format."
end_match = re.search(r'\n---\s*\n', content[3:])
if not end_match:
return "SKILL.md frontmatter is not closed. Ensure you have a closing '---' line."
yaml_content = content[3:end_match.start() + 3]
try:
parsed = yaml.safe_load(yaml_content)
except yaml.YAMLError as e:
return f"YAML frontmatter parse error: {e}"
if not isinstance(parsed, dict):
return "Frontmatter must be a YAML mapping (key: value pairs)."
if "name" not in parsed:
return "Frontmatter must include 'name' field."
if "description" not in parsed:
return "Frontmatter must include 'description' field."
if len(str(parsed["description"])) > MAX_DESCRIPTION_LENGTH:
return f"Description exceeds {MAX_DESCRIPTION_LENGTH} characters."
body = content[end_match.end() + 3:].strip()
if not body:
return "SKILL.md must have content after the frontmatter (instructions, procedures, etc.)."
return None
def _resolve_skill_dir(name: str, category: str = None) -> Path:
"""Build the directory path for a new skill, optionally under a category."""
if category:
return SKILLS_DIR / category / name
return SKILLS_DIR / name
def _find_skill(name: str) -> Optional[Dict[str, Any]]:
"""
Find a skill by name in ~/.hermes/skills/.
Returns {"path": Path} or None.
"""
if not SKILLS_DIR.exists():
return None
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
if skill_md.parent.name == name:
return {"path": skill_md.parent}
return None
def _validate_file_path(file_path: str) -> Optional[str]:
"""
Validate a file path for write_file/remove_file.
Must be under an allowed subdirectory and not escape the skill dir.
"""
if not file_path:
return "file_path is required."
normalized = Path(file_path)
# Prevent path traversal
if ".." in normalized.parts:
return "Path traversal ('..') is not allowed."
# Must be under an allowed subdirectory
if not normalized.parts or normalized.parts[0] not in ALLOWED_SUBDIRS:
allowed = ", ".join(sorted(ALLOWED_SUBDIRS))
return f"File must be under one of: {allowed}. Got: '{file_path}'"
# Must have a filename (not just a directory)
if len(normalized.parts) < 2:
return f"Provide a file path, not just a directory. Example: '{normalized.parts[0]}/myfile.md'"
return None
def _atomic_write_text(file_path: Path, content: str, encoding: str = "utf-8") -> None:
"""
Atomically write text content to a file.
Uses a temporary file in the same directory and os.replace() to ensure
the target file is never left in a partially-written state if the process
crashes or is interrupted.
Args:
file_path: Target file path
content: Content to write
encoding: Text encoding (default: utf-8)
"""
file_path.parent.mkdir(parents=True, exist_ok=True)
fd, temp_path = tempfile.mkstemp(
dir=str(file_path.parent),
prefix=f".{file_path.name}.tmp.",
suffix="",
)
try:
with os.fdopen(fd, "w", encoding=encoding) as f:
f.write(content)
os.replace(temp_path, file_path)
except Exception:
# Clean up temp file on error
try:
os.unlink(temp_path)
except OSError:
logger.error("Failed to remove temporary file %s during atomic write", temp_path, exc_info=True)
raise
# =============================================================================
# Core actions
# =============================================================================
def _create_skill(name: str, content: str, category: str = None) -> Dict[str, Any]:
"""Create a new user skill with SKILL.md content."""
# Validate name
err = _validate_name(name)
if err:
return {"success": False, "error": err}
# Validate content
err = _validate_frontmatter(content)
if err:
return {"success": False, "error": err}
# Check for name collisions across all directories
existing = _find_skill(name)
if existing:
return {
"success": False,
"error": f"A skill named '{name}' already exists at {existing['path']}."
}
# Create the skill directory
skill_dir = _resolve_skill_dir(name, category)
skill_dir.mkdir(parents=True, exist_ok=True)
# Write SKILL.md atomically
skill_md = skill_dir / "SKILL.md"
_atomic_write_text(skill_md, content)
# Security scan — roll back on block
scan_error = _security_scan_skill(skill_dir)
if scan_error:
shutil.rmtree(skill_dir, ignore_errors=True)
return {"success": False, "error": scan_error}
result = {
"success": True,
"message": f"Skill '{name}' created.",
"path": str(skill_dir.relative_to(SKILLS_DIR)),
"skill_md": str(skill_md),
}
if category:
result["category"] = category
result["hint"] = (
"To add reference files, templates, or scripts, use "
"skill_manage(action='write_file', name='{}', file_path='references/example.md', file_content='...')".format(name)
)
return result
def _edit_skill(name: str, content: str) -> Dict[str, Any]:
"""Replace the SKILL.md of any existing skill (full rewrite)."""
err = _validate_frontmatter(content)
if err:
return {"success": False, "error": err}
existing = _find_skill(name)
if not existing:
return {"success": False, "error": f"Skill '{name}' not found. Use skills_list() to see available skills."}
skill_md = existing["path"] / "SKILL.md"
# Back up original content for rollback
original_content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else None
_atomic_write_text(skill_md, content)
# Security scan — roll back on block
scan_error = _security_scan_skill(existing["path"])
if scan_error:
if original_content is not None:
_atomic_write_text(skill_md, original_content)
return {"success": False, "error": scan_error}
return {
"success": True,
"message": f"Skill '{name}' updated.",
"path": str(existing["path"]),
}
def _patch_skill(
name: str,
old_string: str,
new_string: str,
file_path: str = None,
replace_all: bool = False,
) -> Dict[str, Any]:
"""Targeted find-and-replace within a skill file.
Defaults to SKILL.md. Use file_path to patch a supporting file instead.
Requires a unique match unless replace_all is True.
"""
if not old_string:
return {"success": False, "error": "old_string is required for 'patch'."}
if new_string is None:
return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."}
existing = _find_skill(name)
if not existing:
return {"success": False, "error": f"Skill '{name}' not found."}
skill_dir = existing["path"]
if file_path:
# Patching a supporting file
err = _validate_file_path(file_path)
if err:
return {"success": False, "error": err}
target = skill_dir / file_path
else:
# Patching SKILL.md
target = skill_dir / "SKILL.md"
if not target.exists():
return {"success": False, "error": f"File not found: {target.relative_to(skill_dir)}"}
content = target.read_text(encoding="utf-8")
count = content.count(old_string)
if count == 0:
# Show a short preview of the file so the model can self-correct
preview = content[:500] + ("..." if len(content) > 500 else "")
return {
"success": False,
"error": "old_string not found in the file.",
"file_preview": preview,
}
if count > 1 and not replace_all:
return {
"success": False,
"error": (
f"old_string matched {count} times. Provide more surrounding context "
f"to make the match unique, or set replace_all=true to replace all occurrences."
),
"match_count": count,
}
new_content = content.replace(old_string, new_string) if replace_all else content.replace(old_string, new_string, 1)
# If patching SKILL.md, validate frontmatter is still intact
if not file_path:
err = _validate_frontmatter(new_content)
if err:
return {
"success": False,
"error": f"Patch would break SKILL.md structure: {err}",
}
original_content = content # for rollback
_atomic_write_text(target, new_content)
# Security scan — roll back on block
scan_error = _security_scan_skill(skill_dir)
if scan_error:
_atomic_write_text(target, original_content)
return {"success": False, "error": scan_error}
replacements = count if replace_all else 1
return {
"success": True,
"message": f"Patched {'SKILL.md' if not file_path else file_path} in skill '{name}' ({replacements} replacement{'s' if replacements > 1 else ''}).",
}
def _delete_skill(name: str) -> Dict[str, Any]:
"""Delete a skill."""
existing = _find_skill(name)
if not existing:
return {"success": False, "error": f"Skill '{name}' not found."}
skill_dir = existing["path"]
shutil.rmtree(skill_dir)
# Clean up empty category directories (don't remove SKILLS_DIR itself)
parent = skill_dir.parent
if parent != SKILLS_DIR and parent.exists() and not any(parent.iterdir()):
parent.rmdir()
return {
"success": True,
"message": f"Skill '{name}' deleted.",
}
def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
"""Add or overwrite a supporting file within any skill directory."""
err = _validate_file_path(file_path)
if err:
return {"success": False, "error": err}
if not file_content and file_content != "":
return {"success": False, "error": "file_content is required."}
existing = _find_skill(name)
if not existing:
return {"success": False, "error": f"Skill '{name}' not found. Create it first with action='create'."}
target = existing["path"] / file_path
target.parent.mkdir(parents=True, exist_ok=True)
# Back up for rollback
original_content = target.read_text(encoding="utf-8") if target.exists() else None
_atomic_write_text(target, file_content)
# Security scan — roll back on block
scan_error = _security_scan_skill(existing["path"])
if scan_error:
if original_content is not None:
_atomic_write_text(target, original_content)
else:
target.unlink(missing_ok=True)
return {"success": False, "error": scan_error}
return {
"success": True,
"message": f"File '{file_path}' written to skill '{name}'.",
"path": str(target),
}
def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
"""Remove a supporting file from any skill directory."""
err = _validate_file_path(file_path)
if err:
return {"success": False, "error": err}
existing = _find_skill(name)
if not existing:
return {"success": False, "error": f"Skill '{name}' not found."}
skill_dir = existing["path"]
target = skill_dir / file_path
if not target.exists():
# List what's actually there for the model to see
available = []
for subdir in ALLOWED_SUBDIRS:
d = skill_dir / subdir
if d.exists():
for f in d.rglob("*"):
if f.is_file():
available.append(str(f.relative_to(skill_dir)))
return {
"success": False,
"error": f"File '{file_path}' not found in skill '{name}'.",
"available_files": available if available else None,
}
target.unlink()
# Clean up empty subdirectories
parent = target.parent
if parent != skill_dir and parent.exists() and not any(parent.iterdir()):
parent.rmdir()
return {
"success": True,
"message": f"File '{file_path}' removed from skill '{name}'.",
}
# =============================================================================
# Main entry point
# =============================================================================
def skill_manage(
action: str,
name: str,
content: str = None,
category: str = None,
file_path: str = None,
file_content: str = None,
old_string: str = None,
new_string: str = None,
replace_all: bool = False,
) -> str:
"""
Manage user-created skills. Dispatches to the appropriate action handler.
Returns JSON string with results.
"""
if action == "create":
if not content:
return json.dumps({"success": False, "error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body)."}, ensure_ascii=False)
result = _create_skill(name, content, category)
elif action == "edit":
if not content:
return json.dumps({"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."}, ensure_ascii=False)
result = _edit_skill(name, content)
elif action == "patch":
if not old_string:
return json.dumps({"success": False, "error": "old_string is required for 'patch'. Provide the text to find."}, ensure_ascii=False)
if new_string is None:
return json.dumps({"success": False, "error": "new_string is required for 'patch'. Use empty string to delete matched text."}, ensure_ascii=False)
result = _patch_skill(name, old_string, new_string, file_path, replace_all)
elif action == "delete":
result = _delete_skill(name)
elif action == "write_file":
if not file_path:
return json.dumps({"success": False, "error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'"}, ensure_ascii=False)
if file_content is None:
return json.dumps({"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False)
result = _write_file(name, file_path, file_content)
elif action == "remove_file":
if not file_path:
return json.dumps({"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False)
result = _remove_file(name, file_path)
else:
result = {"success": False, "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file"}
return json.dumps(result, ensure_ascii=False)
# =============================================================================
# OpenAI Function-Calling Schema
# =============================================================================
SKILL_MANAGE_SCHEMA = {
"name": "skill_manage",
"description": (
"Manage skills (create, update, delete). Skills are your procedural "
"memory — reusable approaches for recurring task types. "
"New skills go to ~/.hermes/skills/; existing skills can be modified wherever they live.\n\n"
"Actions: create (full SKILL.md + optional category), "
"patch (old_string/new_string — preferred for fixes), "
"edit (full SKILL.md rewrite — major overhauls only), "
"delete, write_file, remove_file.\n\n"
"Create when: complex task succeeded (5+ calls), errors overcome, "
"user-corrected approach worked, non-trivial workflow discovered, "
"or user asks you to remember a procedure.\n"
"Update when: instructions stale/wrong, OS-specific failures, "
"missing steps or pitfalls found during use. "
"If you used a skill and hit issues not covered by it, patch it immediately.\n\n"
"After difficult/iterative tasks, offer to save as a skill. "
"Skip for simple one-offs. Confirm with user before creating/deleting.\n\n"
"Good skills: trigger conditions, numbered steps with exact commands, "
"pitfalls section, verification steps. Use skill_view() to see format examples."
),
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": ["create", "patch", "edit", "delete", "write_file", "remove_file"],
"description": "The action to perform."
},
"name": {
"type": "string",
"description": (
"Skill name (lowercase, hyphens/underscores, max 64 chars). "
"Must match an existing skill for patch/edit/delete/write_file/remove_file."
)
},
"content": {
"type": "string",
"description": (
"Full SKILL.md content (YAML frontmatter + markdown body). "
"Required for 'create' and 'edit'. For 'edit', read the skill "
"first with skill_view() and provide the complete updated text."
)
},
"old_string": {
"type": "string",
"description": (
"Text to find in the file (required for 'patch'). Must be unique "
"unless replace_all=true. Include enough surrounding context to "
"ensure uniqueness."
)
},
"new_string": {
"type": "string",
"description": (
"Replacement text (required for 'patch'). Can be empty string "
"to delete the matched text."
)
},
"replace_all": {
"type": "boolean",
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false)."
},
"category": {
"type": "string",
"description": (
"Optional category/domain for organizing the skill (e.g., 'devops', "
"'data-science', 'mlops'). Creates a subdirectory grouping. "
"Only used with 'create'."
)
},
"file_path": {
"type": "string",
"description": (
"Path to a supporting file within the skill directory. "
"For 'write_file'/'remove_file': required, must be under references/, "
"templates/, scripts/, or assets/. "
"For 'patch': optional, defaults to SKILL.md if omitted."
)
},
"file_content": {
"type": "string",
"description": "Content for the file. Required for 'write_file'."
},
},
"required": ["action", "name"],
},
}
# --- Registry ---
from tools.registry import registry
registry.register(
name="skill_manage",
toolset="skills",
schema=SKILL_MANAGE_SCHEMA,
handler=lambda args, **kw: skill_manage(
action=args.get("action", ""),
name=args.get("name", ""),
content=args.get("content"),
category=args.get("category"),
file_path=args.get("file_path"),
file_content=args.get("file_content"),
old_string=args.get("old_string"),
new_string=args.get("new_string"),
replace_all=args.get("replace_all", False)),
emoji="📝",
)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,287 @@
#!/usr/bin/env python3
"""
Skills Sync -- Manifest-based seeding and updating of bundled skills.
Copies bundled skills from the repo's skills/ directory into ~/.hermes/skills/
and uses a manifest to track which skills have been synced and their origin hash.
Manifest format (v2): each line is "skill_name:origin_hash" where origin_hash
is the MD5 of the bundled skill at the time it was last synced to the user dir.
Old v1 manifests (plain names without hashes) are auto-migrated.
Update logic:
- NEW skills (not in manifest): copied to user dir, origin hash recorded.
- EXISTING skills (in manifest, present in user dir):
* If user copy matches origin hash: user hasn't modified it → safe to
update from bundled if bundled changed. New origin hash recorded.
* If user copy differs from origin hash: user customized it SKIP.
- DELETED by user (in manifest, absent from user dir): respected, not re-added.
- REMOVED from bundled (in manifest, gone from repo): cleaned from manifest.
The manifest lives at ~/.hermes/skills/.bundled_manifest.
"""
import hashlib
import logging
import os
import shutil
from pathlib import Path
from typing import Dict, List, Tuple
logger = logging.getLogger(__name__)
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
SKILLS_DIR = HERMES_HOME / "skills"
MANIFEST_FILE = SKILLS_DIR / ".bundled_manifest"
def _get_bundled_dir() -> Path:
"""Locate the bundled skills/ directory in the repo."""
return Path(__file__).parent.parent / "skills"
def _read_manifest() -> Dict[str, str]:
"""
Read the manifest as a dict of {skill_name: origin_hash}.
Handles both v1 (plain names) and v2 (name:hash) formats.
v1 entries get an empty hash string which triggers migration on next sync.
"""
if not MANIFEST_FILE.exists():
return {}
try:
result = {}
for line in MANIFEST_FILE.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
if ":" in line:
# v2 format: name:hash
name, _, hash_val = line.partition(":")
result[name.strip()] = hash_val.strip()
else:
# v1 format: plain name — empty hash triggers migration
result[line] = ""
return result
except (OSError, IOError):
return {}
def _write_manifest(entries: Dict[str, str]):
"""Write the manifest file atomically in v2 format (name:hash).
Uses a temp file + os.replace() to avoid corruption if the process
crashes or is interrupted mid-write.
"""
import tempfile
MANIFEST_FILE.parent.mkdir(parents=True, exist_ok=True)
data = "\n".join(f"{name}:{hash_val}" for name, hash_val in sorted(entries.items())) + "\n"
try:
fd, tmp_path = tempfile.mkstemp(
dir=str(MANIFEST_FILE.parent),
prefix=".bundled_manifest_",
suffix=".tmp",
)
try:
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(data)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, MANIFEST_FILE)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
except Exception as e:
logger.debug("Failed to write skills manifest %s: %s", MANIFEST_FILE, e, exc_info=True)
def _discover_bundled_skills(bundled_dir: Path) -> List[Tuple[str, Path]]:
"""
Find all SKILL.md files in the bundled directory.
Returns list of (skill_name, skill_directory_path) tuples.
"""
skills = []
if not bundled_dir.exists():
return skills
for skill_md in bundled_dir.rglob("SKILL.md"):
path_str = str(skill_md)
if "/.git/" in path_str or "/.github/" in path_str or "/.hub/" in path_str:
continue
skill_dir = skill_md.parent
skill_name = skill_dir.name
skills.append((skill_name, skill_dir))
return skills
def _compute_relative_dest(skill_dir: Path, bundled_dir: Path) -> Path:
"""
Compute the destination path in SKILLS_DIR preserving the category structure.
e.g., bundled/skills/mlops/axolotl -> ~/.hermes/skills/mlops/axolotl
"""
rel = skill_dir.relative_to(bundled_dir)
return SKILLS_DIR / rel
def _dir_hash(directory: Path) -> str:
"""Compute a hash of all file contents in a directory for change detection."""
hasher = hashlib.md5()
try:
for fpath in sorted(directory.rglob("*")):
if fpath.is_file():
rel = fpath.relative_to(directory)
hasher.update(str(rel).encode("utf-8"))
hasher.update(fpath.read_bytes())
except (OSError, IOError):
pass
return hasher.hexdigest()
def sync_skills(quiet: bool = False) -> dict:
"""
Sync bundled skills into ~/.hermes/skills/ using the manifest.
Returns:
dict with keys: copied (list), updated (list), skipped (int),
user_modified (list), cleaned (list), total_bundled (int)
"""
bundled_dir = _get_bundled_dir()
if not bundled_dir.exists():
return {
"copied": [], "updated": [], "skipped": 0,
"user_modified": [], "cleaned": [], "total_bundled": 0,
}
SKILLS_DIR.mkdir(parents=True, exist_ok=True)
manifest = _read_manifest()
bundled_skills = _discover_bundled_skills(bundled_dir)
bundled_names = {name for name, _ in bundled_skills}
copied = []
updated = []
user_modified = []
skipped = 0
for skill_name, skill_src in bundled_skills:
dest = _compute_relative_dest(skill_src, bundled_dir)
bundled_hash = _dir_hash(skill_src)
if skill_name not in manifest:
# ── New skill — never offered before ──
try:
if dest.exists():
# User already has a skill with the same name — don't overwrite
skipped += 1
manifest[skill_name] = bundled_hash
else:
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(skill_src, dest)
copied.append(skill_name)
manifest[skill_name] = bundled_hash
if not quiet:
print(f" + {skill_name}")
except (OSError, IOError) as e:
if not quiet:
print(f" ! Failed to copy {skill_name}: {e}")
# Do NOT add to manifest — next sync should retry
elif dest.exists():
# ── Existing skill — in manifest AND on disk ──
origin_hash = manifest.get(skill_name, "")
user_hash = _dir_hash(dest)
if not origin_hash:
# v1 migration: no origin hash recorded. Set baseline from
# user's current copy so future syncs can detect modifications.
manifest[skill_name] = user_hash
if user_hash == bundled_hash:
skipped += 1 # already in sync
else:
# Can't tell if user modified or bundled changed — be safe
skipped += 1
continue
if user_hash != origin_hash:
# User modified this skill — don't overwrite their changes
user_modified.append(skill_name)
if not quiet:
print(f" ~ {skill_name} (user-modified, skipping)")
continue
# User copy matches origin — check if bundled has a newer version
if bundled_hash != origin_hash:
try:
# Move old copy to a backup so we can restore on failure
backup = dest.with_suffix(".bak")
shutil.move(str(dest), str(backup))
try:
shutil.copytree(skill_src, dest)
manifest[skill_name] = bundled_hash
updated.append(skill_name)
if not quiet:
print(f"{skill_name} (updated)")
# Remove backup after successful copy
shutil.rmtree(backup, ignore_errors=True)
except (OSError, IOError):
# Restore from backup
if backup.exists() and not dest.exists():
shutil.move(str(backup), str(dest))
raise
except (OSError, IOError) as e:
if not quiet:
print(f" ! Failed to update {skill_name}: {e}")
else:
skipped += 1 # bundled unchanged, user unchanged
else:
# ── In manifest but not on disk — user deleted it ──
skipped += 1
# Clean stale manifest entries (skills removed from bundled dir)
cleaned = sorted(set(manifest.keys()) - bundled_names)
for name in cleaned:
del manifest[name]
# Also copy DESCRIPTION.md files for categories (if not already present)
for desc_md in bundled_dir.rglob("DESCRIPTION.md"):
rel = desc_md.relative_to(bundled_dir)
dest_desc = SKILLS_DIR / rel
if not dest_desc.exists():
try:
dest_desc.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(desc_md, dest_desc)
except (OSError, IOError) as e:
logger.debug("Could not copy %s: %s", desc_md, e)
_write_manifest(manifest)
return {
"copied": copied,
"updated": updated,
"skipped": skipped,
"user_modified": user_modified,
"cleaned": cleaned,
"total_bundled": len(bundled_skills),
}
if __name__ == "__main__":
print("Syncing bundled skills into ~/.hermes/skills/ ...")
result = sync_skills(quiet=False)
parts = [
f"{len(result['copied'])} new",
f"{len(result['updated'])} updated",
f"{result['skipped']} unchanged",
]
if result["user_modified"]:
parts.append(f"{len(result['user_modified'])} user-modified (kept)")
if result["cleaned"]:
parts.append(f"{len(result['cleaned'])} cleaned from manifest")
print(f"\nDone: {', '.join(parts)}. {result['total_bundled']} total bundled.")

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,674 @@
"""Tirith pre-exec security scanning wrapper.
Runs the tirith binary as a subprocess to scan commands for content-level
threats (homograph URLs, pipe-to-interpreter, terminal injection, etc.).
Exit code is the verdict source of truth:
0 = allow, 1 = block, 2 = warn
JSON stdout enriches findings/summary but never overrides the verdict.
Operational failures (spawn error, timeout, unknown exit code) respect
the fail_open config setting. Programming errors propagate.
Auto-install: if tirith is not found on PATH or at the configured path,
it is automatically downloaded from GitHub releases to $HERMES_HOME/bin/tirith.
The download always verifies SHA-256 checksums. When cosign is available on
PATH, provenance verification (GitHub Actions workflow signature) is also
performed. If cosign is not installed, the download proceeds with SHA-256
verification only still secure via HTTPS + checksum, just without supply
chain provenance proof. Installation runs in a background thread so startup
never blocks.
"""
import hashlib
import json
import logging
import os
import platform
import shutil
import stat
import subprocess
import tarfile
import tempfile
import threading
import time
import urllib.request
logger = logging.getLogger(__name__)
_REPO = "sheeki03/tirith"
# Cosign provenance verification — pinned to the specific release workflow
_COSIGN_IDENTITY_REGEXP = f"^https://github.com/{_REPO}/\\.github/workflows/release\\.yml@refs/tags/v"
_COSIGN_ISSUER = "https://token.actions.githubusercontent.com"
# ---------------------------------------------------------------------------
# Config helpers
# ---------------------------------------------------------------------------
def _env_bool(key: str, default: bool) -> bool:
val = os.getenv(key)
if val is None:
return default
return val.lower() in ("1", "true", "yes")
def _env_int(key: str, default: int) -> int:
val = os.getenv(key)
if val is None:
return default
try:
return int(val)
except ValueError:
return default
def _load_security_config() -> dict:
"""Load security settings from config.yaml, with env var overrides."""
defaults = {
"tirith_enabled": True,
"tirith_path": "tirith",
"tirith_timeout": 5,
"tirith_fail_open": True,
}
try:
from hermes_cli.config import load_config
cfg = load_config().get("security", {}) or {}
except Exception:
cfg = {}
return {
"tirith_enabled": _env_bool("TIRITH_ENABLED", cfg.get("tirith_enabled", defaults["tirith_enabled"])),
"tirith_path": os.getenv("TIRITH_BIN", cfg.get("tirith_path", defaults["tirith_path"])),
"tirith_timeout": _env_int("TIRITH_TIMEOUT", cfg.get("tirith_timeout", defaults["tirith_timeout"])),
"tirith_fail_open": _env_bool("TIRITH_FAIL_OPEN", cfg.get("tirith_fail_open", defaults["tirith_fail_open"])),
}
# ---------------------------------------------------------------------------
# Auto-install
# ---------------------------------------------------------------------------
# Cached path after first resolution (avoids repeated shutil.which per command).
# _INSTALL_FAILED means "we tried and failed" — prevents retry on every command.
_resolved_path: str | None | bool = None
_INSTALL_FAILED = False # sentinel: distinct from "not yet tried"
_install_failure_reason: str = "" # reason tag when _resolved_path is _INSTALL_FAILED
# Background install thread coordination
_install_lock = threading.Lock()
_install_thread: threading.Thread | None = None
# Disk-persistent failure marker — avoids retry across process restarts
_MARKER_TTL = 86400 # 24 hours
def _get_hermes_home() -> str:
"""Return the Hermes home directory, respecting HERMES_HOME env var.
Matches the convention used throughout the codebase (hermes_cli.config,
cli.py, gateway/run.py, etc.) so tirith state stays inside the active
profile and tests get automatic isolation via conftest's HERMES_HOME
monkeypatch.
"""
return os.getenv("HERMES_HOME") or os.path.join(os.path.expanduser("~"), ".hermes")
def _failure_marker_path() -> str:
"""Return the path to the install-failure marker file."""
return os.path.join(_get_hermes_home(), ".tirith-install-failed")
def _read_failure_reason() -> str | None:
"""Read the failure reason from the disk marker.
Returns the reason string, or None if the marker doesn't exist or is
older than _MARKER_TTL.
"""
try:
p = _failure_marker_path()
mtime = os.path.getmtime(p)
if (time.time() - mtime) >= _MARKER_TTL:
return None
with open(p, "r") as f:
return f.read().strip()
except OSError:
return None
def _is_install_failed_on_disk() -> bool:
"""Check if a recent install failure was persisted to disk.
Returns False (allowing retry) when:
- No marker exists
- Marker is older than _MARKER_TTL (24h)
- Marker reason is 'cosign_missing' and cosign is now on PATH
"""
reason = _read_failure_reason()
if reason is None:
return False
if reason == "cosign_missing" and shutil.which("cosign"):
_clear_install_failed()
return False
return True
def _mark_install_failed(reason: str = ""):
"""Persist install failure to disk to avoid retry on next process.
Args:
reason: Short tag identifying the failure cause. Use "cosign_missing"
when cosign is not on PATH so the marker can be auto-cleared
once cosign becomes available.
"""
try:
p = _failure_marker_path()
os.makedirs(os.path.dirname(p), exist_ok=True)
with open(p, "w") as f:
f.write(reason)
except OSError:
pass
def _clear_install_failed():
"""Remove the failure marker after successful install."""
try:
os.unlink(_failure_marker_path())
except OSError:
pass
def _hermes_bin_dir() -> str:
"""Return $HERMES_HOME/bin, creating it if needed."""
d = os.path.join(_get_hermes_home(), "bin")
os.makedirs(d, exist_ok=True)
return d
def _detect_target() -> str | None:
"""Return the Rust target triple for the current platform, or None."""
system = platform.system()
machine = platform.machine().lower()
if system == "Darwin":
plat = "apple-darwin"
elif system == "Linux":
plat = "unknown-linux-gnu"
else:
return None
if machine in ("x86_64", "amd64"):
arch = "x86_64"
elif machine in ("aarch64", "arm64"):
arch = "aarch64"
else:
return None
return f"{arch}-{plat}"
def _download_file(url: str, dest: str, timeout: int = 10):
"""Download a URL to a local file."""
req = urllib.request.Request(url)
token = os.getenv("GITHUB_TOKEN")
if token:
req.add_header("Authorization", f"token {token}")
with urllib.request.urlopen(req, timeout=timeout) as resp, open(dest, "wb") as f:
shutil.copyfileobj(resp, f)
def _verify_cosign(checksums_path: str, sig_path: str, cert_path: str) -> bool | None:
"""Verify cosign provenance signature on checksums.txt.
Returns:
True cosign verified successfully
False cosign found but verification failed
None cosign not available (not on PATH, or execution failed)
The caller treats both False and None as "abort auto-install" only
True allows the install to proceed.
"""
cosign = shutil.which("cosign")
if not cosign:
logger.info("cosign not found on PATH")
return None
try:
result = subprocess.run(
[cosign, "verify-blob",
"--certificate", cert_path,
"--signature", sig_path,
"--certificate-identity-regexp", _COSIGN_IDENTITY_REGEXP,
"--certificate-oidc-issuer", _COSIGN_ISSUER,
checksums_path],
capture_output=True,
text=True,
timeout=15,
)
if result.returncode == 0:
logger.info("cosign provenance verification passed")
return True
else:
logger.warning("cosign verification failed (exit %d): %s",
result.returncode, result.stderr.strip())
return False
except (OSError, subprocess.TimeoutExpired) as exc:
logger.warning("cosign execution failed: %s", exc)
return None
def _verify_checksum(archive_path: str, checksums_path: str, archive_name: str) -> bool:
"""Verify SHA-256 of the archive against checksums.txt."""
expected = None
with open(checksums_path) as f:
for line in f:
# Format: "<hash> <filename>"
parts = line.strip().split(" ", 1)
if len(parts) == 2 and parts[1] == archive_name:
expected = parts[0]
break
if not expected:
logger.warning("No checksum entry for %s", archive_name)
return False
sha = hashlib.sha256()
with open(archive_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha.update(chunk)
actual = sha.hexdigest()
if actual != expected:
logger.warning("Checksum mismatch: expected %s, got %s", expected, actual)
return False
return True
def _install_tirith(*, log_failures: bool = True) -> tuple[str | None, str]:
"""Download and install tirith to $HERMES_HOME/bin/tirith.
Verifies provenance via cosign and SHA-256 checksum.
Returns (installed_path, failure_reason). On success failure_reason is "".
failure_reason is a short tag used by the disk marker to decide if the
failure is retryable (e.g. "cosign_missing" clears when cosign appears).
"""
log = logger.warning if log_failures else logger.debug
target = _detect_target()
if not target:
logger.info("tirith auto-install: unsupported platform %s/%s",
platform.system(), platform.machine())
return None, "unsupported_platform"
archive_name = f"tirith-{target}.tar.gz"
base_url = f"https://github.com/{_REPO}/releases/latest/download"
tmpdir = tempfile.mkdtemp(prefix="tirith-install-")
try:
archive_path = os.path.join(tmpdir, archive_name)
checksums_path = os.path.join(tmpdir, "checksums.txt")
sig_path = os.path.join(tmpdir, "checksums.txt.sig")
cert_path = os.path.join(tmpdir, "checksums.txt.pem")
logger.info("tirith not found — downloading latest release for %s...", target)
try:
_download_file(f"{base_url}/{archive_name}", archive_path)
_download_file(f"{base_url}/checksums.txt", checksums_path)
except Exception as exc:
log("tirith download failed: %s", exc)
return None, "download_failed"
# Cosign provenance verification — preferred but not mandatory.
# When cosign is available, we verify that the release was produced
# by the expected GitHub Actions workflow (full supply chain proof).
# Without cosign, SHA-256 checksum + HTTPS still provides integrity
# and transport-level authenticity.
cosign_verified = False
if shutil.which("cosign"):
try:
_download_file(f"{base_url}/checksums.txt.sig", sig_path)
_download_file(f"{base_url}/checksums.txt.pem", cert_path)
except Exception as exc:
logger.info("cosign artifacts unavailable (%s), proceeding with SHA-256 only", exc)
else:
cosign_result = _verify_cosign(checksums_path, sig_path, cert_path)
if cosign_result is True:
cosign_verified = True
elif cosign_result is False:
# Verification explicitly rejected — abort, the release
# may have been tampered with.
log("tirith install aborted: cosign provenance verification failed")
return None, "cosign_verification_failed"
else:
# None = execution failure (timeout/OSError) — proceed
# with SHA-256 only since cosign itself is broken.
logger.info("cosign execution failed, proceeding with SHA-256 only")
else:
logger.info("cosign not on PATH — installing tirith with SHA-256 verification only "
"(install cosign for full supply chain verification)")
if not _verify_checksum(archive_path, checksums_path, archive_name):
return None, "checksum_failed"
with tarfile.open(archive_path, "r:gz") as tar:
# Extract only the tirith binary (safety: reject paths with ..)
for member in tar.getmembers():
if member.name == "tirith" or member.name.endswith("/tirith"):
if ".." in member.name:
continue
member.name = "tirith"
tar.extract(member, tmpdir)
break
else:
log("tirith binary not found in archive")
return None, "binary_not_in_archive"
src = os.path.join(tmpdir, "tirith")
dest = os.path.join(_hermes_bin_dir(), "tirith")
shutil.move(src, dest)
os.chmod(dest, os.stat(dest).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
verification = "cosign + SHA-256" if cosign_verified else "SHA-256 only"
logger.info("tirith installed to %s (%s)", dest, verification)
return dest, ""
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def _is_explicit_path(configured_path: str) -> bool:
"""Return True if the user explicitly configured a non-default tirith path."""
return configured_path != "tirith"
def _resolve_tirith_path(configured_path: str) -> str:
"""Resolve the tirith binary path, auto-installing if necessary.
If the user explicitly set a path (anything other than the bare "tirith"
default), that path is authoritative we never fall through to
auto-download a different binary.
For the default "tirith":
1. PATH lookup via shutil.which
2. $HERMES_HOME/bin/tirith (previously auto-installed)
3. Auto-install from GitHub releases $HERMES_HOME/bin/tirith
Failed installs are cached for the process lifetime (and persisted to
disk for 24h) to avoid repeated network attempts.
"""
global _resolved_path, _install_failure_reason
# Fast path: successfully resolved on a previous call.
if _resolved_path is not None and _resolved_path is not _INSTALL_FAILED:
return _resolved_path
expanded = os.path.expanduser(configured_path)
explicit = _is_explicit_path(configured_path)
install_failed = _resolved_path is _INSTALL_FAILED
# Explicit path: check it and stop. Never auto-download a replacement.
if explicit:
if os.path.isfile(expanded) and os.access(expanded, os.X_OK):
_resolved_path = expanded
return expanded
# Also try shutil.which in case it's a bare name on PATH
found = shutil.which(expanded)
if found:
_resolved_path = found
return found
logger.warning("Configured tirith path %r not found; scanning disabled", configured_path)
_resolved_path = _INSTALL_FAILED
_install_failure_reason = "explicit_path_missing"
return expanded
# Default "tirith" — always re-run cheap local checks so a manual
# install is picked up even after a previous network failure (P2 fix:
# long-lived gateway/CLI recovers without restart).
found = shutil.which("tirith")
if found:
_resolved_path = found
_install_failure_reason = ""
_clear_install_failed()
return found
hermes_bin = os.path.join(_hermes_bin_dir(), "tirith")
if os.path.isfile(hermes_bin) and os.access(hermes_bin, os.X_OK):
_resolved_path = hermes_bin
_install_failure_reason = ""
_clear_install_failed()
return hermes_bin
# Local checks failed. If a previous install attempt already failed,
# skip the network retry — UNLESS the failure was "cosign_missing" and
# cosign is now available (retryable cause resolved in-process).
if install_failed:
if _install_failure_reason == "cosign_missing" and shutil.which("cosign"):
# Retryable cause resolved — clear sentinel and fall through to retry
_resolved_path = None
_install_failure_reason = ""
_clear_install_failed()
install_failed = False
else:
return expanded
# If a background install thread is running, don't start a parallel one —
# return the configured path; the OSError handler in check_command_security
# will apply fail_open until the thread finishes.
if _install_thread is not None and _install_thread.is_alive():
return expanded
# Check disk failure marker before attempting network download.
# Preserve the marker's real reason so in-memory retry logic can
# detect retryable causes (e.g. cosign_missing) without restart.
disk_reason = _read_failure_reason()
if disk_reason is not None and _is_install_failed_on_disk():
_resolved_path = _INSTALL_FAILED
_install_failure_reason = disk_reason
return expanded
installed, reason = _install_tirith()
if installed:
_resolved_path = installed
_install_failure_reason = ""
_clear_install_failed()
return installed
# Install failed — cache the miss and persist reason to disk
_resolved_path = _INSTALL_FAILED
_install_failure_reason = reason
_mark_install_failed(reason)
return expanded
def _background_install(*, log_failures: bool = True):
"""Background thread target: download and install tirith."""
global _resolved_path, _install_failure_reason
with _install_lock:
# Double-check after acquiring lock (another thread may have resolved)
if _resolved_path is not None:
return
# Re-check local paths (may have been installed by another process)
found = shutil.which("tirith")
if found:
_resolved_path = found
_install_failure_reason = ""
return
hermes_bin = os.path.join(_hermes_bin_dir(), "tirith")
if os.path.isfile(hermes_bin) and os.access(hermes_bin, os.X_OK):
_resolved_path = hermes_bin
_install_failure_reason = ""
return
installed, reason = _install_tirith(log_failures=log_failures)
if installed:
_resolved_path = installed
_install_failure_reason = ""
_clear_install_failed()
else:
_resolved_path = _INSTALL_FAILED
_install_failure_reason = reason
_mark_install_failed(reason)
def ensure_installed(*, log_failures: bool = True):
"""Ensure tirith is available, downloading in background if needed.
Quick PATH/local checks are synchronous; network download runs in a
daemon thread so startup never blocks. Safe to call multiple times.
Returns the resolved path immediately if available, or None.
"""
global _resolved_path, _install_thread, _install_failure_reason
cfg = _load_security_config()
if not cfg["tirith_enabled"]:
return None
# Already resolved from a previous call
if _resolved_path is not None and _resolved_path is not _INSTALL_FAILED:
path = _resolved_path
if os.path.isfile(path) and os.access(path, os.X_OK):
return path
return None
configured_path = cfg["tirith_path"]
explicit = _is_explicit_path(configured_path)
expanded = os.path.expanduser(configured_path)
# Explicit path: synchronous check only, no download
if explicit:
if os.path.isfile(expanded) and os.access(expanded, os.X_OK):
_resolved_path = expanded
return expanded
found = shutil.which(expanded)
if found:
_resolved_path = found
return found
_resolved_path = _INSTALL_FAILED
_install_failure_reason = "explicit_path_missing"
return None
# Default "tirith" — quick local checks first (no network)
found = shutil.which("tirith")
if found:
_resolved_path = found
_install_failure_reason = ""
_clear_install_failed()
return found
hermes_bin = os.path.join(_hermes_bin_dir(), "tirith")
if os.path.isfile(hermes_bin) and os.access(hermes_bin, os.X_OK):
_resolved_path = hermes_bin
_install_failure_reason = ""
_clear_install_failed()
return hermes_bin
# If previously failed in-memory, check if the cause is now resolved
if _resolved_path is _INSTALL_FAILED:
if _install_failure_reason == "cosign_missing" and shutil.which("cosign"):
_resolved_path = None
_install_failure_reason = ""
_clear_install_failed()
else:
return None
# Check disk failure marker (skip network attempt for 24h, unless
# the cosign_missing reason was resolved — handled by _is_install_failed_on_disk).
# Preserve the marker's real reason for in-memory retry logic.
disk_reason = _read_failure_reason()
if disk_reason is not None and _is_install_failed_on_disk():
_resolved_path = _INSTALL_FAILED
_install_failure_reason = disk_reason
return None
# Need to download — launch background thread so startup doesn't block
if _install_thread is None or not _install_thread.is_alive():
_install_thread = threading.Thread(
target=_background_install,
kwargs={"log_failures": log_failures},
daemon=True,
)
_install_thread.start()
return None # Not available yet; commands will fail-open until ready
# ---------------------------------------------------------------------------
# Main API
# ---------------------------------------------------------------------------
_MAX_FINDINGS = 50
_MAX_SUMMARY_LEN = 500
def check_command_security(command: str) -> dict:
"""Run tirith security scan on a command.
Exit code determines action (0=allow, 1=block, 2=warn). JSON enriches
findings/summary. Spawn failures and timeouts respect fail_open config.
Programming errors propagate.
Returns:
{"action": "allow"|"warn"|"block", "findings": [...], "summary": str}
"""
cfg = _load_security_config()
if not cfg["tirith_enabled"]:
return {"action": "allow", "findings": [], "summary": ""}
tirith_path = _resolve_tirith_path(cfg["tirith_path"])
timeout = cfg["tirith_timeout"]
fail_open = cfg["tirith_fail_open"]
try:
result = subprocess.run(
[tirith_path, "check", "--json", "--non-interactive",
"--shell", "posix", "--", command],
capture_output=True,
text=True,
timeout=timeout,
)
except OSError as exc:
# Covers FileNotFoundError, PermissionError, exec format error
logger.warning("tirith spawn failed: %s", exc)
if fail_open:
return {"action": "allow", "findings": [], "summary": f"tirith unavailable: {exc}"}
return {"action": "block", "findings": [], "summary": f"tirith spawn failed (fail-closed): {exc}"}
except subprocess.TimeoutExpired:
logger.warning("tirith timed out after %ds", timeout)
if fail_open:
return {"action": "allow", "findings": [], "summary": f"tirith timed out ({timeout}s)"}
return {"action": "block", "findings": [], "summary": f"tirith timed out (fail-closed)"}
# Map exit code to action
exit_code = result.returncode
if exit_code == 0:
action = "allow"
elif exit_code == 1:
action = "block"
elif exit_code == 2:
action = "warn"
else:
# Unknown exit code — respect fail_open
logger.warning("tirith returned unexpected exit code %d", exit_code)
if fail_open:
return {"action": "allow", "findings": [], "summary": f"tirith exit code {exit_code} (fail-open)"}
return {"action": "block", "findings": [], "summary": f"tirith exit code {exit_code} (fail-closed)"}
# Parse JSON for enrichment (never overrides the exit code verdict)
findings = []
summary = ""
try:
data = json.loads(result.stdout) if result.stdout.strip() else {}
raw_findings = data.get("findings", [])
findings = raw_findings[:_MAX_FINDINGS]
summary = (data.get("summary", "") or "")[:_MAX_SUMMARY_LEN]
except (json.JSONDecodeError, AttributeError):
# JSON parse failure degrades findings/summary, not the verdict
logger.debug("tirith JSON parse failed, using exit code only")
if action == "block":
summary = "security issue detected (details unavailable)"
elif action == "warn":
summary = "security warning detected (details unavailable)"
return {"action": action, "findings": findings, "summary": summary}

View file

@ -0,0 +1,268 @@
#!/usr/bin/env python3
"""
Todo Tool Module - Planning & Task Management
Provides an in-memory task list the agent uses to decompose complex tasks,
track progress, and maintain focus across long conversations. The state
lives on the AIAgent instance (one per session) and is re-injected into
the conversation after context compression events.
Design:
- Single `todo` tool: provide `todos` param to write, omit to read
- Every call returns the full current list
- No system prompt mutation, no tool response modification
- Behavioral guidance lives entirely in the tool schema description
"""
import json
from typing import Dict, Any, List, Optional
# Valid status values for todo items
VALID_STATUSES = {"pending", "in_progress", "completed", "cancelled"}
class TodoStore:
"""
In-memory todo list. One instance per AIAgent (one per session).
Items are ordered -- list position is priority. Each item has:
- id: unique string identifier (agent-chosen)
- content: task description
- status: pending | in_progress | completed | cancelled
"""
def __init__(self):
self._items: List[Dict[str, str]] = []
def write(self, todos: List[Dict[str, Any]], merge: bool = False) -> List[Dict[str, str]]:
"""
Write todos. Returns the full current list after writing.
Args:
todos: list of {id, content, status} dicts
merge: if False, replace the entire list. If True, update
existing items by id and append new ones.
"""
if not merge:
# Replace mode: new list entirely
self._items = [self._validate(t) for t in todos]
else:
# Merge mode: update existing items by id, append new ones
existing = {item["id"]: item for item in self._items}
for t in todos:
item_id = str(t.get("id", "")).strip()
if not item_id:
continue # Can't merge without an id
if item_id in existing:
# Update only the fields the LLM actually provided
if "content" in t and t["content"]:
existing[item_id]["content"] = str(t["content"]).strip()
if "status" in t and t["status"]:
status = str(t["status"]).strip().lower()
if status in VALID_STATUSES:
existing[item_id]["status"] = status
else:
# New item -- validate fully and append to end
validated = self._validate(t)
existing[validated["id"]] = validated
self._items.append(validated)
# Rebuild _items preserving order for existing items
seen = set()
rebuilt = []
for item in self._items:
current = existing.get(item["id"], item)
if current["id"] not in seen:
rebuilt.append(current)
seen.add(current["id"])
self._items = rebuilt
return self.read()
def read(self) -> List[Dict[str, str]]:
"""Return a copy of the current list."""
return [item.copy() for item in self._items]
def has_items(self) -> bool:
"""Check if there are any items in the list."""
return len(self._items) > 0
def format_for_injection(self) -> Optional[str]:
"""
Render the todo list for post-compression injection.
Returns a human-readable string to append to the compressed
message history, or None if the list is empty.
"""
if not self._items:
return None
# Status markers for compact display
markers = {
"completed": "[x]",
"in_progress": "[>]",
"pending": "[ ]",
"cancelled": "[~]",
}
# Only inject pending/in_progress items — completed/cancelled ones
# cause the model to re-do finished work after compression.
active_items = [
item for item in self._items
if item["status"] in ("pending", "in_progress")
]
if not active_items:
return None
lines = ["[Your active task list was preserved across context compression]"]
for item in active_items:
marker = markers.get(item["status"], "[?]")
lines.append(f"- {marker} {item['id']}. {item['content']} ({item['status']})")
return "\n".join(lines)
@staticmethod
def _validate(item: Dict[str, Any]) -> Dict[str, str]:
"""
Validate and normalize a todo item.
Ensures required fields exist and status is valid.
Returns a clean dict with only {id, content, status}.
"""
item_id = str(item.get("id", "")).strip()
if not item_id:
item_id = "?"
content = str(item.get("content", "")).strip()
if not content:
content = "(no description)"
status = str(item.get("status", "pending")).strip().lower()
if status not in VALID_STATUSES:
status = "pending"
return {"id": item_id, "content": content, "status": status}
def todo_tool(
todos: Optional[List[Dict[str, Any]]] = None,
merge: bool = False,
store: Optional[TodoStore] = None,
) -> str:
"""
Single entry point for the todo tool. Reads or writes depending on params.
Args:
todos: if provided, write these items. If None, read current list.
merge: if True, update by id. If False (default), replace entire list.
store: the TodoStore instance from the AIAgent.
Returns:
JSON string with the full current list and summary metadata.
"""
if store is None:
return json.dumps({"error": "TodoStore not initialized"}, ensure_ascii=False)
if todos is not None:
items = store.write(todos, merge)
else:
items = store.read()
# Build summary counts
pending = sum(1 for i in items if i["status"] == "pending")
in_progress = sum(1 for i in items if i["status"] == "in_progress")
completed = sum(1 for i in items if i["status"] == "completed")
cancelled = sum(1 for i in items if i["status"] == "cancelled")
return json.dumps({
"todos": items,
"summary": {
"total": len(items),
"pending": pending,
"in_progress": in_progress,
"completed": completed,
"cancelled": cancelled,
},
}, ensure_ascii=False)
def check_todo_requirements() -> bool:
"""Todo tool has no external requirements -- always available."""
return True
# =============================================================================
# OpenAI Function-Calling Schema
# =============================================================================
# Behavioral guidance is baked into the description so it's part of the
# static tool schema (cached, never changes mid-conversation).
TODO_SCHEMA = {
"name": "todo",
"description": (
"Manage your task list for the current session. Use for complex tasks "
"with 3+ steps or when the user provides multiple tasks. "
"Call with no parameters to read the current list.\n\n"
"Writing:\n"
"- Provide 'todos' array to create/update items\n"
"- merge=false (default): replace the entire list with a fresh plan\n"
"- merge=true: update existing items by id, add any new ones\n\n"
"Each item: {id: string, content: string, "
"status: pending|in_progress|completed|cancelled}\n"
"List order is priority. Only ONE item in_progress at a time.\n"
"Mark items completed immediately when done. If something fails, "
"cancel it and add a revised item.\n\n"
"Always returns the full current list."
),
"parameters": {
"type": "object",
"properties": {
"todos": {
"type": "array",
"description": "Task items to write. Omit to read current list.",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique item identifier"
},
"content": {
"type": "string",
"description": "Task description"
},
"status": {
"type": "string",
"enum": ["pending", "in_progress", "completed", "cancelled"],
"description": "Current status"
}
},
"required": ["id", "content", "status"]
}
},
"merge": {
"type": "boolean",
"description": (
"true: update existing items by id, add new ones. "
"false (default): replace the entire list."
),
"default": False
}
},
"required": []
}
}
# --- Registry ---
from tools.registry import registry
registry.register(
name="todo",
toolset="todo",
schema=TODO_SCHEMA,
handler=lambda args, **kw: todo_tool(
todos=args.get("todos"), merge=args.get("merge", False), store=kw.get("store")),
check_fn=check_todo_requirements,
emoji="📋",
)

View file

@ -0,0 +1,554 @@
#!/usr/bin/env python3
"""
Transcription Tools Module
Provides speech-to-text transcription with three providers:
- **local** (default, free) faster-whisper running locally, no API key needed.
Auto-downloads the model (~150 MB for ``base``) on first use.
- **groq** (free tier) Groq Whisper API, requires ``GROQ_API_KEY``.
- **openai** (paid) OpenAI Whisper API, requires ``VOICE_TOOLS_OPENAI_KEY``.
Used by the messaging gateway to automatically transcribe voice messages
sent by users on Telegram, Discord, WhatsApp, Slack, and Signal.
Supported input formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, ogg
Usage::
from tools.transcription_tools import transcribe_audio
result = transcribe_audio("/path/to/audio.ogg")
if result["success"]:
print(result["transcript"])
"""
import logging
import os
import shlex
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Optional imports — graceful degradation
# ---------------------------------------------------------------------------
import importlib.util as _ilu
_HAS_FASTER_WHISPER = _ilu.find_spec("faster_whisper") is not None
_HAS_OPENAI = _ilu.find_spec("openai") is not None
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
DEFAULT_PROVIDER = "local"
DEFAULT_LOCAL_MODEL = "base"
DEFAULT_LOCAL_STT_LANGUAGE = "en"
DEFAULT_STT_MODEL = os.getenv("STT_OPENAI_MODEL", "whisper-1")
DEFAULT_GROQ_STT_MODEL = os.getenv("STT_GROQ_MODEL", "whisper-large-v3-turbo")
LOCAL_STT_COMMAND_ENV = "HERMES_LOCAL_STT_COMMAND"
LOCAL_STT_LANGUAGE_ENV = "HERMES_LOCAL_STT_LANGUAGE"
COMMON_LOCAL_BIN_DIRS = ("/opt/homebrew/bin", "/usr/local/bin")
GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
OPENAI_BASE_URL = os.getenv("STT_OPENAI_BASE_URL", "https://api.openai.com/v1")
SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg"}
LOCAL_NATIVE_AUDIO_FORMATS = {".wav", ".aiff", ".aif"}
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB
# Known model sets for auto-correction
OPENAI_MODELS = {"whisper-1", "gpt-4o-mini-transcribe", "gpt-4o-transcribe"}
GROQ_MODELS = {"whisper-large-v3", "whisper-large-v3-turbo", "distil-whisper-large-v3-en"}
# Singleton for the local model — loaded once, reused across calls
_local_model: Optional[object] = None
_local_model_name: Optional[str] = None
# ---------------------------------------------------------------------------
# Config helpers
# ---------------------------------------------------------------------------
def get_stt_model_from_config() -> Optional[str]:
"""Read the STT model name from ~/.hermes/config.yaml.
Returns the value of ``stt.model`` if present, otherwise ``None``.
Silently returns ``None`` on any error (missing file, bad YAML, etc.).
"""
try:
import yaml
cfg_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml"
if cfg_path.exists():
with open(cfg_path) as f:
data = yaml.safe_load(f) or {}
return data.get("stt", {}).get("model")
except Exception:
pass
return None
def _load_stt_config() -> dict:
"""Load the ``stt`` section from user config, falling back to defaults."""
try:
from hermes_cli.config import load_config
return load_config().get("stt", {})
except Exception:
return {}
def is_stt_enabled(stt_config: Optional[dict] = None) -> bool:
"""Return whether STT is enabled in config."""
if stt_config is None:
stt_config = _load_stt_config()
enabled = stt_config.get("enabled", True)
if isinstance(enabled, str):
return enabled.strip().lower() in ("true", "1", "yes", "on")
if enabled is None:
return True
return bool(enabled)
def _resolve_openai_api_key() -> str:
"""Prefer the voice-tools key, but fall back to the normal OpenAI key."""
return os.getenv("VOICE_TOOLS_OPENAI_KEY", "") or os.getenv("OPENAI_API_KEY", "")
def _find_binary(binary_name: str) -> Optional[str]:
"""Find a local binary, checking common Homebrew/local prefixes as well as PATH."""
for directory in COMMON_LOCAL_BIN_DIRS:
candidate = Path(directory) / binary_name
if candidate.exists() and os.access(candidate, os.X_OK):
return str(candidate)
return shutil.which(binary_name)
def _find_ffmpeg_binary() -> Optional[str]:
return _find_binary("ffmpeg")
def _find_whisper_binary() -> Optional[str]:
return _find_binary("whisper")
def _get_local_command_template() -> Optional[str]:
configured = os.getenv(LOCAL_STT_COMMAND_ENV, "").strip()
if configured:
return configured
whisper_binary = _find_whisper_binary()
if whisper_binary:
quoted_binary = shlex.quote(whisper_binary)
return (
f"{quoted_binary} {{input_path}} --model {{model}} --output_format txt "
"--output_dir {output_dir} --language {language}"
)
return None
def _has_local_command() -> bool:
return _get_local_command_template() is not None
def _normalize_local_command_model(model_name: Optional[str]) -> str:
if not model_name or model_name in OPENAI_MODELS or model_name in GROQ_MODELS:
return DEFAULT_LOCAL_MODEL
return model_name
def _get_provider(stt_config: dict) -> str:
"""Determine which STT provider to use.
When ``stt.provider`` is explicitly set in config, that choice is
honoured no silent cloud fallback. When no provider is configured,
auto-detect tries: local > groq (free) > openai (paid).
"""
if not is_stt_enabled(stt_config):
return "none"
explicit = "provider" in stt_config
provider = stt_config.get("provider", DEFAULT_PROVIDER)
# --- Explicit provider: respect the user's choice ----------------------
if explicit:
if provider == "local":
if _HAS_FASTER_WHISPER:
return "local"
if _has_local_command():
return "local_command"
logger.warning(
"STT provider 'local' configured but unavailable "
"(install faster-whisper or set HERMES_LOCAL_STT_COMMAND)"
)
return "none"
if provider == "local_command":
if _has_local_command():
return "local_command"
if _HAS_FASTER_WHISPER:
logger.info("Local STT command unavailable, using local faster-whisper")
return "local"
logger.warning(
"STT provider 'local_command' configured but unavailable"
)
return "none"
if provider == "groq":
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
return "groq"
logger.warning(
"STT provider 'groq' configured but GROQ_API_KEY not set"
)
return "none"
if provider == "openai":
if _HAS_OPENAI and _resolve_openai_api_key():
return "openai"
logger.warning(
"STT provider 'openai' configured but no API key available"
)
return "none"
return provider # Unknown — let it fail downstream
# --- Auto-detect (no explicit provider): local > groq > openai ---------
if _HAS_FASTER_WHISPER:
return "local"
if _has_local_command():
return "local_command"
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
logger.info("No local STT available, using Groq Whisper API")
return "groq"
if _HAS_OPENAI and _resolve_openai_api_key():
logger.info("No local STT available, using OpenAI Whisper API")
return "openai"
return "none"
# ---------------------------------------------------------------------------
# Shared validation
# ---------------------------------------------------------------------------
def _validate_audio_file(file_path: str) -> Optional[Dict[str, Any]]:
"""Validate the audio file. Returns an error dict or None if OK."""
audio_path = Path(file_path)
if not audio_path.exists():
return {"success": False, "transcript": "", "error": f"Audio file not found: {file_path}"}
if not audio_path.is_file():
return {"success": False, "transcript": "", "error": f"Path is not a file: {file_path}"}
if audio_path.suffix.lower() not in SUPPORTED_FORMATS:
return {
"success": False,
"transcript": "",
"error": f"Unsupported format: {audio_path.suffix}. Supported: {', '.join(sorted(SUPPORTED_FORMATS))}",
}
try:
file_size = audio_path.stat().st_size
if file_size > MAX_FILE_SIZE:
return {
"success": False,
"transcript": "",
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max {MAX_FILE_SIZE / (1024*1024):.0f}MB)",
}
except OSError as e:
return {"success": False, "transcript": "", "error": f"Failed to access file: {e}"}
return None
# ---------------------------------------------------------------------------
# Provider: local (faster-whisper)
# ---------------------------------------------------------------------------
def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]:
"""Transcribe using faster-whisper (local, free)."""
global _local_model, _local_model_name
if not _HAS_FASTER_WHISPER:
return {"success": False, "transcript": "", "error": "faster-whisper not installed"}
try:
from faster_whisper import WhisperModel
# Lazy-load the model (downloads on first use, ~150 MB for 'base')
if _local_model is None or _local_model_name != model_name:
logger.info("Loading faster-whisper model '%s' (first load downloads the model)...", model_name)
_local_model = WhisperModel(model_name, device="auto", compute_type="auto")
_local_model_name = model_name
segments, info = _local_model.transcribe(file_path, beam_size=5)
transcript = " ".join(segment.text.strip() for segment in segments)
logger.info(
"Transcribed %s via local whisper (%s, lang=%s, %.1fs audio)",
Path(file_path).name, model_name, info.language, info.duration,
)
return {"success": True, "transcript": transcript, "provider": "local"}
except Exception as e:
logger.error("Local transcription failed: %s", e, exc_info=True)
return {"success": False, "transcript": "", "error": f"Local transcription failed: {e}"}
def _prepare_local_audio(file_path: str, work_dir: str) -> tuple[Optional[str], Optional[str]]:
"""Normalize audio for local CLI STT when needed."""
audio_path = Path(file_path)
if audio_path.suffix.lower() in LOCAL_NATIVE_AUDIO_FORMATS:
return file_path, None
ffmpeg = _find_ffmpeg_binary()
if not ffmpeg:
return None, "Local STT fallback requires ffmpeg for non-WAV inputs, but ffmpeg was not found"
converted_path = os.path.join(work_dir, f"{audio_path.stem}.wav")
command = [ffmpeg, "-y", "-i", file_path, converted_path]
try:
subprocess.run(command, check=True, capture_output=True, text=True)
return converted_path, None
except subprocess.CalledProcessError as e:
details = e.stderr.strip() or e.stdout.strip() or str(e)
logger.error("ffmpeg conversion failed for %s: %s", file_path, details)
return None, f"Failed to convert audio for local STT: {details}"
def _transcribe_local_command(file_path: str, model_name: str) -> Dict[str, Any]:
"""Run the configured local STT command template and read back a .txt transcript."""
command_template = _get_local_command_template()
if not command_template:
return {
"success": False,
"transcript": "",
"error": (
f"{LOCAL_STT_COMMAND_ENV} not configured and no local whisper binary was found"
),
}
language = os.getenv(LOCAL_STT_LANGUAGE_ENV, DEFAULT_LOCAL_STT_LANGUAGE)
normalized_model = _normalize_local_command_model(model_name)
try:
with tempfile.TemporaryDirectory(prefix="hermes-local-stt-") as output_dir:
prepared_input, prep_error = _prepare_local_audio(file_path, output_dir)
if prep_error:
return {"success": False, "transcript": "", "error": prep_error}
command = command_template.format(
input_path=shlex.quote(prepared_input),
output_dir=shlex.quote(output_dir),
language=shlex.quote(language),
model=shlex.quote(normalized_model),
)
subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
txt_files = sorted(Path(output_dir).glob("*.txt"))
if not txt_files:
return {
"success": False,
"transcript": "",
"error": "Local STT command completed but did not produce a .txt transcript",
}
transcript_text = txt_files[0].read_text(encoding="utf-8").strip()
logger.info(
"Transcribed %s via local STT command (%s, %d chars)",
Path(file_path).name,
normalized_model,
len(transcript_text),
)
return {"success": True, "transcript": transcript_text, "provider": "local_command"}
except KeyError as e:
return {
"success": False,
"transcript": "",
"error": f"Invalid {LOCAL_STT_COMMAND_ENV} template, missing placeholder: {e}",
}
except subprocess.CalledProcessError as e:
details = e.stderr.strip() or e.stdout.strip() or str(e)
logger.error("Local STT command failed for %s: %s", file_path, details)
return {"success": False, "transcript": "", "error": f"Local STT failed: {details}"}
except Exception as e:
logger.error("Unexpected error during local command transcription: %s", e, exc_info=True)
return {"success": False, "transcript": "", "error": f"Local transcription failed: {e}"}
# ---------------------------------------------------------------------------
# Provider: groq (Whisper API — free tier)
# ---------------------------------------------------------------------------
def _transcribe_groq(file_path: str, model_name: str) -> Dict[str, Any]:
"""Transcribe using Groq Whisper API (free tier available)."""
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
return {"success": False, "transcript": "", "error": "GROQ_API_KEY not set"}
if not _HAS_OPENAI:
return {"success": False, "transcript": "", "error": "openai package not installed"}
# Auto-correct model if caller passed an OpenAI-only model
if model_name in OPENAI_MODELS:
logger.info("Model %s not available on Groq, using %s", model_name, DEFAULT_GROQ_STT_MODEL)
model_name = DEFAULT_GROQ_STT_MODEL
try:
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
client = OpenAI(api_key=api_key, base_url=GROQ_BASE_URL, timeout=30, max_retries=0)
with open(file_path, "rb") as audio_file:
transcription = client.audio.transcriptions.create(
model=model_name,
file=audio_file,
response_format="text",
)
transcript_text = str(transcription).strip()
logger.info("Transcribed %s via Groq API (%s, %d chars)",
Path(file_path).name, model_name, len(transcript_text))
return {"success": True, "transcript": transcript_text, "provider": "groq"}
except PermissionError:
return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"}
except APIConnectionError as e:
return {"success": False, "transcript": "", "error": f"Connection error: {e}"}
except APITimeoutError as e:
return {"success": False, "transcript": "", "error": f"Request timeout: {e}"}
except APIError as e:
return {"success": False, "transcript": "", "error": f"API error: {e}"}
except Exception as e:
logger.error("Groq transcription failed: %s", e, exc_info=True)
return {"success": False, "transcript": "", "error": f"Transcription failed: {e}"}
# ---------------------------------------------------------------------------
# Provider: openai (Whisper API)
# ---------------------------------------------------------------------------
def _transcribe_openai(file_path: str, model_name: str) -> Dict[str, Any]:
"""Transcribe using OpenAI Whisper API (paid)."""
api_key = _resolve_openai_api_key()
if not api_key:
return {
"success": False,
"transcript": "",
"error": "Neither VOICE_TOOLS_OPENAI_KEY nor OPENAI_API_KEY is set",
}
if not _HAS_OPENAI:
return {"success": False, "transcript": "", "error": "openai package not installed"}
# Auto-correct model if caller passed a Groq-only model
if model_name in GROQ_MODELS:
logger.info("Model %s not available on OpenAI, using %s", model_name, DEFAULT_STT_MODEL)
model_name = DEFAULT_STT_MODEL
try:
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
client = OpenAI(api_key=api_key, base_url=OPENAI_BASE_URL, timeout=30, max_retries=0)
with open(file_path, "rb") as audio_file:
transcription = client.audio.transcriptions.create(
model=model_name,
file=audio_file,
response_format="text",
)
transcript_text = str(transcription).strip()
logger.info("Transcribed %s via OpenAI API (%s, %d chars)",
Path(file_path).name, model_name, len(transcript_text))
return {"success": True, "transcript": transcript_text, "provider": "openai"}
except PermissionError:
return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"}
except APIConnectionError as e:
return {"success": False, "transcript": "", "error": f"Connection error: {e}"}
except APITimeoutError as e:
return {"success": False, "transcript": "", "error": f"Request timeout: {e}"}
except APIError as e:
return {"success": False, "transcript": "", "error": f"API error: {e}"}
except Exception as e:
logger.error("OpenAI transcription failed: %s", e, exc_info=True)
return {"success": False, "transcript": "", "error": f"Transcription failed: {e}"}
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, Any]:
"""
Transcribe an audio file using the configured STT provider.
Provider priority:
1. User config (``stt.provider`` in config.yaml)
2. Auto-detect: local faster-whisper (free) > Groq (free tier) > OpenAI (paid)
Args:
file_path: Absolute path to the audio file to transcribe.
model: Override the model. If None, uses config or provider default.
Returns:
dict with keys:
- "success" (bool): Whether transcription succeeded
- "transcript" (str): The transcribed text (empty on failure)
- "error" (str, optional): Error message if success is False
- "provider" (str, optional): Which provider was used
"""
# Validate input
error = _validate_audio_file(file_path)
if error:
return error
# Load config and determine provider
stt_config = _load_stt_config()
if not is_stt_enabled(stt_config):
return {
"success": False,
"transcript": "",
"error": "STT is disabled in config.yaml (stt.enabled: false).",
}
provider = _get_provider(stt_config)
if provider == "local":
local_cfg = stt_config.get("local", {})
model_name = model or local_cfg.get("model", DEFAULT_LOCAL_MODEL)
return _transcribe_local(file_path, model_name)
if provider == "local_command":
local_cfg = stt_config.get("local", {})
model_name = _normalize_local_command_model(
model or local_cfg.get("model", DEFAULT_LOCAL_MODEL)
)
return _transcribe_local_command(file_path, model_name)
if provider == "groq":
model_name = model or DEFAULT_GROQ_STT_MODEL
return _transcribe_groq(file_path, model_name)
if provider == "openai":
openai_cfg = stt_config.get("openai", {})
model_name = model or openai_cfg.get("model", DEFAULT_STT_MODEL)
return _transcribe_openai(file_path, model_name)
# No provider available
return {
"success": False,
"transcript": "",
"error": (
"No STT provider available. Install faster-whisper for free local "
f"transcription, configure {LOCAL_STT_COMMAND_ENV} or install a local whisper CLI, "
"set GROQ_API_KEY for free Groq Whisper, or set VOICE_TOOLS_OPENAI_KEY "
"or OPENAI_API_KEY for the OpenAI Whisper API."
),
}

View file

@ -0,0 +1,847 @@
#!/usr/bin/env python3
"""
Text-to-Speech Tool Module
Supports four TTS providers:
- Edge TTS (default, free, no API key): Microsoft Edge neural voices
- ElevenLabs (premium): High-quality voices, needs ELEVENLABS_API_KEY
- OpenAI TTS: Good quality, needs OPENAI_API_KEY
- NeuTTS (local, free, no API key): On-device TTS via neutts_cli, needs neutts installed
Output formats:
- Opus (.ogg) for Telegram voice bubbles (requires ffmpeg for Edge TTS)
- MP3 (.mp3) for everything else (CLI, Discord, WhatsApp)
Configuration is loaded from ~/.hermes/config.yaml under the 'tts:' key.
The user chooses the provider and voice; the model just sends text.
Usage:
from tools.tts_tool import text_to_speech_tool, check_tts_requirements
result = text_to_speech_tool(text="Hello world")
"""
import asyncio
import datetime
import json
import logging
import os
import queue
import re
import shutil
import subprocess
import tempfile
import threading
from pathlib import Path
from typing import Callable, Dict, Any, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Lazy imports -- providers are imported only when actually used to avoid
# crashing in headless environments (SSH, Docker, WSL, no PortAudio).
# ---------------------------------------------------------------------------
def _import_edge_tts():
"""Lazy import edge_tts. Returns the module or raises ImportError."""
import edge_tts
return edge_tts
def _import_elevenlabs():
"""Lazy import ElevenLabs client. Returns the class or raises ImportError."""
from elevenlabs.client import ElevenLabs
return ElevenLabs
def _import_openai_client():
"""Lazy import OpenAI client. Returns the class or raises ImportError."""
from openai import OpenAI as OpenAIClient
return OpenAIClient
def _import_sounddevice():
"""Lazy import sounddevice. Returns the module or raises ImportError/OSError."""
import sounddevice as sd
return sd
# ===========================================================================
# Defaults
# ===========================================================================
DEFAULT_PROVIDER = "edge"
DEFAULT_EDGE_VOICE = "en-US-AriaNeural"
DEFAULT_ELEVENLABS_VOICE_ID = "pNInz6obpgDQGcFmaJgB" # Adam
DEFAULT_ELEVENLABS_MODEL_ID = "eleven_multilingual_v2"
DEFAULT_ELEVENLABS_STREAMING_MODEL_ID = "eleven_flash_v2_5"
DEFAULT_OPENAI_MODEL = "gpt-4o-mini-tts"
DEFAULT_OPENAI_VOICE = "alloy"
DEFAULT_OUTPUT_DIR = str(Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "audio_cache")
MAX_TEXT_LENGTH = 4000
# ===========================================================================
# Config loader -- reads tts: section from ~/.hermes/config.yaml
# ===========================================================================
def _load_tts_config() -> Dict[str, Any]:
"""
Load TTS configuration from ~/.hermes/config.yaml.
Returns a dict with provider settings. Falls back to defaults
for any missing fields.
"""
try:
from hermes_cli.config import load_config
config = load_config()
return config.get("tts", {})
except ImportError:
logger.debug("hermes_cli.config not available, using default TTS config")
return {}
except Exception as e:
logger.warning("Failed to load TTS config: %s", e, exc_info=True)
return {}
def _get_provider(tts_config: Dict[str, Any]) -> str:
"""Get the configured TTS provider name."""
return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip()
# ===========================================================================
# ffmpeg Opus conversion (Edge TTS MP3 -> OGG Opus for Telegram)
# ===========================================================================
def _has_ffmpeg() -> bool:
"""Check if ffmpeg is available on the system."""
return shutil.which("ffmpeg") is not None
def _convert_to_opus(mp3_path: str) -> Optional[str]:
"""
Convert an MP3 file to OGG Opus format for Telegram voice bubbles.
Args:
mp3_path: Path to the input MP3 file.
Returns:
Path to the .ogg file, or None if conversion fails.
"""
if not _has_ffmpeg():
return None
ogg_path = mp3_path.rsplit(".", 1)[0] + ".ogg"
try:
result = subprocess.run(
["ffmpeg", "-i", mp3_path, "-acodec", "libopus",
"-ac", "1", "-b:a", "64k", "-vbr", "off", ogg_path, "-y"],
capture_output=True, timeout=30,
)
if result.returncode != 0:
logger.warning("ffmpeg conversion failed with return code %d: %s",
result.returncode, result.stderr.decode('utf-8', errors='ignore')[:200])
return None
if os.path.exists(ogg_path) and os.path.getsize(ogg_path) > 0:
return ogg_path
except subprocess.TimeoutExpired:
logger.warning("ffmpeg OGG conversion timed out after 30s")
except FileNotFoundError:
logger.warning("ffmpeg not found in PATH")
except Exception as e:
logger.warning("ffmpeg OGG conversion failed: %s", e, exc_info=True)
return None
# ===========================================================================
# Provider: Edge TTS (free)
# ===========================================================================
async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
"""
Generate audio using Edge TTS.
Args:
text: Text to convert.
output_path: Where to save the MP3 file.
tts_config: TTS config dict.
Returns:
Path to the saved audio file.
"""
_edge_tts = _import_edge_tts()
edge_config = tts_config.get("edge", {})
voice = edge_config.get("voice", DEFAULT_EDGE_VOICE)
communicate = _edge_tts.Communicate(text, voice)
await communicate.save(output_path)
return output_path
# ===========================================================================
# Provider: ElevenLabs (premium)
# ===========================================================================
def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
"""
Generate audio using ElevenLabs.
Args:
text: Text to convert.
output_path: Where to save the audio file.
tts_config: TTS config dict.
Returns:
Path to the saved audio file.
"""
api_key = os.getenv("ELEVENLABS_API_KEY", "")
if not api_key:
raise ValueError("ELEVENLABS_API_KEY not set. Get one at https://elevenlabs.io/")
el_config = tts_config.get("elevenlabs", {})
voice_id = el_config.get("voice_id", DEFAULT_ELEVENLABS_VOICE_ID)
model_id = el_config.get("model_id", DEFAULT_ELEVENLABS_MODEL_ID)
# Determine output format based on file extension
if output_path.endswith(".ogg"):
output_format = "opus_48000_64"
else:
output_format = "mp3_44100_128"
ElevenLabs = _import_elevenlabs()
client = ElevenLabs(api_key=api_key)
audio_generator = client.text_to_speech.convert(
text=text,
voice_id=voice_id,
model_id=model_id,
output_format=output_format,
)
# audio_generator yields chunks -- write them all
with open(output_path, "wb") as f:
for chunk in audio_generator:
f.write(chunk)
return output_path
# ===========================================================================
# Provider: OpenAI TTS
# ===========================================================================
def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
"""
Generate audio using OpenAI TTS.
Args:
text: Text to convert.
output_path: Where to save the audio file.
tts_config: TTS config dict.
Returns:
Path to the saved audio file.
"""
api_key = os.getenv("VOICE_TOOLS_OPENAI_KEY", "")
if not api_key:
raise ValueError("VOICE_TOOLS_OPENAI_KEY not set. Get one at https://platform.openai.com/api-keys")
oai_config = tts_config.get("openai", {})
model = oai_config.get("model", DEFAULT_OPENAI_MODEL)
voice = oai_config.get("voice", DEFAULT_OPENAI_VOICE)
base_url = oai_config.get("base_url", "https://api.openai.com/v1")
# Determine response format from extension
if output_path.endswith(".ogg"):
response_format = "opus"
else:
response_format = "mp3"
OpenAIClient = _import_openai_client()
client = OpenAIClient(api_key=api_key, base_url=base_url)
response = client.audio.speech.create(
model=model,
voice=voice,
input=text,
response_format=response_format,
)
response.stream_to_file(output_path)
return output_path
# ===========================================================================
# NeuTTS (local, on-device TTS via neutts_cli)
# ===========================================================================
def _check_neutts_available() -> bool:
"""Check if the neutts engine is importable (installed locally)."""
try:
import importlib.util
return importlib.util.find_spec("neutts") is not None
except Exception:
return False
def _default_neutts_ref_audio() -> str:
"""Return path to the bundled default voice reference audio."""
return str(Path(__file__).parent / "neutts_samples" / "jo.wav")
def _default_neutts_ref_text() -> str:
"""Return path to the bundled default voice reference transcript."""
return str(Path(__file__).parent / "neutts_samples" / "jo.txt")
def _generate_neutts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
"""Generate speech using the local NeuTTS engine.
Runs synthesis in a subprocess via tools/neutts_synth.py to keep the
~500MB model in a separate process that exits after synthesis.
Outputs WAV; the caller handles conversion for Telegram if needed.
"""
import sys
neutts_config = tts_config.get("neutts", {})
ref_audio = neutts_config.get("ref_audio", "") or _default_neutts_ref_audio()
ref_text = neutts_config.get("ref_text", "") or _default_neutts_ref_text()
model = neutts_config.get("model", "neuphonic/neutts-air-q4-gguf")
device = neutts_config.get("device", "cpu")
# NeuTTS outputs WAV natively — use a .wav path for generation,
# let the caller convert to the final format afterward.
wav_path = output_path
if not output_path.endswith(".wav"):
wav_path = output_path.rsplit(".", 1)[0] + ".wav"
synth_script = str(Path(__file__).parent / "neutts_synth.py")
cmd = [
sys.executable, synth_script,
"--text", text,
"--out", wav_path,
"--ref-audio", ref_audio,
"--ref-text", ref_text,
"--model", model,
"--device", device,
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
if result.returncode != 0:
stderr = result.stderr.strip()
# Filter out the "OK:" line from stderr
error_lines = [l for l in stderr.splitlines() if not l.startswith("OK:")]
raise RuntimeError(f"NeuTTS synthesis failed: {chr(10).join(error_lines) or 'unknown error'}")
# If the caller wanted .mp3 or .ogg, convert from WAV
if wav_path != output_path:
ffmpeg = shutil.which("ffmpeg")
if ffmpeg:
conv_cmd = [ffmpeg, "-i", wav_path, "-y", "-loglevel", "error", output_path]
subprocess.run(conv_cmd, check=True, timeout=30)
os.remove(wav_path)
else:
# No ffmpeg — just rename the WAV to the expected path
os.rename(wav_path, output_path)
return output_path
# ===========================================================================
# Main tool function
# ===========================================================================
def text_to_speech_tool(
text: str,
output_path: Optional[str] = None,
) -> str:
"""
Convert text to speech audio.
Reads provider/voice config from ~/.hermes/config.yaml (tts: section).
The model sends text; the user configures voice and provider.
On messaging platforms, the returned MEDIA:<path> tag is intercepted
by the send pipeline and delivered as a native voice message.
In CLI mode, the file is saved to ~/voice-memos/.
Args:
text: The text to convert to speech.
output_path: Optional custom save path. Defaults to ~/voice-memos/<timestamp>.mp3
Returns:
str: JSON result with success, file_path, and optionally MEDIA tag.
"""
if not text or not text.strip():
return json.dumps({"success": False, "error": "Text is required"}, ensure_ascii=False)
# Truncate very long text with a warning
if len(text) > MAX_TEXT_LENGTH:
logger.warning("TTS text too long (%d chars), truncating to %d", len(text), MAX_TEXT_LENGTH)
text = text[:MAX_TEXT_LENGTH]
tts_config = _load_tts_config()
provider = _get_provider(tts_config)
# Detect platform from gateway env var to choose the best output format.
# Telegram voice bubbles require Opus (.ogg); OpenAI and ElevenLabs can
# produce Opus natively (no ffmpeg needed). Edge TTS always outputs MP3
# and needs ffmpeg for conversion.
platform = os.getenv("HERMES_SESSION_PLATFORM", "").lower()
want_opus = (platform == "telegram")
# Determine output path
if output_path:
file_path = Path(output_path).expanduser()
else:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = Path(DEFAULT_OUTPUT_DIR)
out_dir.mkdir(parents=True, exist_ok=True)
# Use .ogg for Telegram with providers that support native Opus output,
# otherwise fall back to .mp3 (Edge TTS will attempt ffmpeg conversion later).
if want_opus and provider in ("openai", "elevenlabs"):
file_path = out_dir / f"tts_{timestamp}.ogg"
else:
file_path = out_dir / f"tts_{timestamp}.mp3"
# Ensure parent directory exists
file_path.parent.mkdir(parents=True, exist_ok=True)
file_str = str(file_path)
try:
# Generate audio with the configured provider
if provider == "elevenlabs":
try:
_import_elevenlabs()
except ImportError:
return json.dumps({
"success": False,
"error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs"
}, ensure_ascii=False)
logger.info("Generating speech with ElevenLabs...")
_generate_elevenlabs(text, file_str, tts_config)
elif provider == "openai":
try:
_import_openai_client()
except ImportError:
return json.dumps({
"success": False,
"error": "OpenAI provider selected but 'openai' package not installed."
}, ensure_ascii=False)
logger.info("Generating speech with OpenAI TTS...")
_generate_openai_tts(text, file_str, tts_config)
elif provider == "neutts":
if not _check_neutts_available():
return json.dumps({
"success": False,
"error": "NeuTTS provider selected but neutts is not installed. "
"Run hermes setup and choose NeuTTS, or install espeak-ng and run python -m pip install -U neutts[all]."
}, ensure_ascii=False)
logger.info("Generating speech with NeuTTS (local)...")
_generate_neutts(text, file_str, tts_config)
else:
# Default: Edge TTS (free), with NeuTTS as local fallback
edge_available = True
try:
_import_edge_tts()
except ImportError:
edge_available = False
if edge_available:
logger.info("Generating speech with Edge TTS...")
try:
loop = asyncio.get_running_loop()
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
pool.submit(
lambda: asyncio.run(_generate_edge_tts(text, file_str, tts_config))
).result(timeout=60)
except RuntimeError:
asyncio.run(_generate_edge_tts(text, file_str, tts_config))
elif _check_neutts_available():
logger.info("Edge TTS not available, falling back to NeuTTS (local)...")
provider = "neutts"
_generate_neutts(text, file_str, tts_config)
else:
return json.dumps({
"success": False,
"error": "No TTS provider available. Install edge-tts (pip install edge-tts) "
"or set up NeuTTS for local synthesis."
}, ensure_ascii=False)
# Check the file was actually created
if not os.path.exists(file_str) or os.path.getsize(file_str) == 0:
return json.dumps({
"success": False,
"error": f"TTS generation produced no output (provider: {provider})"
}, ensure_ascii=False)
# Try Opus conversion for Telegram compatibility
# Edge TTS outputs MP3, NeuTTS outputs WAV — both need ffmpeg conversion
voice_compatible = False
if provider in ("edge", "neutts") and not file_str.endswith(".ogg"):
opus_path = _convert_to_opus(file_str)
if opus_path:
file_str = opus_path
voice_compatible = True
elif provider in ("elevenlabs", "openai"):
# These providers can output Opus natively if the path ends in .ogg
voice_compatible = file_str.endswith(".ogg")
file_size = os.path.getsize(file_str)
logger.info("TTS audio saved: %s (%s bytes, provider: %s)", file_str, f"{file_size:,}", provider)
# Build response with MEDIA tag for platform delivery
media_tag = f"MEDIA:{file_str}"
if voice_compatible:
media_tag = f"[[audio_as_voice]]\n{media_tag}"
return json.dumps({
"success": True,
"file_path": file_str,
"media_tag": media_tag,
"provider": provider,
"voice_compatible": voice_compatible,
}, ensure_ascii=False)
except ValueError as e:
# Configuration errors (missing API keys, etc.)
error_msg = f"TTS configuration error ({provider}): {e}"
logger.error("%s", error_msg)
return json.dumps({"success": False, "error": error_msg}, ensure_ascii=False)
except FileNotFoundError as e:
# Missing dependencies or files
error_msg = f"TTS dependency missing ({provider}): {e}"
logger.error("%s", error_msg, exc_info=True)
return json.dumps({"success": False, "error": error_msg}, ensure_ascii=False)
except Exception as e:
# Unexpected errors
error_msg = f"TTS generation failed ({provider}): {e}"
logger.error("%s", error_msg, exc_info=True)
return json.dumps({"success": False, "error": error_msg}, ensure_ascii=False)
# ===========================================================================
# Requirements check
# ===========================================================================
def check_tts_requirements() -> bool:
"""
Check if at least one TTS provider is available.
Edge TTS needs no API key and is the default, so if the package
is installed, TTS is available.
Returns:
bool: True if at least one provider can work.
"""
try:
_import_edge_tts()
return True
except ImportError:
pass
try:
_import_elevenlabs()
if os.getenv("ELEVENLABS_API_KEY"):
return True
except ImportError:
pass
try:
_import_openai_client()
if os.getenv("VOICE_TOOLS_OPENAI_KEY"):
return True
except ImportError:
pass
if _check_neutts_available():
return True
return False
# ===========================================================================
# Streaming TTS: sentence-by-sentence pipeline for ElevenLabs
# ===========================================================================
# Sentence boundary pattern: punctuation followed by space or newline
_SENTENCE_BOUNDARY_RE = re.compile(r'(?<=[.!?])(?:\s|\n)|(?:\n\n)')
# Markdown stripping patterns (same as cli.py _voice_speak_response)
_MD_CODE_BLOCK = re.compile(r'```[\s\S]*?```')
_MD_LINK = re.compile(r'\[([^\]]+)\]\([^)]+\)')
_MD_URL = re.compile(r'https?://\S+')
_MD_BOLD = re.compile(r'\*\*(.+?)\*\*')
_MD_ITALIC = re.compile(r'\*(.+?)\*')
_MD_INLINE_CODE = re.compile(r'`(.+?)`')
_MD_HEADER = re.compile(r'^#+\s*', flags=re.MULTILINE)
_MD_LIST_ITEM = re.compile(r'^\s*[-*]\s+', flags=re.MULTILINE)
_MD_HR = re.compile(r'---+')
_MD_EXCESS_NL = re.compile(r'\n{3,}')
def _strip_markdown_for_tts(text: str) -> str:
"""Remove markdown formatting that shouldn't be spoken aloud."""
text = _MD_CODE_BLOCK.sub(' ', text)
text = _MD_LINK.sub(r'\1', text)
text = _MD_URL.sub('', text)
text = _MD_BOLD.sub(r'\1', text)
text = _MD_ITALIC.sub(r'\1', text)
text = _MD_INLINE_CODE.sub(r'\1', text)
text = _MD_HEADER.sub('', text)
text = _MD_LIST_ITEM.sub('', text)
text = _MD_HR.sub('', text)
text = _MD_EXCESS_NL.sub('\n\n', text)
return text.strip()
def stream_tts_to_speaker(
text_queue: queue.Queue,
stop_event: threading.Event,
tts_done_event: threading.Event,
display_callback: Optional[Callable[[str], None]] = None,
):
"""Consume text deltas from *text_queue*, buffer them into sentences,
and stream each sentence through ElevenLabs TTS to the speaker in
real-time.
Protocol:
* The producer puts ``str`` deltas onto *text_queue*.
* A ``None`` sentinel signals end-of-text (flush remaining buffer).
* *stop_event* can be set to abort early (e.g. user interrupt).
* *tts_done_event* is **set** in the ``finally`` block so callers
waiting on it (continuous voice mode) know playback is finished.
"""
tts_done_event.clear()
try:
# --- TTS client setup (optional -- display_callback works without it) ---
client = None
output_stream = None
voice_id = DEFAULT_ELEVENLABS_VOICE_ID
model_id = DEFAULT_ELEVENLABS_STREAMING_MODEL_ID
tts_config = _load_tts_config()
el_config = tts_config.get("elevenlabs", {})
voice_id = el_config.get("voice_id", voice_id)
model_id = el_config.get("streaming_model_id",
el_config.get("model_id", model_id))
api_key = os.getenv("ELEVENLABS_API_KEY", "")
if not api_key:
logger.warning("ELEVENLABS_API_KEY not set; streaming TTS audio disabled")
else:
try:
ElevenLabs = _import_elevenlabs()
client = ElevenLabs(api_key=api_key)
except ImportError:
logger.warning("elevenlabs package not installed; streaming TTS disabled")
# Open a single sounddevice output stream for the lifetime of
# this function. ElevenLabs pcm_24000 produces signed 16-bit
# little-endian mono PCM at 24 kHz.
if client is not None:
try:
sd = _import_sounddevice()
import numpy as _np
output_stream = sd.OutputStream(
samplerate=24000, channels=1, dtype="int16",
)
output_stream.start()
except (ImportError, OSError) as exc:
logger.debug("sounddevice not available: %s", exc)
output_stream = None
except Exception as exc:
logger.warning("sounddevice OutputStream failed: %s", exc)
output_stream = None
sentence_buf = ""
min_sentence_len = 20
long_flush_len = 100
queue_timeout = 0.5
_spoken_sentences: list[str] = [] # track spoken sentences to skip duplicates
# Regex to strip complete <think>...</think> blocks from buffer
_think_block_re = re.compile(r'<think[\s>].*?</think>', flags=re.DOTALL)
def _speak_sentence(sentence: str):
"""Display sentence and optionally generate + play audio."""
if stop_event.is_set():
return
cleaned = _strip_markdown_for_tts(sentence).strip()
if not cleaned:
return
# Skip duplicate/near-duplicate sentences (LLM repetition)
cleaned_lower = cleaned.lower().rstrip(".!,")
for prev in _spoken_sentences:
if prev.lower().rstrip(".!,") == cleaned_lower:
return
_spoken_sentences.append(cleaned)
# Display raw sentence on screen before TTS processing
if display_callback is not None:
display_callback(sentence)
# Skip audio generation if no TTS client available
if client is None:
return
# Truncate very long sentences
if len(cleaned) > MAX_TEXT_LENGTH:
cleaned = cleaned[:MAX_TEXT_LENGTH]
try:
audio_iter = client.text_to_speech.convert(
text=cleaned,
voice_id=voice_id,
model_id=model_id,
output_format="pcm_24000",
)
if output_stream is not None:
for chunk in audio_iter:
if stop_event.is_set():
break
import numpy as _np
audio_array = _np.frombuffer(chunk, dtype=_np.int16)
output_stream.write(audio_array.reshape(-1, 1))
else:
# Fallback: write chunks to temp file and play via system player
_play_via_tempfile(audio_iter, stop_event)
except Exception as exc:
logger.warning("Streaming TTS sentence failed: %s", exc)
def _play_via_tempfile(audio_iter, stop_evt):
"""Write PCM chunks to a temp WAV file and play it."""
tmp_path = None
try:
import wave
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_path = tmp.name
with wave.open(tmp, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(24000)
for chunk in audio_iter:
if stop_evt.is_set():
break
wf.writeframes(chunk)
from tools.voice_mode import play_audio_file
play_audio_file(tmp_path)
except Exception as exc:
logger.warning("Temp-file TTS fallback failed: %s", exc)
finally:
if tmp_path:
try:
os.unlink(tmp_path)
except OSError:
pass
while not stop_event.is_set():
# Read next delta from queue
try:
delta = text_queue.get(timeout=queue_timeout)
except queue.Empty:
# Timeout: if we have accumulated a long buffer, flush it
if len(sentence_buf) > long_flush_len:
_speak_sentence(sentence_buf)
sentence_buf = ""
continue
if delta is None:
# End-of-text sentinel: strip any remaining think blocks, flush
sentence_buf = _think_block_re.sub('', sentence_buf)
if sentence_buf.strip():
_speak_sentence(sentence_buf)
break
sentence_buf += delta
# --- Think block filtering ---
# Strip complete <think>...</think> blocks from buffer.
# Works correctly even when tags span multiple deltas.
sentence_buf = _think_block_re.sub('', sentence_buf)
# If an incomplete <think tag is at the end, wait for more data
# before extracting sentences (the closing tag may arrive next).
if '<think' in sentence_buf and '</think>' not in sentence_buf:
continue
# Check for sentence boundaries
while True:
m = _SENTENCE_BOUNDARY_RE.search(sentence_buf)
if m is None:
break
end_pos = m.end()
sentence = sentence_buf[:end_pos]
sentence_buf = sentence_buf[end_pos:]
# Merge short fragments into the next sentence
if len(sentence.strip()) < min_sentence_len:
sentence_buf = sentence + sentence_buf
break
_speak_sentence(sentence)
# Drain any remaining items from the queue
while True:
try:
text_queue.get_nowait()
except queue.Empty:
break
# output_stream is closed in the finally block below
except Exception as exc:
logger.warning("Streaming TTS pipeline error: %s", exc)
finally:
# Always close the audio output stream to avoid locking the device
if output_stream is not None:
try:
output_stream.stop()
output_stream.close()
except Exception:
pass
tts_done_event.set()
# ===========================================================================
# Main -- quick diagnostics
# ===========================================================================
if __name__ == "__main__":
print("🔊 Text-to-Speech Tool Module")
print("=" * 50)
def _check(importer, label):
try:
importer()
return True
except ImportError:
return False
print(f"\nProvider availability:")
print(f" Edge TTS: {'installed' if _check(_import_edge_tts, 'edge') else 'not installed (pip install edge-tts)'}")
print(f" ElevenLabs: {'installed' if _check(_import_elevenlabs, 'el') else 'not installed (pip install elevenlabs)'}")
print(f" API Key: {'set' if os.getenv('ELEVENLABS_API_KEY') else 'not set'}")
print(f" OpenAI: {'installed' if _check(_import_openai_client, 'oai') else 'not installed'}")
print(f" API Key: {'set' if os.getenv('VOICE_TOOLS_OPENAI_KEY') else 'not set (VOICE_TOOLS_OPENAI_KEY)'}")
print(f" ffmpeg: {'✅ found' if _has_ffmpeg() else '❌ not found (needed for Telegram Opus)'}")
print(f"\n Output dir: {DEFAULT_OUTPUT_DIR}")
config = _load_tts_config()
provider = _get_provider(config)
print(f" Configured provider: {provider}")
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
from tools.registry import registry
TTS_SCHEMA = {
"name": "text_to_speech",
"description": "Convert text to speech audio. Returns a MEDIA: path that the platform delivers as a voice message. On Telegram it plays as a voice bubble, on Discord/WhatsApp as an audio attachment. In CLI mode, saves to ~/voice-memos/. Voice and provider are user-configured, not model-selected.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The text to convert to speech. Keep under 4000 characters."
},
"output_path": {
"type": "string",
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/<timestamp>.mp3"
}
},
"required": ["text"]
}
}
registry.register(
name="text_to_speech",
toolset="tts",
schema=TTS_SCHEMA,
handler=lambda args, **kw: text_to_speech_tool(
text=args.get("text", ""),
output_path=args.get("output_path")),
check_fn=check_tts_requirements,
emoji="🔊",
)

View file

@ -0,0 +1,96 @@
"""URL safety checks — blocks requests to private/internal network addresses.
Prevents SSRF (Server-Side Request Forgery) where a malicious prompt or
skill could trick the agent into fetching internal resources like cloud
metadata endpoints (169.254.169.254), localhost services, or private
network hosts.
Limitations (documented, not fixable at pre-flight level):
- DNS rebinding (TOCTOU): an attacker-controlled DNS server with TTL=0
can return a public IP for the check, then a private IP for the actual
connection. Fixing this requires connection-level validation (e.g.
Python's Champion library or an egress proxy like Stripe's Smokescreen).
- Redirect-based bypass in vision_tools is mitigated by an httpx event
hook that re-validates each redirect target. Web tools use third-party
SDKs (Firecrawl/Tavily) where redirect handling is on their servers.
"""
import ipaddress
import logging
import socket
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
# Hostnames that should always be blocked regardless of IP resolution
_BLOCKED_HOSTNAMES = frozenset({
"metadata.google.internal",
"metadata.goog",
})
# 100.64.0.0/10 (CGNAT / Shared Address Space, RFC 6598) is NOT covered by
# ipaddress.is_private — it returns False for both is_private and is_global.
# Must be blocked explicitly. Used by carrier-grade NAT, Tailscale/WireGuard
# VPNs, and some cloud internal networks.
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return True if the IP should be blocked for SSRF protection."""
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
return True
if ip.is_multicast or ip.is_unspecified:
return True
# CGNAT range not covered by is_private
if ip in _CGNAT_NETWORK:
return True
return False
def is_safe_url(url: str) -> bool:
"""Return True if the URL target is not a private/internal address.
Resolves the hostname to an IP and checks against private ranges.
Fails closed: DNS errors and unexpected exceptions block the request.
"""
try:
parsed = urlparse(url)
hostname = (parsed.hostname or "").strip().lower()
if not hostname:
return False
# Block known internal hostnames
if hostname in _BLOCKED_HOSTNAMES:
logger.warning("Blocked request to internal hostname: %s", hostname)
return False
# Try to resolve and check IP
try:
addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
except socket.gaierror:
# DNS resolution failed — fail closed. If DNS can't resolve it,
# the HTTP client will also fail, so blocking loses nothing.
logger.warning("Blocked request — DNS resolution failed for: %s", hostname)
return False
for family, _, _, _, sockaddr in addr_info:
ip_str = sockaddr[0]
try:
ip = ipaddress.ip_address(ip_str)
except ValueError:
continue
if _is_blocked_ip(ip):
logger.warning(
"Blocked request to private/internal address: %s -> %s",
hostname, ip_str,
)
return False
return True
except Exception as exc:
# Fail closed on unexpected errors — don't let parsing edge cases
# become SSRF bypass vectors
logger.warning("Blocked request — URL safety check error for %s: %s", url, exc)
return False

View file

@ -0,0 +1,541 @@
#!/usr/bin/env python3
"""
Vision Tools Module
This module provides vision analysis tools that work with image URLs.
Uses the centralized auxiliary vision router, which can select OpenRouter,
Nous, Codex, native Anthropic, or a custom OpenAI-compatible endpoint.
Available tools:
- vision_analyze_tool: Analyze images from URLs with custom prompts
Features:
- Downloads images from URLs and converts to base64 for API compatibility
- Comprehensive image description
- Context-aware analysis based on user queries
- Automatic temporary file cleanup
- Proper error handling and validation
- Debug logging support
Usage:
from vision_tools import vision_analyze_tool
import asyncio
# Analyze an image
result = await vision_analyze_tool(
image_url="https://example.com/image.jpg",
user_prompt="What architectural style is this building?"
)
"""
import asyncio
import base64
import json
import logging
import os
import uuid
from pathlib import Path
from typing import Any, Awaitable, Dict, Optional
from urllib.parse import urlparse
import httpx
from agent.auxiliary_client import async_call_llm
from tools.debug_helpers import DebugSession
logger = logging.getLogger(__name__)
_debug = DebugSession("vision_tools", env_var="VISION_TOOLS_DEBUG")
def _validate_image_url(url: str) -> bool:
"""
Basic validation of image URL format.
Args:
url (str): The URL to validate
Returns:
bool: True if URL appears to be valid, False otherwise
"""
if not url or not isinstance(url, str):
return False
# Basic HTTP/HTTPS URL check
if not (url.startswith("http://") or url.startswith("https://")):
return False
# Parse to ensure we at least have a network location; still allow URLs
# without file extensions (e.g. CDN endpoints that redirect to images).
parsed = urlparse(url)
if not parsed.netloc:
return False
# Block private/internal addresses to prevent SSRF
from tools.url_safety import is_safe_url
if not is_safe_url(url):
return False
return True
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
"""
Download an image from a URL to a local destination (async) with retry logic.
Args:
image_url (str): The URL of the image to download
destination (Path): The path where the image should be saved
max_retries (int): Maximum number of retry attempts (default: 3)
Returns:
Path: The path to the downloaded image
Raises:
Exception: If download fails after all retries
"""
import asyncio
# Create parent directories if they don't exist
destination.parent.mkdir(parents=True, exist_ok=True)
async def _ssrf_redirect_guard(response):
"""Re-validate each redirect target to prevent redirect-based SSRF.
Without this, an attacker can host a public URL that 302-redirects
to http://169.254.169.254/ and bypass the pre-flight is_safe_url check.
Must be async because httpx.AsyncClient awaits event hooks.
"""
if response.is_redirect and response.next_request:
redirect_url = str(response.next_request.url)
from tools.url_safety import is_safe_url
if not is_safe_url(redirect_url):
raise ValueError(
f"Blocked redirect to private/internal address: {redirect_url}"
)
last_error = None
for attempt in range(max_retries):
try:
# Download the image with appropriate headers using async httpx
# Enable follow_redirects to handle image CDNs that redirect (e.g., Imgur, Picsum)
# SSRF: event_hooks validates each redirect target against private IP ranges
async with httpx.AsyncClient(
timeout=30.0,
follow_redirects=True,
event_hooks={"response": [_ssrf_redirect_guard]},
) as client:
response = await client.get(
image_url,
headers={
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
"Accept": "image/*,*/*;q=0.8",
},
)
response.raise_for_status()
# Save the image content
destination.write_bytes(response.content)
return destination
except Exception as e:
last_error = e
if attempt < max_retries - 1:
wait_time = 2 ** (attempt + 1) # 2s, 4s, 8s
logger.warning("Image download failed (attempt %s/%s): %s", attempt + 1, max_retries, str(e)[:50])
logger.warning("Retrying in %ss...", wait_time)
await asyncio.sleep(wait_time)
else:
logger.error(
"Image download failed after %s attempts: %s",
max_retries,
str(e)[:100],
exc_info=True,
)
raise last_error
def _determine_mime_type(image_path: Path) -> str:
"""
Determine the MIME type of an image based on its file extension.
Args:
image_path (Path): Path to the image file
Returns:
str: The MIME type (defaults to image/jpeg if unknown)
"""
extension = image_path.suffix.lower()
mime_types = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.bmp': 'image/bmp',
'.webp': 'image/webp',
'.svg': 'image/svg+xml'
}
return mime_types.get(extension, 'image/jpeg')
def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None) -> str:
"""
Convert an image file to a base64-encoded data URL.
Args:
image_path (Path): Path to the image file
mime_type (Optional[str]): MIME type of the image (auto-detected if None)
Returns:
str: Base64-encoded data URL (e.g., "data:image/jpeg;base64,...")
"""
# Read the image as bytes
data = image_path.read_bytes()
# Encode to base64
encoded = base64.b64encode(data).decode("ascii")
# Determine MIME type
mime = mime_type or _determine_mime_type(image_path)
# Create data URL
data_url = f"data:{mime};base64,{encoded}"
return data_url
async def vision_analyze_tool(
image_url: str,
user_prompt: str,
model: str = None,
) -> str:
"""
Analyze an image from a URL or local file path using vision AI.
This tool accepts either an HTTP/HTTPS URL or a local file path. For URLs,
it downloads the image first. In both cases, the image is converted to base64
and processed using Gemini 3 Flash Preview via OpenRouter API.
The user_prompt parameter is expected to be pre-formatted by the calling
function (typically model_tools.py) to include both full description
requests and specific questions.
Args:
image_url (str): The URL or local file path of the image to analyze.
Accepts http://, https:// URLs or absolute/relative file paths.
user_prompt (str): The pre-formatted prompt for the vision model
model (str): The vision model to use (default: google/gemini-3-flash-preview)
Returns:
str: JSON string containing the analysis results with the following structure:
{
"success": bool,
"analysis": str (defaults to error message if None)
}
Raises:
Exception: If download fails, analysis fails, or API key is not set
Note:
- For URLs, temporary images are stored in ./temp_vision_images/ and cleaned up
- For local file paths, the file is used directly and NOT deleted
- Supports common image formats (JPEG, PNG, GIF, WebP, etc.)
"""
debug_call_data = {
"parameters": {
"image_url": image_url,
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
"model": model
},
"error": None,
"success": False,
"analysis_length": 0,
"model_used": model,
"image_size_bytes": 0
}
temp_image_path = None
# Track whether we should clean up the file after processing.
# Local files (e.g. from the image cache) should NOT be deleted.
should_cleanup = True
try:
from tools.interrupt import is_interrupted
if is_interrupted():
return json.dumps({"success": False, "error": "Interrupted"})
logger.info("Analyzing image: %s", image_url[:60])
logger.info("User prompt: %s", user_prompt[:100])
# Determine if this is a local file path or a remote URL
local_path = Path(os.path.expanduser(image_url))
if local_path.is_file():
# Local file path (e.g. from platform image cache) -- skip download
logger.info("Using local image file: %s", image_url)
temp_image_path = local_path
should_cleanup = False # Don't delete cached/local files
elif _validate_image_url(image_url):
# Remote URL -- download to a temporary location
logger.info("Downloading image from URL...")
temp_dir = Path("./temp_vision_images")
temp_image_path = temp_dir / f"temp_image_{uuid.uuid4()}.jpg"
await _download_image(image_url, temp_image_path)
should_cleanup = True
else:
raise ValueError(
"Invalid image source. Provide an HTTP/HTTPS URL or a valid local file path."
)
# Get image file size for logging
image_size_bytes = temp_image_path.stat().st_size
image_size_kb = image_size_bytes / 1024
logger.info("Image ready (%.1f KB)", image_size_kb)
# Convert image to base64 data URL
logger.info("Converting image to base64...")
image_data_url = _image_to_base64_data_url(temp_image_path)
# Calculate size in KB for better readability
data_size_kb = len(image_data_url) / 1024
logger.info("Image converted to base64 (%.1f KB)", data_size_kb)
debug_call_data["image_size_bytes"] = image_size_bytes
# Use the prompt as provided (model_tools.py now handles full description formatting)
comprehensive_prompt = user_prompt
# Prepare the message with base64-encoded image
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": comprehensive_prompt
},
{
"type": "image_url",
"image_url": {
"url": image_data_url
}
}
]
}
]
logger.info("Processing image with vision model...")
# Call the vision API via centralized router.
# Read timeout from config.yaml (auxiliary.vision.timeout), default 30s.
vision_timeout = 30.0
try:
from hermes_cli.config import load_config
_cfg = load_config()
_vt = _cfg.get("auxiliary", {}).get("vision", {}).get("timeout")
if _vt is not None:
vision_timeout = float(_vt)
except Exception:
pass
call_kwargs = {
"task": "vision",
"messages": messages,
"temperature": 0.1,
"max_tokens": 2000,
"timeout": vision_timeout,
}
if model:
call_kwargs["model"] = model
response = await async_call_llm(**call_kwargs)
# Extract the analysis
analysis = response.choices[0].message.content.strip()
analysis_length = len(analysis)
logger.info("Image analysis completed (%s characters)", analysis_length)
# Prepare successful response
result = {
"success": True,
"analysis": analysis or "There was a problem with the request and the image could not be analyzed."
}
debug_call_data["success"] = True
debug_call_data["analysis_length"] = analysis_length
# Log debug information
_debug.log_call("vision_analyze_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
except Exception as e:
error_msg = f"Error analyzing image: {str(e)}"
logger.error("%s", error_msg, exc_info=True)
# Detect vision capability errors — give the model a clear message
# so it can inform the user instead of a cryptic API error.
err_str = str(e).lower()
if any(hint in err_str for hint in (
"402", "insufficient", "payment required", "credits", "billing",
)):
analysis = (
"Insufficient credits or payment required. Please top up your "
f"API provider account and try again. Error: {e}"
)
elif any(hint in err_str for hint in (
"does not support", "not support image", "invalid_request",
"content_policy", "image_url", "multimodal",
"unrecognized request argument", "image input",
)):
analysis = (
f"{model} does not support vision or our request was not "
f"accepted by the server. Error: {e}"
)
else:
analysis = (
"There was a problem with the request and the image could not "
f"be analyzed. Error: {e}"
)
# Prepare error response
result = {
"success": False,
"error": error_msg,
"analysis": analysis,
}
debug_call_data["error"] = error_msg
_debug.log_call("vision_analyze_tool", debug_call_data)
_debug.save()
return json.dumps(result, indent=2, ensure_ascii=False)
finally:
# Clean up temporary image file (but NOT local/cached files)
if should_cleanup and temp_image_path and temp_image_path.exists():
try:
temp_image_path.unlink()
logger.debug("Cleaned up temporary image file")
except Exception as cleanup_error:
logger.warning(
"Could not delete temporary file: %s", cleanup_error, exc_info=True
)
def check_vision_requirements() -> bool:
"""Check if the configured runtime vision path can resolve a client."""
try:
from agent.auxiliary_client import resolve_vision_provider_client
_provider, client, _model = resolve_vision_provider_client()
return client is not None
except Exception:
return False
def get_debug_session_info() -> Dict[str, Any]:
"""
Get information about the current debug session.
Returns:
Dict[str, Any]: Dictionary containing debug session information
"""
return _debug.get_session_info()
if __name__ == "__main__":
"""
Simple test/demo when run directly
"""
print("👁️ Vision Tools Module")
print("=" * 40)
# Check if vision model is available
api_available = check_vision_requirements()
if not api_available:
print("❌ No auxiliary vision model available")
print("Configure a supported multimodal backend (OpenRouter, Nous, Codex, Anthropic, or a custom OpenAI-compatible endpoint).")
exit(1)
else:
print("✅ Vision model available")
print("🛠️ Vision tools ready for use!")
# Show debug mode status
if _debug.active:
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{_debug.session_id}.json")
else:
print("🐛 Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)")
print("\nBasic usage:")
print(" from vision_tools import vision_analyze_tool")
print(" import asyncio")
print("")
print(" async def main():")
print(" result = await vision_analyze_tool(")
print(" image_url='https://example.com/image.jpg',")
print(" user_prompt='What do you see in this image?'")
print(" )")
print(" print(result)")
print(" asyncio.run(main())")
print("\nExample prompts:")
print(" - 'What architectural style is this building?'")
print(" - 'Describe the emotions and mood in this image'")
print(" - 'What text can you read in this image?'")
print(" - 'Identify any safety hazards visible'")
print(" - 'What products or brands are shown?'")
print("\nDebug mode:")
print(" # Enable debug logging")
print(" export VISION_TOOLS_DEBUG=true")
print(" # Debug logs capture all vision analysis calls and results")
print(" # Logs saved to: ./logs/vision_tools_debug_UUID.json")
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
from tools.registry import registry
VISION_ANALYZE_SCHEMA = {
"name": "vision_analyze",
"description": "Analyze images using AI vision. Provides a comprehensive description and answers a specific question about the image content.",
"parameters": {
"type": "object",
"properties": {
"image_url": {
"type": "string",
"description": "Image URL (http/https) or local file path to analyze."
},
"question": {
"type": "string",
"description": "Your specific question or request about the image to resolve. The AI will automatically provide a complete image description AND answer your specific question."
}
},
"required": ["image_url", "question"]
}
}
def _handle_vision_analyze(args: Dict[str, Any], **kw: Any) -> Awaitable[str]:
image_url = args.get("image_url", "")
question = args.get("question", "")
full_prompt = (
"Fully describe and explain everything about this image, then answer the "
f"following question:\n\n{question}"
)
model = os.getenv("AUXILIARY_VISION_MODEL", "").strip() or None
return vision_analyze_tool(image_url, full_prompt, model)
registry.register(
name="vision_analyze",
toolset="vision",
schema=VISION_ANALYZE_SCHEMA,
handler=_handle_vision_analyze,
check_fn=check_vision_requirements,
is_async=True,
emoji="👁️",
)

View file

@ -0,0 +1,793 @@
"""Voice Mode -- Push-to-talk audio recording and playback for the CLI.
Provides audio capture via sounddevice, WAV encoding via stdlib wave,
STT dispatch via tools.transcription_tools, and TTS playback via
sounddevice or system audio players.
Dependencies (optional):
pip install sounddevice numpy
or: pip install hermes-agent[voice]
"""
import logging
import os
import platform
import re
import shutil
import subprocess
import tempfile
import threading
import time
import wave
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Lazy audio imports -- never imported at module level to avoid crashing
# in headless environments (SSH, Docker, WSL, no PortAudio).
# ---------------------------------------------------------------------------
def _import_audio():
"""Lazy-import sounddevice and numpy. Returns (sd, np).
Raises ImportError or OSError if the libraries are not available
(e.g. PortAudio missing on headless servers).
"""
import sounddevice as sd
import numpy as np
return sd, np
def _audio_available() -> bool:
"""Return True if audio libraries can be imported."""
try:
_import_audio()
return True
except (ImportError, OSError):
return False
def detect_audio_environment() -> dict:
"""Detect if the current environment supports audio I/O.
Returns dict with 'available' (bool) and 'warnings' (list of strings).
"""
warnings = []
# SSH detection
if any(os.environ.get(v) for v in ('SSH_CLIENT', 'SSH_TTY', 'SSH_CONNECTION')):
warnings.append("Running over SSH -- no audio devices available")
# Docker detection
if os.path.exists('/.dockerenv'):
warnings.append("Running inside Docker container -- no audio devices")
# WSL detection
try:
with open('/proc/version', 'r') as f:
if 'microsoft' in f.read().lower():
warnings.append("Running in WSL -- audio requires PulseAudio bridge to Windows")
except (FileNotFoundError, PermissionError, OSError):
pass
# Check audio libraries
try:
sd, _ = _import_audio()
try:
devices = sd.query_devices()
if not devices:
warnings.append("No audio input/output devices detected")
except Exception:
warnings.append("Audio subsystem error (PortAudio cannot query devices)")
except ImportError:
warnings.append("Audio libraries not installed (pip install sounddevice numpy)")
except OSError:
warnings.append(
"PortAudio system library not found -- install it first:\n"
" Linux: sudo apt-get install libportaudio2\n"
" macOS: brew install portaudio\n"
"Then retry /voice on."
)
return {
"available": len(warnings) == 0,
"warnings": warnings,
}
# ---------------------------------------------------------------------------
# Recording parameters
# ---------------------------------------------------------------------------
SAMPLE_RATE = 16000 # Whisper native rate
CHANNELS = 1 # Mono
DTYPE = "int16" # 16-bit PCM
SAMPLE_WIDTH = 2 # bytes per sample (int16)
MAX_RECORDING_SECONDS = 120 # Safety cap
# Silence detection defaults
SILENCE_RMS_THRESHOLD = 200 # RMS below this = silence (int16 range 0-32767)
SILENCE_DURATION_SECONDS = 3.0 # Seconds of continuous silence before auto-stop
# Temp directory for voice recordings
_TEMP_DIR = os.path.join(tempfile.gettempdir(), "hermes_voice")
# ============================================================================
# Audio cues (beep tones)
# ============================================================================
def play_beep(frequency: int = 880, duration: float = 0.12, count: int = 1) -> None:
"""Play a short beep tone using numpy + sounddevice.
Args:
frequency: Tone frequency in Hz (default 880 = A5).
duration: Duration of each beep in seconds.
count: Number of beeps to play (with short gap between).
"""
try:
sd, np = _import_audio()
except (ImportError, OSError):
return
try:
gap = 0.06 # seconds between beeps
samples_per_beep = int(SAMPLE_RATE * duration)
samples_per_gap = int(SAMPLE_RATE * gap)
parts = []
for i in range(count):
t = np.linspace(0, duration, samples_per_beep, endpoint=False)
# Apply fade in/out to avoid click artifacts
tone = np.sin(2 * np.pi * frequency * t)
fade_len = min(int(SAMPLE_RATE * 0.01), samples_per_beep // 4)
tone[:fade_len] *= np.linspace(0, 1, fade_len)
tone[-fade_len:] *= np.linspace(1, 0, fade_len)
parts.append((tone * 0.3 * 32767).astype(np.int16))
if i < count - 1:
parts.append(np.zeros(samples_per_gap, dtype=np.int16))
audio = np.concatenate(parts)
sd.play(audio, samplerate=SAMPLE_RATE)
# sd.wait() calls Event.wait() without timeout — hangs forever if the
# audio device stalls. Poll with a 2s ceiling and force-stop.
deadline = time.monotonic() + 2.0
while sd.get_stream() and sd.get_stream().active and time.monotonic() < deadline:
time.sleep(0.01)
sd.stop()
except Exception as e:
logger.debug("Beep playback failed: %s", e)
# ============================================================================
# AudioRecorder
# ============================================================================
class AudioRecorder:
"""Thread-safe audio recorder using sounddevice.InputStream.
Usage::
recorder = AudioRecorder()
recorder.start(on_silence_stop=my_callback)
# ... user speaks ...
wav_path = recorder.stop() # returns path to WAV file
# or
recorder.cancel() # discard without saving
If ``on_silence_stop`` is provided, recording automatically stops when
the user is silent for ``silence_duration`` seconds and calls the callback.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._stream: Any = None
self._frames: List[Any] = []
self._recording = False
self._start_time: float = 0.0
# Silence detection state
self._has_spoken = False
self._speech_start: float = 0.0 # When speech attempt began
self._dip_start: float = 0.0 # When current below-threshold dip began
self._min_speech_duration: float = 0.3 # Seconds of speech needed to confirm
self._max_dip_tolerance: float = 0.3 # Max dip duration before resetting speech
self._silence_start: float = 0.0
self._resume_start: float = 0.0 # Tracks sustained speech after silence starts
self._resume_dip_start: float = 0.0 # Dip tolerance tracker for resume detection
self._on_silence_stop = None
self._silence_threshold: int = SILENCE_RMS_THRESHOLD
self._silence_duration: float = SILENCE_DURATION_SECONDS
self._max_wait: float = 15.0 # Max seconds to wait for speech before auto-stop
# Peak RMS seen during recording (for speech presence check in stop())
self._peak_rms: int = 0
# Live audio level (read by UI for visual feedback)
self._current_rms: int = 0
# -- public properties ---------------------------------------------------
@property
def is_recording(self) -> bool:
return self._recording
@property
def elapsed_seconds(self) -> float:
if not self._recording:
return 0.0
return time.monotonic() - self._start_time
@property
def current_rms(self) -> int:
"""Current audio input RMS level (0-32767). Updated each audio chunk."""
return self._current_rms
# -- public methods ------------------------------------------------------
def _ensure_stream(self) -> None:
"""Create the audio InputStream once and keep it alive.
The stream stays open for the lifetime of the recorder. Between
recordings the callback simply discards audio chunks (``_recording``
is ``False``). This avoids the CoreAudio bug where closing and
re-opening an ``InputStream`` hangs indefinitely on macOS.
"""
if self._stream is not None:
return # already alive
sd, np = _import_audio()
def _callback(indata, frames, time_info, status): # noqa: ARG001
if status:
logger.debug("sounddevice status: %s", status)
# When not recording the stream is idle — discard audio.
if not self._recording:
return
self._frames.append(indata.copy())
# Compute RMS for level display and silence detection
rms = int(np.sqrt(np.mean(indata.astype(np.float64) ** 2)))
self._current_rms = rms
if rms > self._peak_rms:
self._peak_rms = rms
# Silence detection
if self._on_silence_stop is not None:
now = time.monotonic()
elapsed = now - self._start_time
if rms > self._silence_threshold:
# Audio is above threshold -- this is speech (or noise).
self._dip_start = 0.0 # Reset dip tracker
if self._speech_start == 0.0:
self._speech_start = now
elif not self._has_spoken and now - self._speech_start >= self._min_speech_duration:
self._has_spoken = True
logger.debug("Speech confirmed (%.2fs above threshold)",
now - self._speech_start)
# After speech is confirmed, only reset silence timer if
# speech is sustained (>0.3s above threshold). Brief
# spikes from ambient noise should NOT reset the timer.
if not self._has_spoken:
self._silence_start = 0.0
else:
# Track resumed speech with dip tolerance.
# Brief dips below threshold are normal during speech,
# so we mirror the initial speech detection pattern:
# start tracking, tolerate short dips, confirm after 0.3s.
self._resume_dip_start = 0.0 # Above threshold — no dip
if self._resume_start == 0.0:
self._resume_start = now
elif now - self._resume_start >= self._min_speech_duration:
self._silence_start = 0.0
self._resume_start = 0.0
elif self._has_spoken:
# Below threshold after speech confirmed.
# Use dip tolerance before resetting resume tracker —
# natural speech has brief dips below threshold.
if self._resume_start > 0:
if self._resume_dip_start == 0.0:
self._resume_dip_start = now
elif now - self._resume_dip_start >= self._max_dip_tolerance:
# Sustained dip — user actually stopped speaking
self._resume_start = 0.0
self._resume_dip_start = 0.0
elif self._speech_start > 0:
# We were in a speech attempt but RMS dipped.
# Tolerate brief dips (micro-pauses between syllables).
if self._dip_start == 0.0:
self._dip_start = now
elif now - self._dip_start >= self._max_dip_tolerance:
# Dip lasted too long -- genuine silence, reset
logger.debug("Speech attempt reset (dip lasted %.2fs)",
now - self._dip_start)
self._speech_start = 0.0
self._dip_start = 0.0
# Fire silence callback when:
# 1. User spoke then went silent for silence_duration, OR
# 2. No speech detected at all for max_wait seconds
should_fire = False
if self._has_spoken and rms <= self._silence_threshold:
# User was speaking and now is silent
if self._silence_start == 0.0:
self._silence_start = now
elif now - self._silence_start >= self._silence_duration:
logger.info("Silence detected (%.1fs), auto-stopping",
self._silence_duration)
should_fire = True
elif not self._has_spoken and elapsed >= self._max_wait:
logger.info("No speech within %.0fs, auto-stopping",
self._max_wait)
should_fire = True
if should_fire:
with self._lock:
cb = self._on_silence_stop
self._on_silence_stop = None # fire only once
if cb:
def _safe_cb():
try:
cb()
except Exception as e:
logger.error("Silence callback failed: %s", e, exc_info=True)
threading.Thread(target=_safe_cb, daemon=True).start()
# Create stream — may block on CoreAudio (first call only).
stream = None
try:
stream = sd.InputStream(
samplerate=SAMPLE_RATE,
channels=CHANNELS,
dtype=DTYPE,
callback=_callback,
)
stream.start()
except Exception as e:
if stream is not None:
try:
stream.close()
except Exception:
pass
raise RuntimeError(
f"Failed to open audio input stream: {e}. "
"Check that a microphone is connected and accessible."
) from e
self._stream = stream
def start(self, on_silence_stop=None) -> None:
"""Start capturing audio from the default input device.
The underlying InputStream is created once and kept alive across
recordings. Subsequent calls simply reset detection state and
toggle frame collection via ``_recording``.
Args:
on_silence_stop: Optional callback invoked (in a daemon thread) when
silence is detected after speech. The callback receives no arguments.
Use this to auto-stop recording and trigger transcription.
Raises ``RuntimeError`` if sounddevice/numpy are not installed
or if a recording is already in progress.
"""
try:
_import_audio()
except (ImportError, OSError) as e:
raise RuntimeError(
"Voice mode requires sounddevice and numpy.\n"
"Install with: pip install sounddevice numpy\n"
"Or: pip install hermes-agent[voice]"
) from e
with self._lock:
if self._recording:
return # already recording
self._frames = []
self._start_time = time.monotonic()
self._has_spoken = False
self._speech_start = 0.0
self._dip_start = 0.0
self._silence_start = 0.0
self._resume_start = 0.0
self._resume_dip_start = 0.0
self._peak_rms = 0
self._current_rms = 0
self._on_silence_stop = on_silence_stop
# Ensure the persistent stream is alive (no-op after first call).
self._ensure_stream()
with self._lock:
self._recording = True
logger.info("Voice recording started (rate=%d, channels=%d)", SAMPLE_RATE, CHANNELS)
def _close_stream_with_timeout(self, timeout: float = 3.0) -> None:
"""Close the audio stream with a timeout to prevent CoreAudio hangs."""
if self._stream is None:
return
stream = self._stream
self._stream = None
def _do_close():
try:
stream.stop()
stream.close()
except Exception:
pass
t = threading.Thread(target=_do_close, daemon=True)
t.start()
# Poll in short intervals so Ctrl+C is not blocked
deadline = __import__("time").monotonic() + timeout
while t.is_alive() and __import__("time").monotonic() < deadline:
t.join(timeout=0.1)
if t.is_alive():
logger.warning("Audio stream close timed out after %.1fs — forcing ahead", timeout)
def stop(self) -> Optional[str]:
"""Stop recording and write captured audio to a WAV file.
The underlying stream is kept alive for reuse only frame
collection is stopped.
Returns:
Path to the WAV file, or ``None`` if no audio was captured.
"""
with self._lock:
if not self._recording:
return None
self._recording = False
self._current_rms = 0
# Stream stays alive — no close needed.
if not self._frames:
return None
# Concatenate frames and write WAV
_, np = _import_audio()
audio_data = np.concatenate(self._frames, axis=0)
self._frames = []
elapsed = time.monotonic() - self._start_time
logger.info("Voice recording stopped (%.1fs, %d samples)", elapsed, len(audio_data))
# Skip very short recordings (< 0.3s of audio)
min_samples = int(SAMPLE_RATE * 0.3)
if len(audio_data) < min_samples:
logger.debug("Recording too short (%d samples), discarding", len(audio_data))
return None
# Skip silent recordings using peak RMS (not overall average, which
# gets diluted by silence at the end of the recording).
if self._peak_rms < SILENCE_RMS_THRESHOLD:
logger.info("Recording too quiet (peak RMS=%d < %d), discarding",
self._peak_rms, SILENCE_RMS_THRESHOLD)
return None
return self._write_wav(audio_data)
def cancel(self) -> None:
"""Stop recording and discard all captured audio.
The underlying stream is kept alive for reuse.
"""
with self._lock:
self._recording = False
self._frames = []
self._on_silence_stop = None
self._current_rms = 0
logger.info("Voice recording cancelled")
def shutdown(self) -> None:
"""Release the audio stream. Call when voice mode is disabled."""
with self._lock:
self._recording = False
self._frames = []
self._on_silence_stop = None
# Close stream OUTSIDE the lock to avoid deadlock with audio callback
self._close_stream_with_timeout()
logger.info("AudioRecorder shut down")
# -- private helpers -----------------------------------------------------
@staticmethod
def _write_wav(audio_data) -> str:
"""Write numpy int16 audio data to a WAV file.
Returns the file path.
"""
os.makedirs(_TEMP_DIR, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
wav_path = os.path.join(_TEMP_DIR, f"recording_{timestamp}.wav")
with wave.open(wav_path, "wb") as wf:
wf.setnchannels(CHANNELS)
wf.setsampwidth(SAMPLE_WIDTH)
wf.setframerate(SAMPLE_RATE)
wf.writeframes(audio_data.tobytes())
file_size = os.path.getsize(wav_path)
logger.info("WAV written: %s (%d bytes)", wav_path, file_size)
return wav_path
# ============================================================================
# Whisper hallucination filter
# ============================================================================
# Whisper commonly hallucinates these phrases on silent/near-silent audio.
WHISPER_HALLUCINATIONS = {
"thank you.",
"thank you",
"thanks for watching.",
"thanks for watching",
"subscribe to my channel.",
"subscribe to my channel",
"like and subscribe.",
"like and subscribe",
"please subscribe.",
"please subscribe",
"thank you for watching.",
"thank you for watching",
"bye.",
"bye",
"you",
"the end.",
"the end",
# Non-English hallucinations (common on silence)
"продолжение следует",
"продолжение следует...",
"sous-titres",
"sous-titres réalisés par la communauté d'amara.org",
"sottotitoli creati dalla comunità amara.org",
"untertitel von stephanie geiges",
"amara.org",
"www.mooji.org",
"ご視聴ありがとうございました",
}
# Regex patterns for repetitive hallucinations (e.g. "Thank you. Thank you. Thank you.")
_HALLUCINATION_REPEAT_RE = re.compile(
r'^(?:thank you|thanks|bye|you|ok|okay|the end|\.|\s|,|!)+$',
flags=re.IGNORECASE,
)
def is_whisper_hallucination(transcript: str) -> bool:
"""Check if a transcript is a known Whisper hallucination on silence."""
cleaned = transcript.strip().lower()
if not cleaned:
return True
# Exact match against known phrases
if cleaned.rstrip('.!') in WHISPER_HALLUCINATIONS or cleaned in WHISPER_HALLUCINATIONS:
return True
# Repetitive patterns (e.g. "Thank you. Thank you. Thank you. you")
if _HALLUCINATION_REPEAT_RE.match(cleaned):
return True
return False
# ============================================================================
# STT dispatch
# ============================================================================
def transcribe_recording(wav_path: str, model: Optional[str] = None) -> Dict[str, Any]:
"""Transcribe a WAV recording using the existing Whisper pipeline.
Delegates to ``tools.transcription_tools.transcribe_audio()``.
Filters out known Whisper hallucinations on silent audio.
Args:
wav_path: Path to the WAV file.
model: Whisper model name (default: from config or ``whisper-1``).
Returns:
Dict with ``success``, ``transcript``, and optionally ``error``.
"""
from tools.transcription_tools import transcribe_audio
result = transcribe_audio(wav_path, model=model)
# Filter out Whisper hallucinations (common on silent/near-silent audio)
if result.get("success") and is_whisper_hallucination(result.get("transcript", "")):
logger.info("Filtered Whisper hallucination: %r", result["transcript"])
return {"success": True, "transcript": "", "filtered": True}
return result
# ============================================================================
# Audio playback (interruptable)
# ============================================================================
# Global reference to the active playback process so it can be interrupted.
_active_playback: Optional[subprocess.Popen] = None
_playback_lock = threading.Lock()
def stop_playback() -> None:
"""Interrupt the currently playing audio (if any)."""
global _active_playback
with _playback_lock:
proc = _active_playback
_active_playback = None
if proc and proc.poll() is None:
try:
proc.terminate()
logger.info("Audio playback interrupted")
except Exception:
pass
# Also stop sounddevice playback if active
try:
sd, _ = _import_audio()
sd.stop()
except Exception:
pass
def play_audio_file(file_path: str) -> bool:
"""Play an audio file through the default output device.
Strategy:
1. WAV files via ``sounddevice.play()`` when available.
2. System commands: ``afplay`` (macOS), ``ffplay`` (cross-platform),
``aplay`` (Linux ALSA).
Playback can be interrupted by calling ``stop_playback()``.
Returns:
``True`` if playback succeeded, ``False`` otherwise.
"""
global _active_playback
if not os.path.isfile(file_path):
logger.warning("Audio file not found: %s", file_path)
return False
# Try sounddevice for WAV files
if file_path.endswith(".wav"):
try:
sd, np = _import_audio()
with wave.open(file_path, "rb") as wf:
frames = wf.readframes(wf.getnframes())
audio_data = np.frombuffer(frames, dtype=np.int16)
sample_rate = wf.getframerate()
sd.play(audio_data, samplerate=sample_rate)
# sd.wait() calls Event.wait() without timeout — hangs forever if
# the audio device stalls. Poll with a ceiling and force-stop.
duration_secs = len(audio_data) / sample_rate
deadline = time.monotonic() + duration_secs + 2.0
while sd.get_stream() and sd.get_stream().active and time.monotonic() < deadline:
time.sleep(0.01)
sd.stop()
return True
except (ImportError, OSError):
pass # audio libs not available, fall through to system players
except Exception as e:
logger.debug("sounddevice playback failed: %s", e)
# Fall back to system audio players (using Popen for interruptability)
system = platform.system()
players = []
if system == "Darwin":
players.append(["afplay", file_path])
players.append(["ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", file_path])
if system == "Linux":
players.append(["aplay", "-q", file_path])
for cmd in players:
exe = shutil.which(cmd[0])
if exe:
try:
proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
with _playback_lock:
_active_playback = proc
proc.wait(timeout=300)
with _playback_lock:
_active_playback = None
return True
except subprocess.TimeoutExpired:
logger.warning("System player %s timed out, killing process", cmd[0])
proc.kill()
proc.wait()
with _playback_lock:
_active_playback = None
except Exception as e:
logger.debug("System player %s failed: %s", cmd[0], e)
with _playback_lock:
_active_playback = None
logger.warning("No audio player available for %s", file_path)
return False
# ============================================================================
# Requirements check
# ============================================================================
def check_voice_requirements() -> Dict[str, Any]:
"""Check if all voice mode requirements are met.
Returns:
Dict with ``available``, ``audio_available``, ``stt_available``,
``missing_packages``, and ``details``.
"""
# Determine STT provider availability
from tools.transcription_tools import _get_provider, _load_stt_config, is_stt_enabled, _HAS_FASTER_WHISPER
stt_config = _load_stt_config()
stt_enabled = is_stt_enabled(stt_config)
stt_provider = _get_provider(stt_config)
stt_available = stt_enabled and stt_provider != "none"
missing: List[str] = []
has_audio = _audio_available()
if not has_audio:
missing.extend(["sounddevice", "numpy"])
# Environment detection
env_check = detect_audio_environment()
available = has_audio and stt_available and env_check["available"]
details_parts = []
if has_audio:
details_parts.append("Audio capture: OK")
else:
details_parts.append("Audio capture: MISSING (pip install sounddevice numpy)")
if not stt_enabled:
details_parts.append("STT provider: DISABLED in config (stt.enabled: false)")
elif stt_provider == "local":
details_parts.append("STT provider: OK (local faster-whisper)")
elif stt_provider == "groq":
details_parts.append("STT provider: OK (Groq)")
elif stt_provider == "openai":
details_parts.append("STT provider: OK (OpenAI)")
else:
details_parts.append(
"STT provider: MISSING (pip install faster-whisper, "
"or set GROQ_API_KEY / VOICE_TOOLS_OPENAI_KEY)"
)
for warning in env_check["warnings"]:
details_parts.append(f"Environment: {warning}")
return {
"available": available,
"audio_available": has_audio,
"stt_available": stt_available,
"missing_packages": missing,
"details": "\n".join(details_parts),
"environment": env_check,
}
# ============================================================================
# Temp file cleanup
# ============================================================================
def cleanup_temp_recordings(max_age_seconds: int = 3600) -> int:
"""Remove old temporary voice recording files.
Args:
max_age_seconds: Delete files older than this (default: 1 hour).
Returns:
Number of files deleted.
"""
if not os.path.isdir(_TEMP_DIR):
return 0
deleted = 0
now = time.time()
for entry in os.scandir(_TEMP_DIR):
if entry.is_file() and entry.name.startswith("recording_") and entry.name.endswith(".wav"):
try:
age = now - entry.stat().st_mtime
if age > max_age_seconds:
os.unlink(entry.path)
deleted += 1
except OSError:
pass
if deleted:
logger.debug("Cleaned up %d old voice recordings", deleted)
return deleted

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,285 @@
"""Website access policy helpers for URL-capable tools.
This module loads a user-managed website blocklist from ~/.hermes/config.yaml
and optional shared list files. It is intentionally lightweight so web/browser
tools can enforce URL policy without pulling in the heavier CLI config stack.
Policy is cached in memory with a short TTL so config changes take effect
quickly without re-reading the file on every URL check.
"""
from __future__ import annotations
import fnmatch
import logging
import os
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
_DEFAULT_WEBSITE_BLOCKLIST = {
"enabled": False,
"domains": [],
"shared_files": [],
}
# Cache: parsed policy + timestamp. Avoids re-reading config.yaml on every
# URL check (a web_crawl with 50 pages would otherwise mean 51 YAML parses).
_CACHE_TTL_SECONDS = 30.0
_cache_lock = threading.Lock()
_cached_policy: Optional[Dict[str, Any]] = None
_cached_policy_path: Optional[str] = None
_cached_policy_time: float = 0.0
def _get_hermes_home() -> Path:
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
def _get_default_config_path() -> Path:
return _get_hermes_home() / "config.yaml"
class WebsitePolicyError(Exception):
"""Raised when a website policy file is malformed."""
def _normalize_host(host: str) -> str:
return (host or "").strip().lower().rstrip(".")
def _normalize_rule(rule: Any) -> Optional[str]:
if not isinstance(rule, str):
return None
value = rule.strip().lower()
if not value or value.startswith("#"):
return None
if "://" in value:
parsed = urlparse(value)
value = parsed.netloc or parsed.path
value = value.split("/", 1)[0].strip().rstrip(".")
if value.startswith("www."):
value = value[4:]
return value or None
def _iter_blocklist_file_rules(path: Path) -> List[str]:
"""Load rules from a shared blocklist file.
Missing or unreadable files log a warning and return an empty list
rather than raising a bad file path should not disable all web tools.
"""
try:
raw = path.read_text(encoding="utf-8")
except FileNotFoundError:
logger.warning("Shared blocklist file not found (skipping): %s", path)
return []
except (OSError, UnicodeDecodeError) as exc:
logger.warning("Failed to read shared blocklist file %s (skipping): %s", path, exc)
return []
rules: List[str] = []
for line in raw.splitlines():
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
normalized = _normalize_rule(stripped)
if normalized:
rules.append(normalized)
return rules
def _load_policy_config(config_path: Optional[Path] = None) -> Dict[str, Any]:
config_path = config_path or _get_default_config_path()
if not config_path.exists():
return dict(_DEFAULT_WEBSITE_BLOCKLIST)
try:
import yaml
except ImportError:
logger.debug("PyYAML not installed — website blocklist disabled")
return dict(_DEFAULT_WEBSITE_BLOCKLIST)
try:
with open(config_path, encoding="utf-8") as f:
config = yaml.safe_load(f) or {}
except yaml.YAMLError as exc:
raise WebsitePolicyError(f"Invalid config YAML at {config_path}: {exc}") from exc
except OSError as exc:
raise WebsitePolicyError(f"Failed to read config file {config_path}: {exc}") from exc
if not isinstance(config, dict):
raise WebsitePolicyError("config root must be a mapping")
security = config.get("security", {})
if security is None:
security = {}
if not isinstance(security, dict):
raise WebsitePolicyError("security must be a mapping")
website_blocklist = security.get("website_blocklist", {})
if website_blocklist is None:
website_blocklist = {}
if not isinstance(website_blocklist, dict):
raise WebsitePolicyError("security.website_blocklist must be a mapping")
policy = dict(_DEFAULT_WEBSITE_BLOCKLIST)
policy.update(website_blocklist)
return policy
def load_website_blocklist(config_path: Optional[Path] = None) -> Dict[str, Any]:
"""Load and return the parsed website blocklist policy.
Results are cached for ``_CACHE_TTL_SECONDS`` to avoid re-reading
config.yaml on every URL check. Pass an explicit ``config_path``
to bypass the cache (used by tests).
"""
global _cached_policy, _cached_policy_path, _cached_policy_time
resolved_path = str(config_path) if config_path else "__default__"
now = time.monotonic()
# Return cached policy if still fresh and same path
if config_path is None:
with _cache_lock:
if (
_cached_policy is not None
and _cached_policy_path == resolved_path
and (now - _cached_policy_time) < _CACHE_TTL_SECONDS
):
return _cached_policy
config_path = config_path or _get_default_config_path()
policy = _load_policy_config(config_path)
raw_domains = policy.get("domains", []) or []
if not isinstance(raw_domains, list):
raise WebsitePolicyError("security.website_blocklist.domains must be a list")
raw_shared_files = policy.get("shared_files", []) or []
if not isinstance(raw_shared_files, list):
raise WebsitePolicyError("security.website_blocklist.shared_files must be a list")
enabled = policy.get("enabled", True)
if not isinstance(enabled, bool):
raise WebsitePolicyError("security.website_blocklist.enabled must be a boolean")
rules: List[Dict[str, str]] = []
seen: set[Tuple[str, str]] = set()
for raw_rule in raw_domains:
normalized = _normalize_rule(raw_rule)
if normalized and ("config", normalized) not in seen:
rules.append({"pattern": normalized, "source": "config"})
seen.add(("config", normalized))
for shared_file in raw_shared_files:
if not isinstance(shared_file, str) or not shared_file.strip():
continue
path = Path(shared_file).expanduser()
if not path.is_absolute():
path = (_get_hermes_home() / path).resolve()
for normalized in _iter_blocklist_file_rules(path):
key = (str(path), normalized)
if key in seen:
continue
rules.append({"pattern": normalized, "source": str(path)})
seen.add(key)
result = {"enabled": enabled, "rules": rules}
# Cache the result (only for the default path — explicit paths are tests)
if config_path == _get_default_config_path():
with _cache_lock:
_cached_policy = result
_cached_policy_path = "__default__"
_cached_policy_time = now
return result
def invalidate_cache() -> None:
"""Force the next ``check_website_access`` call to re-read config."""
global _cached_policy
with _cache_lock:
_cached_policy = None
def _match_host_against_rule(host: str, pattern: str) -> bool:
if not host or not pattern:
return False
if pattern.startswith("*."):
return fnmatch.fnmatch(host, pattern)
return host == pattern or host.endswith(f".{pattern}")
def _extract_host_from_urlish(url: str) -> str:
parsed = urlparse(url)
host = _normalize_host(parsed.hostname or parsed.netloc)
if host:
return host
if "://" not in url:
schemeless = urlparse(f"//{url}")
host = _normalize_host(schemeless.hostname or schemeless.netloc)
if host:
return host
return ""
def check_website_access(url: str, config_path: Optional[Path] = None) -> Optional[Dict[str, str]]:
"""Check whether a URL is allowed by the website blocklist policy.
Returns ``None`` if access is allowed, or a dict with block metadata
(``host``, ``rule``, ``source``, ``message``) if blocked.
Never raises on policy errors logs a warning and returns ``None``
(fail-open) so a config typo doesn't break all web tools. Pass
``config_path`` explicitly (tests) to get strict error propagation.
"""
# Fast path: if no explicit config_path and the cached policy is disabled
# or empty, skip all work (no YAML read, no host extraction).
if config_path is None:
with _cache_lock:
if _cached_policy is not None and not _cached_policy.get("enabled"):
return None
host = _extract_host_from_urlish(url)
if not host:
return None
try:
policy = load_website_blocklist(config_path)
except WebsitePolicyError as exc:
if config_path is not None:
raise # Tests pass explicit paths — let errors propagate
logger.warning("Website policy config error (failing open): %s", exc)
return None
except Exception as exc:
logger.warning("Unexpected error loading website policy (failing open): %s", exc)
return None
if not policy.get("enabled"):
return None
for rule in policy.get("rules", []):
pattern = rule.get("pattern", "")
if _match_host_against_rule(host, pattern):
logger.info("Blocked URL %s — matched rule '%s' from %s",
url, pattern, rule.get("source", "config"))
return {
"url": url,
"host": host,
"rule": pattern,
"source": rule.get("source", "config"),
"message": (
f"Blocked by website policy: '{host}' matched rule '{pattern}'"
f" from {rule.get('source', 'config')}"
),
}
return None