The architecture has been updated
This commit is contained in:
parent
805f7a017e
commit
a01257ead9
1119 changed files with 226 additions and 352 deletions
266
hermes_code/tools/__init__.py
Normal file
266
hermes_code/tools/__init__.py
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tools Package
|
||||
|
||||
This package contains all the specific tool implementations for the Hermes Agent.
|
||||
Each module provides specialized functionality for different capabilities:
|
||||
|
||||
- web_tools: Web search, content extraction, and crawling
|
||||
- terminal_tool: Command execution (local/docker/modal/daytona/ssh/singularity backends)
|
||||
- vision_tools: Image analysis and understanding
|
||||
- mixture_of_agents_tool: Multi-model collaborative reasoning
|
||||
- image_generation_tool: Text-to-image generation with upscaling
|
||||
|
||||
The tools are imported into model_tools.py which provides a unified interface
|
||||
for the AI agent to access all capabilities.
|
||||
"""
|
||||
|
||||
# Export all tools for easy importing
|
||||
from .web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
check_firecrawl_api_key
|
||||
)
|
||||
|
||||
# Primary terminal tool (local/docker/singularity/modal/daytona/ssh)
|
||||
from .terminal_tool import (
|
||||
terminal_tool,
|
||||
check_terminal_requirements,
|
||||
cleanup_vm,
|
||||
cleanup_all_environments,
|
||||
get_active_environments_info,
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
TERMINAL_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
from .vision_tools import (
|
||||
vision_analyze_tool,
|
||||
check_vision_requirements
|
||||
)
|
||||
|
||||
from .mixture_of_agents_tool import (
|
||||
mixture_of_agents_tool,
|
||||
check_moa_requirements
|
||||
)
|
||||
|
||||
from .image_generation_tool import (
|
||||
image_generate_tool,
|
||||
check_image_generation_requirements
|
||||
)
|
||||
|
||||
from .skills_tool import (
|
||||
skills_list,
|
||||
skill_view,
|
||||
check_skills_requirements,
|
||||
SKILLS_TOOL_DESCRIPTION
|
||||
)
|
||||
|
||||
from .skill_manager_tool import (
|
||||
skill_manage,
|
||||
check_skill_manage_requirements,
|
||||
SKILL_MANAGE_SCHEMA
|
||||
)
|
||||
|
||||
# Browser automation tools (agent-browser + Browserbase)
|
||||
# from .browser_tool import (
|
||||
# browser_navigate,
|
||||
# browser_snapshot,
|
||||
# browser_click,
|
||||
# browser_type,
|
||||
# browser_scroll,
|
||||
# browser_back,
|
||||
# browser_press,
|
||||
# browser_close,
|
||||
# browser_get_images,
|
||||
# browser_vision,
|
||||
# cleanup_browser,
|
||||
# cleanup_all_browsers,
|
||||
# get_active_browser_sessions,
|
||||
# check_browser_requirements,
|
||||
# BROWSER_TOOL_SCHEMAS
|
||||
# )
|
||||
|
||||
from .browser_use_tool import run_browser_task
|
||||
|
||||
from .browser_tool import cleanup_browser, cleanup_all_browsers
|
||||
|
||||
# Cronjob management tools (CLI-only, hermes-cli toolset)
|
||||
from .cronjob_tools import (
|
||||
cronjob,
|
||||
schedule_cronjob,
|
||||
list_cronjobs,
|
||||
remove_cronjob,
|
||||
check_cronjob_requirements,
|
||||
get_cronjob_tool_definitions,
|
||||
CRONJOB_SCHEMA,
|
||||
)
|
||||
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from .rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_list_runs,
|
||||
rl_test_inference,
|
||||
check_rl_api_keys,
|
||||
get_missing_keys,
|
||||
)
|
||||
|
||||
# File manipulation tools (read, write, patch, search)
|
||||
from .file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
patch_tool,
|
||||
search_tool,
|
||||
get_file_tools,
|
||||
clear_file_ops_cache,
|
||||
)
|
||||
|
||||
# Text-to-speech tools (Edge TTS / ElevenLabs / OpenAI)
|
||||
from .tts_tool import (
|
||||
text_to_speech_tool,
|
||||
check_tts_requirements,
|
||||
)
|
||||
|
||||
# Planning & task management tool
|
||||
from .todo_tool import (
|
||||
todo_tool,
|
||||
check_todo_requirements,
|
||||
TODO_SCHEMA,
|
||||
TodoStore,
|
||||
)
|
||||
|
||||
# Clarifying questions tool (interactive Q&A with the user)
|
||||
from .clarify_tool import (
|
||||
clarify_tool,
|
||||
check_clarify_requirements,
|
||||
CLARIFY_SCHEMA,
|
||||
)
|
||||
|
||||
# Code execution sandbox (programmatic tool calling)
|
||||
from .code_execution_tool import (
|
||||
execute_code,
|
||||
check_sandbox_requirements,
|
||||
EXECUTE_CODE_SCHEMA,
|
||||
)
|
||||
|
||||
# Subagent delegation (spawn child agents with isolated context)
|
||||
from .delegate_tool import (
|
||||
delegate_task,
|
||||
check_delegate_requirements,
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
)
|
||||
|
||||
# File tools have no external requirements - they use the terminal backend
|
||||
def check_file_requirements():
|
||||
"""File tools only require terminal backend to be available."""
|
||||
from .terminal_tool import check_terminal_requirements
|
||||
return check_terminal_requirements()
|
||||
|
||||
__all__ = [
|
||||
# Web tools
|
||||
'web_search_tool',
|
||||
'web_extract_tool',
|
||||
'web_crawl_tool',
|
||||
'check_firecrawl_api_key',
|
||||
# Terminal tools
|
||||
'terminal_tool',
|
||||
'check_terminal_requirements',
|
||||
'cleanup_vm',
|
||||
'cleanup_all_environments',
|
||||
'get_active_environments_info',
|
||||
'register_task_env_overrides',
|
||||
'clear_task_env_overrides',
|
||||
'TERMINAL_TOOL_DESCRIPTION',
|
||||
# Vision tools
|
||||
'vision_analyze_tool',
|
||||
'check_vision_requirements',
|
||||
# MoA tools
|
||||
'mixture_of_agents_tool',
|
||||
'check_moa_requirements',
|
||||
# Image generation tools
|
||||
'image_generate_tool',
|
||||
'check_image_generation_requirements',
|
||||
# Skills tools
|
||||
'skills_list',
|
||||
'skill_view',
|
||||
'check_skills_requirements',
|
||||
'SKILLS_TOOL_DESCRIPTION',
|
||||
# Skill management
|
||||
'skill_manage',
|
||||
'check_skill_manage_requirements',
|
||||
'SKILL_MANAGE_SCHEMA',
|
||||
# Browser automation tools
|
||||
'browser_navigate',
|
||||
'browser_snapshot',
|
||||
'browser_click',
|
||||
'browser_type',
|
||||
'browser_scroll',
|
||||
'browser_back',
|
||||
'browser_press',
|
||||
'browser_close',
|
||||
'browser_get_images',
|
||||
'browser_vision',
|
||||
'cleanup_browser',
|
||||
'cleanup_all_browsers',
|
||||
'get_active_browser_sessions',
|
||||
'check_browser_requirements',
|
||||
'BROWSER_TOOL_SCHEMAS',
|
||||
# Cronjob management tools (CLI-only)
|
||||
'cronjob',
|
||||
'schedule_cronjob',
|
||||
'list_cronjobs',
|
||||
'remove_cronjob',
|
||||
'check_cronjob_requirements',
|
||||
'get_cronjob_tool_definitions',
|
||||
'CRONJOB_SCHEMA',
|
||||
# RL Training tools
|
||||
'rl_list_environments',
|
||||
'rl_select_environment',
|
||||
'rl_get_current_config',
|
||||
'rl_edit_config',
|
||||
'rl_start_training',
|
||||
'rl_check_status',
|
||||
'rl_stop_training',
|
||||
'rl_get_results',
|
||||
'rl_list_runs',
|
||||
'rl_test_inference',
|
||||
'check_rl_api_keys',
|
||||
'get_missing_keys',
|
||||
# File manipulation tools
|
||||
'read_file_tool',
|
||||
'write_file_tool',
|
||||
'patch_tool',
|
||||
'search_tool',
|
||||
'get_file_tools',
|
||||
'clear_file_ops_cache',
|
||||
'check_file_requirements',
|
||||
# Text-to-speech tools
|
||||
'text_to_speech_tool',
|
||||
'check_tts_requirements',
|
||||
# Planning & task management tool
|
||||
'todo_tool',
|
||||
'check_todo_requirements',
|
||||
'TODO_SCHEMA',
|
||||
'TodoStore',
|
||||
# Clarifying questions tool
|
||||
'clarify_tool',
|
||||
'check_clarify_requirements',
|
||||
'CLARIFY_SCHEMA',
|
||||
# Code execution sandbox
|
||||
'execute_code',
|
||||
'check_sandbox_requirements',
|
||||
'EXECUTE_CODE_SCHEMA',
|
||||
# Subagent delegation
|
||||
'delegate_task',
|
||||
'check_delegate_requirements',
|
||||
'DELEGATE_TASK_SCHEMA',
|
||||
]
|
||||
|
||||
44
hermes_code/tools/ansi_strip.py
Normal file
44
hermes_code/tools/ansi_strip.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Strip ANSI escape sequences from subprocess output.
|
||||
|
||||
Used by terminal_tool, code_execution_tool, and process_registry to clean
|
||||
command output before returning it to the model. This prevents ANSI codes
|
||||
from entering the model's context — which is the root cause of models
|
||||
copying escape sequences into file writes.
|
||||
|
||||
Covers the full ECMA-48 spec: CSI (including private-mode ``?`` prefix,
|
||||
colon-separated params, intermediate bytes), OSC (BEL and ST terminators),
|
||||
DCS/SOS/PM/APC string sequences, nF multi-byte escapes, Fp/Fe/Fs
|
||||
single-byte escapes, and 8-bit C1 control characters.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
_ANSI_ESCAPE_RE = re.compile(
|
||||
r"\x1b"
|
||||
r"(?:"
|
||||
r"\[[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # CSI sequence
|
||||
r"|\][\s\S]*?(?:\x07|\x1b\\)" # OSC (BEL or ST terminator)
|
||||
r"|[PX^_][\s\S]*?(?:\x1b\\)" # DCS/SOS/PM/APC strings
|
||||
r"|[\x20-\x2f]+[\x30-\x7e]" # nF escape sequences
|
||||
r"|[\x30-\x7e]" # Fp/Fe/Fs single-byte
|
||||
r")"
|
||||
r"|\x9b[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]" # 8-bit CSI
|
||||
r"|\x9d[\s\S]*?(?:\x07|\x9c)" # 8-bit OSC
|
||||
r"|[\x80-\x9f]", # Other 8-bit C1 controls
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Fast-path check — skip full regex when no escape-like bytes are present.
|
||||
_HAS_ESCAPE = re.compile(r"[\x1b\x80-\x9f]")
|
||||
|
||||
|
||||
def strip_ansi(text: str) -> str:
|
||||
"""Remove ANSI escape sequences from text.
|
||||
|
||||
Returns the input unchanged (fast path) when no ESC or C1 bytes are
|
||||
present. Safe to call on any string — clean text passes through
|
||||
with negligible overhead.
|
||||
"""
|
||||
if not text or not _HAS_ESCAPE.search(text):
|
||||
return text
|
||||
return _ANSI_ESCAPE_RE.sub("", text)
|
||||
590
hermes_code/tools/approval.py
Normal file
590
hermes_code/tools/approval.py
Normal file
|
|
@ -0,0 +1,590 @@
|
|||
"""Dangerous command approval -- detection, prompting, and per-session state.
|
||||
|
||||
This module is the single source of truth for the dangerous command system:
|
||||
- Pattern detection (DANGEROUS_PATTERNS, detect_dangerous_command)
|
||||
- Per-session approval state (thread-safe, keyed by session_key)
|
||||
- Approval prompting (CLI interactive + gateway async)
|
||||
- Smart approval via auxiliary LLM (auto-approve low-risk commands)
|
||||
- Permanent allowlist persistence (config.yaml)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =========================================================================
|
||||
# Dangerous command patterns
|
||||
# =========================================================================
|
||||
|
||||
DANGEROUS_PATTERNS = [
|
||||
(r'\brm\s+(-[^\s]*\s+)*/', "delete in root path"),
|
||||
(r'\brm\s+-[^\s]*r', "recursive delete"),
|
||||
(r'\brm\s+--recursive\b', "recursive delete (long flag)"),
|
||||
(r'\bchmod\s+(-[^\s]*\s+)*777\b', "world-writable permissions"),
|
||||
(r'\bchmod\s+--recursive\b.*777', "recursive world-writable (long flag)"),
|
||||
(r'\bchown\s+(-[^\s]*)?R\s+root', "recursive chown to root"),
|
||||
(r'\bchown\s+--recursive\b.*root', "recursive chown to root (long flag)"),
|
||||
(r'\bmkfs\b', "format filesystem"),
|
||||
(r'\bdd\s+.*if=', "disk copy"),
|
||||
(r'>\s*/dev/sd', "write to block device"),
|
||||
(r'\bDROP\s+(TABLE|DATABASE)\b', "SQL DROP"),
|
||||
(r'\bDELETE\s+FROM\b(?!.*\bWHERE\b)', "SQL DELETE without WHERE"),
|
||||
(r'\bTRUNCATE\s+(TABLE)?\s*\w', "SQL TRUNCATE"),
|
||||
(r'>\s*/etc/', "overwrite system config"),
|
||||
(r'\bsystemctl\s+(stop|disable|mask)\b', "stop/disable system service"),
|
||||
(r'\bkill\s+-9\s+-1\b', "kill all processes"),
|
||||
(r'\bpkill\s+-9\b', "force kill processes"),
|
||||
(r':\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\}\s*;\s*:', "fork bomb"),
|
||||
# Any shell invocation via -c or combined flags like -lc, -ic, etc.
|
||||
(r'\b(bash|sh|zsh|ksh)\s+-[^\s]*c(\s+|$)', "shell command via -c/-lc flag"),
|
||||
(r'\b(python[23]?|perl|ruby|node)\s+-[ec]\s+', "script execution via -e/-c flag"),
|
||||
(r'\b(curl|wget)\b.*\|\s*(ba)?sh\b', "pipe remote content to shell"),
|
||||
(r'\b(bash|sh|zsh|ksh)\s+<\s*<?\s*\(\s*(curl|wget)\b', "execute remote script via process substitution"),
|
||||
(r'\btee\b.*(/etc/|/dev/sd|\.ssh/|\.hermes/\.env)', "overwrite system file via tee"),
|
||||
(r'\bxargs\s+.*\brm\b', "xargs with rm"),
|
||||
(r'\bfind\b.*-exec\s+(/\S*/)?rm\b', "find -exec rm"),
|
||||
(r'\bfind\b.*-delete\b', "find -delete"),
|
||||
# Gateway protection: never start gateway outside systemd management
|
||||
(r'gateway\s+run\b.*(&\s*$|&\s*;|\bdisown\b|\bsetsid\b)', "start gateway outside systemd (use 'systemctl --user restart hermes-gateway')"),
|
||||
(r'\bnohup\b.*gateway\s+run\b', "start gateway outside systemd (use 'systemctl --user restart hermes-gateway')"),
|
||||
]
|
||||
|
||||
|
||||
def _legacy_pattern_key(pattern: str) -> str:
|
||||
"""Reproduce the old regex-derived approval key for backwards compatibility."""
|
||||
return pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20]
|
||||
|
||||
|
||||
_PATTERN_KEY_ALIASES: dict[str, set[str]] = {}
|
||||
for _pattern, _description in DANGEROUS_PATTERNS:
|
||||
_legacy_key = _legacy_pattern_key(_pattern)
|
||||
_canonical_key = _description
|
||||
_PATTERN_KEY_ALIASES.setdefault(_canonical_key, set()).update({_canonical_key, _legacy_key})
|
||||
_PATTERN_KEY_ALIASES.setdefault(_legacy_key, set()).update({_legacy_key, _canonical_key})
|
||||
|
||||
|
||||
def _approval_key_aliases(pattern_key: str) -> set[str]:
|
||||
"""Return all approval keys that should match this pattern.
|
||||
|
||||
New approvals use the human-readable description string, but older
|
||||
command_allowlist entries and session approvals may still contain the
|
||||
historical regex-derived key.
|
||||
"""
|
||||
return _PATTERN_KEY_ALIASES.get(pattern_key, {pattern_key})
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Detection
|
||||
# =========================================================================
|
||||
|
||||
def detect_dangerous_command(command: str) -> tuple:
|
||||
"""Check if a command matches any dangerous patterns.
|
||||
|
||||
Returns:
|
||||
(is_dangerous, pattern_key, description) or (False, None, None)
|
||||
"""
|
||||
command_lower = command.lower()
|
||||
for pattern, description in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, command_lower, re.IGNORECASE | re.DOTALL):
|
||||
pattern_key = description
|
||||
return (True, pattern_key, description)
|
||||
return (False, None, None)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Per-session approval state (thread-safe)
|
||||
# =========================================================================
|
||||
|
||||
_lock = threading.Lock()
|
||||
_pending: dict[str, dict] = {}
|
||||
_session_approved: dict[str, set] = {}
|
||||
_permanent_approved: set = set()
|
||||
|
||||
|
||||
def submit_pending(session_key: str, approval: dict):
|
||||
"""Store a pending approval request for a session."""
|
||||
with _lock:
|
||||
_pending[session_key] = approval
|
||||
|
||||
|
||||
def pop_pending(session_key: str) -> Optional[dict]:
|
||||
"""Retrieve and remove a pending approval for a session."""
|
||||
with _lock:
|
||||
return _pending.pop(session_key, None)
|
||||
|
||||
|
||||
def has_pending(session_key: str) -> bool:
|
||||
"""Check if a session has a pending approval request."""
|
||||
with _lock:
|
||||
return session_key in _pending
|
||||
|
||||
|
||||
def approve_session(session_key: str, pattern_key: str):
|
||||
"""Approve a pattern for this session only."""
|
||||
with _lock:
|
||||
_session_approved.setdefault(session_key, set()).add(pattern_key)
|
||||
|
||||
|
||||
def is_approved(session_key: str, pattern_key: str) -> bool:
|
||||
"""Check if a pattern is approved (session-scoped or permanent).
|
||||
|
||||
Accept both the current canonical key and the legacy regex-derived key so
|
||||
existing command_allowlist entries continue to work after key migrations.
|
||||
"""
|
||||
aliases = _approval_key_aliases(pattern_key)
|
||||
with _lock:
|
||||
if any(alias in _permanent_approved for alias in aliases):
|
||||
return True
|
||||
session_approvals = _session_approved.get(session_key, set())
|
||||
return any(alias in session_approvals for alias in aliases)
|
||||
|
||||
|
||||
def approve_permanent(pattern_key: str):
|
||||
"""Add a pattern to the permanent allowlist."""
|
||||
with _lock:
|
||||
_permanent_approved.add(pattern_key)
|
||||
|
||||
|
||||
def load_permanent(patterns: set):
|
||||
"""Bulk-load permanent allowlist entries from config."""
|
||||
with _lock:
|
||||
_permanent_approved.update(patterns)
|
||||
|
||||
|
||||
def clear_session(session_key: str):
|
||||
"""Clear all approvals and pending requests for a session."""
|
||||
with _lock:
|
||||
_session_approved.pop(session_key, None)
|
||||
_pending.pop(session_key, None)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Config persistence for permanent allowlist
|
||||
# =========================================================================
|
||||
|
||||
def load_permanent_allowlist() -> set:
|
||||
"""Load permanently allowed command patterns from config.
|
||||
|
||||
Also syncs them into the approval module so is_approved() works for
|
||||
patterns added via 'always' in a previous session.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
patterns = set(config.get("command_allowlist", []) or [])
|
||||
if patterns:
|
||||
load_permanent(patterns)
|
||||
return patterns
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
def save_permanent_allowlist(patterns: set):
|
||||
"""Save permanently allowed command patterns to config."""
|
||||
try:
|
||||
from hermes_cli.config import load_config, save_config
|
||||
config = load_config()
|
||||
config["command_allowlist"] = list(patterns)
|
||||
save_config(config)
|
||||
except Exception as e:
|
||||
logger.warning("Could not save allowlist: %s", e)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approval prompting + orchestration
|
||||
# =========================================================================
|
||||
|
||||
def prompt_dangerous_approval(command: str, description: str,
|
||||
timeout_seconds: int = 60,
|
||||
allow_permanent: bool = True,
|
||||
approval_callback=None) -> str:
|
||||
"""Prompt the user to approve a dangerous command (CLI only).
|
||||
|
||||
Args:
|
||||
allow_permanent: When False, hide the [a]lways option (used when
|
||||
tirith warnings are present, since broad permanent allowlisting
|
||||
is inappropriate for content-level security findings).
|
||||
approval_callback: Optional callback registered by the CLI for
|
||||
prompt_toolkit integration. Signature:
|
||||
(command, description, *, allow_permanent=True) -> str.
|
||||
|
||||
Returns: 'once', 'session', 'always', or 'deny'
|
||||
"""
|
||||
if approval_callback is not None:
|
||||
try:
|
||||
return approval_callback(command, description,
|
||||
allow_permanent=allow_permanent)
|
||||
except Exception:
|
||||
return "deny"
|
||||
|
||||
os.environ["HERMES_SPINNER_PAUSE"] = "1"
|
||||
try:
|
||||
while True:
|
||||
print()
|
||||
print(f" ⚠️ DANGEROUS COMMAND: {description}")
|
||||
print(f" {command}")
|
||||
print()
|
||||
if allow_permanent:
|
||||
print(" [o]nce | [s]ession | [a]lways | [d]eny")
|
||||
else:
|
||||
print(" [o]nce | [s]ession | [d]eny")
|
||||
print()
|
||||
sys.stdout.flush()
|
||||
|
||||
result = {"choice": ""}
|
||||
|
||||
def get_input():
|
||||
try:
|
||||
prompt = " Choice [o/s/a/D]: " if allow_permanent else " Choice [o/s/D]: "
|
||||
result["choice"] = input(prompt).strip().lower()
|
||||
except (EOFError, OSError):
|
||||
result["choice"] = ""
|
||||
|
||||
thread = threading.Thread(target=get_input, daemon=True)
|
||||
thread.start()
|
||||
thread.join(timeout=timeout_seconds)
|
||||
|
||||
if thread.is_alive():
|
||||
print("\n ⏱ Timeout - denying command")
|
||||
return "deny"
|
||||
|
||||
choice = result["choice"]
|
||||
if choice in ('o', 'once'):
|
||||
print(" ✓ Allowed once")
|
||||
return "once"
|
||||
elif choice in ('s', 'session'):
|
||||
print(" ✓ Allowed for this session")
|
||||
return "session"
|
||||
elif choice in ('a', 'always'):
|
||||
if not allow_permanent:
|
||||
print(" ✓ Allowed for this session")
|
||||
return "session"
|
||||
print(" ✓ Added to permanent allowlist")
|
||||
return "always"
|
||||
else:
|
||||
print(" ✗ Denied")
|
||||
return "deny"
|
||||
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n ✗ Cancelled")
|
||||
return "deny"
|
||||
finally:
|
||||
if "HERMES_SPINNER_PAUSE" in os.environ:
|
||||
del os.environ["HERMES_SPINNER_PAUSE"]
|
||||
print()
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _normalize_approval_mode(mode) -> str:
|
||||
"""Normalize approval mode values loaded from YAML/config.
|
||||
|
||||
YAML 1.1 treats bare words like `off` as booleans, so a config entry like
|
||||
`approvals:\n mode: off` is parsed as False unless quoted. Treat that as the
|
||||
intended string mode instead of falling back to manual approvals.
|
||||
"""
|
||||
if isinstance(mode, bool):
|
||||
return "off" if mode is False else "manual"
|
||||
if isinstance(mode, str):
|
||||
normalized = mode.strip().lower()
|
||||
return normalized or "manual"
|
||||
return "manual"
|
||||
|
||||
|
||||
def _get_approval_mode() -> str:
|
||||
"""Read the approval mode from config. Returns 'manual', 'smart', or 'off'."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
mode = config.get("approvals", {}).get("mode", "manual")
|
||||
return _normalize_approval_mode(mode)
|
||||
except Exception:
|
||||
return "manual"
|
||||
|
||||
|
||||
def _smart_approve(command: str, description: str) -> str:
|
||||
"""Use the auxiliary LLM to assess risk and decide approval.
|
||||
|
||||
Returns 'approve' if the LLM determines the command is safe,
|
||||
'deny' if genuinely dangerous, or 'escalate' if uncertain.
|
||||
|
||||
Inspired by OpenAI Codex's Smart Approvals guardian subagent
|
||||
(openai/codex#13860).
|
||||
"""
|
||||
try:
|
||||
from agent.auxiliary_client import get_text_auxiliary_client, auxiliary_max_tokens_param
|
||||
|
||||
client, model = get_text_auxiliary_client(task="approval")
|
||||
if not client or not model:
|
||||
logger.debug("Smart approvals: no aux client available, escalating")
|
||||
return "escalate"
|
||||
|
||||
prompt = f"""You are a security reviewer for an AI coding agent. A terminal command was flagged by pattern matching as potentially dangerous.
|
||||
|
||||
Command: {command}
|
||||
Flagged reason: {description}
|
||||
|
||||
Assess the ACTUAL risk of this command. Many flagged commands are false positives — for example, `python -c "print('hello')"` is flagged as "script execution via -c flag" but is completely harmless.
|
||||
|
||||
Rules:
|
||||
- APPROVE if the command is clearly safe (benign script execution, safe file operations, development tools, package installs, git operations, etc.)
|
||||
- DENY if the command could genuinely damage the system (recursive delete of important paths, overwriting system files, fork bombs, wiping disks, dropping databases, etc.)
|
||||
- ESCALATE if you're uncertain
|
||||
|
||||
Respond with exactly one word: APPROVE, DENY, or ESCALATE"""
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**auxiliary_max_tokens_param(16),
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
answer = (response.choices[0].message.content or "").strip().upper()
|
||||
|
||||
if "APPROVE" in answer:
|
||||
return "approve"
|
||||
elif "DENY" in answer:
|
||||
return "deny"
|
||||
else:
|
||||
return "escalate"
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Smart approvals: LLM call failed (%s), escalating", e)
|
||||
return "escalate"
|
||||
|
||||
|
||||
def check_dangerous_command(command: str, env_type: str,
|
||||
approval_callback=None) -> dict:
|
||||
"""Check if a command is dangerous and handle approval.
|
||||
|
||||
This is the main entry point called by terminal_tool before executing
|
||||
any command. It orchestrates detection, session checks, and prompting.
|
||||
|
||||
Args:
|
||||
command: The shell command to check.
|
||||
env_type: Terminal backend type ('local', 'ssh', 'docker', etc.).
|
||||
approval_callback: Optional CLI callback for interactive prompts.
|
||||
|
||||
Returns:
|
||||
{"approved": True/False, "message": str or None, ...}
|
||||
"""
|
||||
if env_type in ("docker", "singularity", "modal", "daytona"):
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
# --yolo: bypass all approval prompts
|
||||
if os.getenv("HERMES_YOLO_MODE"):
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
is_dangerous, pattern_key, description = detect_dangerous_command(command)
|
||||
if not is_dangerous:
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
session_key = os.getenv("HERMES_SESSION_KEY", "default")
|
||||
if is_approved(session_key, pattern_key):
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
is_cli = os.getenv("HERMES_INTERACTIVE")
|
||||
is_gateway = os.getenv("HERMES_GATEWAY_SESSION")
|
||||
|
||||
if not is_cli and not is_gateway:
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
if is_gateway or os.getenv("HERMES_EXEC_ASK"):
|
||||
submit_pending(session_key, {
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"description": description,
|
||||
})
|
||||
return {
|
||||
"approved": False,
|
||||
"pattern_key": pattern_key,
|
||||
"status": "approval_required",
|
||||
"command": command,
|
||||
"description": description,
|
||||
"message": (
|
||||
f"⚠️ This command is potentially dangerous ({description}). "
|
||||
f"Asking the user for approval.\n\n**Command:**\n```\n{command}\n```"
|
||||
),
|
||||
}
|
||||
|
||||
choice = prompt_dangerous_approval(command, description,
|
||||
approval_callback=approval_callback)
|
||||
|
||||
if choice == "deny":
|
||||
return {
|
||||
"approved": False,
|
||||
"message": f"BLOCKED: User denied this potentially dangerous command (matched '{description}' pattern). Do NOT retry this command - the user has explicitly rejected it.",
|
||||
"pattern_key": pattern_key,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
if choice == "session":
|
||||
approve_session(session_key, pattern_key)
|
||||
elif choice == "always":
|
||||
approve_session(session_key, pattern_key)
|
||||
approve_permanent(pattern_key)
|
||||
save_permanent_allowlist(_permanent_approved)
|
||||
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Combined pre-exec guard (tirith + dangerous command detection)
|
||||
# =========================================================================
|
||||
|
||||
def check_all_command_guards(command: str, env_type: str,
|
||||
approval_callback=None) -> dict:
|
||||
"""Run all pre-exec security checks and return a single approval decision.
|
||||
|
||||
Gathers findings from tirith and dangerous-command detection, then
|
||||
presents them as a single combined approval request. This prevents
|
||||
a gateway force=True replay from bypassing one check when only the
|
||||
other was shown to the user.
|
||||
"""
|
||||
# Skip containers for both checks
|
||||
if env_type in ("docker", "singularity", "modal", "daytona"):
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
# --yolo or approvals.mode=off: bypass all approval prompts
|
||||
approval_mode = _get_approval_mode()
|
||||
if os.getenv("HERMES_YOLO_MODE") or approval_mode == "off":
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
is_cli = os.getenv("HERMES_INTERACTIVE")
|
||||
is_gateway = os.getenv("HERMES_GATEWAY_SESSION")
|
||||
is_ask = os.getenv("HERMES_EXEC_ASK")
|
||||
|
||||
# Preserve the existing non-interactive behavior: outside CLI/gateway/ask
|
||||
# flows, we do not block on approvals and we skip external guard work.
|
||||
if not is_cli and not is_gateway and not is_ask:
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
# --- Phase 1: Gather findings from both checks ---
|
||||
|
||||
# Tirith check — wrapper guarantees no raise for expected failures.
|
||||
# Only catch ImportError (module not installed).
|
||||
tirith_result = {"action": "allow", "findings": [], "summary": ""}
|
||||
try:
|
||||
from tools.tirith_security import check_command_security
|
||||
tirith_result = check_command_security(command)
|
||||
except ImportError:
|
||||
pass # tirith module not installed — allow
|
||||
|
||||
# Dangerous command check (detection only, no approval)
|
||||
is_dangerous, pattern_key, description = detect_dangerous_command(command)
|
||||
|
||||
# --- Phase 2: Decide ---
|
||||
|
||||
# If tirith blocks, block immediately (no approval possible)
|
||||
if tirith_result["action"] == "block":
|
||||
summary = tirith_result.get("summary") or "security issue detected"
|
||||
return {
|
||||
"approved": False,
|
||||
"message": f"BLOCKED: Command blocked by security scan ({summary}). Do NOT retry.",
|
||||
}
|
||||
|
||||
# Collect warnings that need approval
|
||||
warnings = [] # list of (pattern_key, description, is_tirith)
|
||||
|
||||
session_key = os.getenv("HERMES_SESSION_KEY", "default")
|
||||
|
||||
if tirith_result["action"] == "warn":
|
||||
findings = tirith_result.get("findings") or []
|
||||
rule_id = findings[0].get("rule_id", "unknown") if findings else "unknown"
|
||||
tirith_key = f"tirith:{rule_id}"
|
||||
tirith_desc = f"Security scan: {tirith_result.get('summary') or 'security warning detected'}"
|
||||
if not is_approved(session_key, tirith_key):
|
||||
warnings.append((tirith_key, tirith_desc, True))
|
||||
|
||||
if is_dangerous:
|
||||
if not is_approved(session_key, pattern_key):
|
||||
warnings.append((pattern_key, description, False))
|
||||
|
||||
# Nothing to warn about
|
||||
if not warnings:
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
# --- Phase 2.5: Smart approval (auxiliary LLM risk assessment) ---
|
||||
# When approvals.mode=smart, ask the aux LLM before prompting the user.
|
||||
# Inspired by OpenAI Codex's Smart Approvals guardian subagent
|
||||
# (openai/codex#13860).
|
||||
if approval_mode == "smart":
|
||||
combined_desc_for_llm = "; ".join(desc for _, desc, _ in warnings)
|
||||
verdict = _smart_approve(command, combined_desc_for_llm)
|
||||
if verdict == "approve":
|
||||
# Auto-approve and grant session-level approval for these patterns
|
||||
for key, _, _ in warnings:
|
||||
approve_session(session_key, key)
|
||||
logger.debug("Smart approval: auto-approved '%s' (%s)",
|
||||
command[:60], combined_desc_for_llm)
|
||||
return {"approved": True, "message": None,
|
||||
"smart_approved": True}
|
||||
elif verdict == "deny":
|
||||
combined_desc_for_llm = "; ".join(desc for _, desc, _ in warnings)
|
||||
return {
|
||||
"approved": False,
|
||||
"message": f"BLOCKED by smart approval: {combined_desc_for_llm}. "
|
||||
"The command was assessed as genuinely dangerous. Do NOT retry.",
|
||||
"smart_denied": True,
|
||||
}
|
||||
# verdict == "escalate" → fall through to manual prompt
|
||||
|
||||
# --- Phase 3: Approval ---
|
||||
|
||||
# Combine descriptions for a single approval prompt
|
||||
combined_desc = "; ".join(desc for _, desc, _ in warnings)
|
||||
primary_key = warnings[0][0]
|
||||
all_keys = [key for key, _, _ in warnings]
|
||||
has_tirith = any(is_t for _, _, is_t in warnings)
|
||||
|
||||
# Gateway/async: single approval_required with combined description
|
||||
# Store all pattern keys so gateway replay approves all of them
|
||||
if is_gateway or is_ask:
|
||||
submit_pending(session_key, {
|
||||
"command": command,
|
||||
"pattern_key": primary_key, # backward compat
|
||||
"pattern_keys": all_keys, # all keys for replay
|
||||
"description": combined_desc,
|
||||
})
|
||||
return {
|
||||
"approved": False,
|
||||
"pattern_key": primary_key,
|
||||
"status": "approval_required",
|
||||
"command": command,
|
||||
"description": combined_desc,
|
||||
"message": (
|
||||
f"⚠️ {combined_desc}. Asking the user for approval.\n\n**Command:**\n```\n{command}\n```"
|
||||
),
|
||||
}
|
||||
|
||||
# CLI interactive: single combined prompt
|
||||
# Hide [a]lways when any tirith warning is present
|
||||
choice = prompt_dangerous_approval(command, combined_desc,
|
||||
allow_permanent=not has_tirith,
|
||||
approval_callback=approval_callback)
|
||||
|
||||
if choice == "deny":
|
||||
return {
|
||||
"approved": False,
|
||||
"message": "BLOCKED: User denied. Do NOT retry.",
|
||||
"pattern_key": primary_key,
|
||||
"description": combined_desc,
|
||||
}
|
||||
|
||||
# Persist approval for each warning individually
|
||||
for key, _, is_tirith in warnings:
|
||||
if choice == "session" or (choice == "always" and is_tirith):
|
||||
# tirith: session only (no permanent broad allowlisting)
|
||||
approve_session(session_key, key)
|
||||
elif choice == "always":
|
||||
# dangerous patterns: permanent allowed
|
||||
approve_session(session_key, key)
|
||||
approve_permanent(key)
|
||||
save_permanent_allowlist(_permanent_approved)
|
||||
|
||||
return {"approved": True, "message": None}
|
||||
10
hermes_code/tools/browser_providers/__init__.py
Normal file
10
hermes_code/tools/browser_providers/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""Cloud browser provider abstraction.
|
||||
|
||||
Import the ABC so callers can do::
|
||||
|
||||
from tools.browser_providers import CloudBrowserProvider
|
||||
"""
|
||||
|
||||
from tools.browser_providers.base import CloudBrowserProvider
|
||||
|
||||
__all__ = ["CloudBrowserProvider"]
|
||||
59
hermes_code/tools/browser_providers/base.py
Normal file
59
hermes_code/tools/browser_providers/base.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""Abstract base class for cloud browser providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class CloudBrowserProvider(ABC):
|
||||
"""Interface for cloud browser backends (Browserbase, Steel, etc.).
|
||||
|
||||
Implementations live in sibling modules and are registered in
|
||||
``browser_tool._PROVIDER_REGISTRY``. The user selects a provider via
|
||||
``hermes setup`` / ``hermes tools``; the choice is persisted as
|
||||
``config["browser"]["cloud_provider"]``.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
"""Short, human-readable name shown in logs and diagnostics."""
|
||||
|
||||
@abstractmethod
|
||||
def is_configured(self) -> bool:
|
||||
"""Return True when all required env vars / credentials are present.
|
||||
|
||||
Called at tool-registration time (``check_browser_requirements``) to
|
||||
gate availability. Must be cheap — no network calls.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_session(self, task_id: str) -> Dict[str, object]:
|
||||
"""Create a cloud browser session and return session metadata.
|
||||
|
||||
Must return a dict with at least::
|
||||
|
||||
{
|
||||
"session_name": str, # unique name for agent-browser --session
|
||||
"bb_session_id": str, # provider session ID (for close/cleanup)
|
||||
"cdp_url": str, # CDP websocket URL
|
||||
"features": dict, # feature flags that were enabled
|
||||
}
|
||||
|
||||
``bb_session_id`` is a legacy key name kept for backward compat with
|
||||
the rest of browser_tool.py — it holds the provider's session ID
|
||||
regardless of which provider is in use.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def close_session(self, session_id: str) -> bool:
|
||||
"""Release / terminate a cloud session by its provider session ID.
|
||||
|
||||
Returns True on success, False on failure. Should not raise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def emergency_cleanup(self, session_id: str) -> None:
|
||||
"""Best-effort session teardown during process exit.
|
||||
|
||||
Called from atexit / signal handlers. Must tolerate missing
|
||||
credentials, network errors, etc. — log and move on.
|
||||
"""
|
||||
107
hermes_code/tools/browser_providers/browser_use.py
Normal file
107
hermes_code/tools/browser_providers/browser_use.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Browser Use cloud browser provider."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
|
||||
from tools.browser_providers.base import CloudBrowserProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BASE_URL = "https://api.browser-use.com/api/v2"
|
||||
|
||||
|
||||
class BrowserUseProvider(CloudBrowserProvider):
|
||||
"""Browser Use (https://browser-use.com) cloud browser backend."""
|
||||
|
||||
def provider_name(self) -> str:
|
||||
return "Browser Use"
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
return bool(os.environ.get("BROWSER_USE_API_KEY"))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
api_key = os.environ.get("BROWSER_USE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"BROWSER_USE_API_KEY environment variable is required. "
|
||||
"Get your key at https://browser-use.com"
|
||||
)
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"X-Browser-Use-API-Key": api_key,
|
||||
}
|
||||
|
||||
def create_session(self, task_id: str) -> Dict[str, object]:
|
||||
response = requests.post(
|
||||
f"{_BASE_URL}/browsers",
|
||||
headers=self._headers(),
|
||||
json={},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise RuntimeError(
|
||||
f"Failed to create Browser Use session: "
|
||||
f"{response.status_code} {response.text}"
|
||||
)
|
||||
|
||||
session_data = response.json()
|
||||
session_name = f"hermes_{task_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info("Created Browser Use session %s", session_name)
|
||||
|
||||
return {
|
||||
"session_name": session_name,
|
||||
"bb_session_id": session_data["id"],
|
||||
"cdp_url": session_data["cdpUrl"],
|
||||
"features": {"browser_use": True},
|
||||
}
|
||||
|
||||
def close_session(self, session_id: str) -> bool:
|
||||
try:
|
||||
response = requests.patch(
|
||||
f"{_BASE_URL}/browsers/{session_id}",
|
||||
headers=self._headers(),
|
||||
json={"action": "stop"},
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code in (200, 201, 204):
|
||||
logger.debug("Successfully closed Browser Use session %s", session_id)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to close Browser Use session %s: HTTP %s - %s",
|
||||
session_id,
|
||||
response.status_code,
|
||||
response.text[:200],
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Exception closing Browser Use session %s: %s", session_id, e)
|
||||
return False
|
||||
|
||||
def emergency_cleanup(self, session_id: str) -> None:
|
||||
api_key = os.environ.get("BROWSER_USE_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("Cannot emergency-cleanup Browser Use session %s — missing credentials", session_id)
|
||||
return
|
||||
try:
|
||||
requests.patch(
|
||||
f"{_BASE_URL}/browsers/{session_id}",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Browser-Use-API-Key": api_key,
|
||||
},
|
||||
json={"action": "stop"},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Emergency cleanup failed for Browser Use session %s: %s", session_id, e)
|
||||
206
hermes_code/tools/browser_providers/browserbase.py
Normal file
206
hermes_code/tools/browser_providers/browserbase.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""Browserbase cloud browser provider."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
|
||||
from tools.browser_providers.base import CloudBrowserProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BrowserbaseProvider(CloudBrowserProvider):
|
||||
"""Browserbase (https://browserbase.com) cloud browser backend."""
|
||||
|
||||
def provider_name(self) -> str:
|
||||
return "Browserbase"
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
return bool(
|
||||
os.environ.get("BROWSERBASE_API_KEY")
|
||||
and os.environ.get("BROWSERBASE_PROJECT_ID")
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_config(self) -> Dict[str, str]:
|
||||
api_key = os.environ.get("BROWSERBASE_API_KEY")
|
||||
project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
|
||||
if not api_key or not project_id:
|
||||
raise ValueError(
|
||||
"BROWSERBASE_API_KEY and BROWSERBASE_PROJECT_ID environment "
|
||||
"variables are required. Get your credentials at "
|
||||
"https://browserbase.com"
|
||||
)
|
||||
return {"api_key": api_key, "project_id": project_id}
|
||||
|
||||
def create_session(self, task_id: str) -> Dict[str, object]:
|
||||
config = self._get_config()
|
||||
|
||||
# Optional env-var knobs
|
||||
enable_proxies = os.environ.get("BROWSERBASE_PROXIES", "true").lower() != "false"
|
||||
enable_advanced_stealth = os.environ.get("BROWSERBASE_ADVANCED_STEALTH", "false").lower() == "true"
|
||||
enable_keep_alive = os.environ.get("BROWSERBASE_KEEP_ALIVE", "true").lower() != "false"
|
||||
custom_timeout_ms = os.environ.get("BROWSERBASE_SESSION_TIMEOUT")
|
||||
|
||||
features_enabled = {
|
||||
"basic_stealth": True,
|
||||
"proxies": False,
|
||||
"advanced_stealth": False,
|
||||
"keep_alive": False,
|
||||
"custom_timeout": False,
|
||||
}
|
||||
|
||||
session_config: Dict[str, object] = {"projectId": config["project_id"]}
|
||||
|
||||
if enable_keep_alive:
|
||||
session_config["keepAlive"] = True
|
||||
|
||||
if custom_timeout_ms:
|
||||
try:
|
||||
timeout_val = int(custom_timeout_ms)
|
||||
if timeout_val > 0:
|
||||
session_config["timeout"] = timeout_val
|
||||
except ValueError:
|
||||
logger.warning("Invalid BROWSERBASE_SESSION_TIMEOUT value: %s", custom_timeout_ms)
|
||||
|
||||
if enable_proxies:
|
||||
session_config["proxies"] = True
|
||||
|
||||
if enable_advanced_stealth:
|
||||
session_config["browserSettings"] = {"advancedStealth": True}
|
||||
|
||||
# --- Create session via API ---
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-BB-API-Key": config["api_key"],
|
||||
}
|
||||
response = requests.post(
|
||||
"https://api.browserbase.com/v1/sessions",
|
||||
headers=headers,
|
||||
json=session_config,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
proxies_fallback = False
|
||||
keepalive_fallback = False
|
||||
|
||||
# Handle 402 — paid features unavailable
|
||||
if response.status_code == 402:
|
||||
if enable_keep_alive:
|
||||
keepalive_fallback = True
|
||||
logger.warning(
|
||||
"keepAlive may require paid plan (402), retrying without it. "
|
||||
"Sessions may timeout during long operations."
|
||||
)
|
||||
session_config.pop("keepAlive", None)
|
||||
response = requests.post(
|
||||
"https://api.browserbase.com/v1/sessions",
|
||||
headers=headers,
|
||||
json=session_config,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if response.status_code == 402 and enable_proxies:
|
||||
proxies_fallback = True
|
||||
logger.warning(
|
||||
"Proxies unavailable (402), retrying without proxies. "
|
||||
"Bot detection may be less effective."
|
||||
)
|
||||
session_config.pop("proxies", None)
|
||||
response = requests.post(
|
||||
"https://api.browserbase.com/v1/sessions",
|
||||
headers=headers,
|
||||
json=session_config,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise RuntimeError(
|
||||
f"Failed to create Browserbase session: "
|
||||
f"{response.status_code} {response.text}"
|
||||
)
|
||||
|
||||
session_data = response.json()
|
||||
session_name = f"hermes_{task_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
if enable_proxies and not proxies_fallback:
|
||||
features_enabled["proxies"] = True
|
||||
if enable_advanced_stealth:
|
||||
features_enabled["advanced_stealth"] = True
|
||||
if enable_keep_alive and not keepalive_fallback:
|
||||
features_enabled["keep_alive"] = True
|
||||
if custom_timeout_ms and "timeout" in session_config:
|
||||
features_enabled["custom_timeout"] = True
|
||||
|
||||
feature_str = ", ".join(k for k, v in features_enabled.items() if v)
|
||||
logger.info("Created Browserbase session %s with features: %s", session_name, feature_str)
|
||||
|
||||
return {
|
||||
"session_name": session_name,
|
||||
"bb_session_id": session_data["id"],
|
||||
"cdp_url": session_data["connectUrl"],
|
||||
"features": features_enabled,
|
||||
}
|
||||
|
||||
def close_session(self, session_id: str) -> bool:
|
||||
try:
|
||||
config = self._get_config()
|
||||
except ValueError:
|
||||
logger.warning("Cannot close Browserbase session %s — missing credentials", session_id)
|
||||
return False
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"https://api.browserbase.com/v1/sessions/{session_id}",
|
||||
headers={
|
||||
"X-BB-API-Key": config["api_key"],
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"projectId": config["project_id"],
|
||||
"status": "REQUEST_RELEASE",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code in (200, 201, 204):
|
||||
logger.debug("Successfully closed Browserbase session %s", session_id)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to close session %s: HTTP %s - %s",
|
||||
session_id,
|
||||
response.status_code,
|
||||
response.text[:200],
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Exception closing Browserbase session %s: %s", session_id, e)
|
||||
return False
|
||||
|
||||
def emergency_cleanup(self, session_id: str) -> None:
|
||||
api_key = os.environ.get("BROWSERBASE_API_KEY")
|
||||
project_id = os.environ.get("BROWSERBASE_PROJECT_ID")
|
||||
if not api_key or not project_id:
|
||||
logger.warning("Cannot emergency-cleanup Browserbase session %s — missing credentials", session_id)
|
||||
return
|
||||
try:
|
||||
requests.post(
|
||||
f"https://api.browserbase.com/v1/sessions/{session_id}",
|
||||
headers={
|
||||
"X-BB-API-Key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"projectId": project_id,
|
||||
"status": "REQUEST_RELEASE",
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Emergency cleanup failed for Browserbase session %s: %s", session_id, e)
|
||||
1923
hermes_code/tools/browser_tool.py
Normal file
1923
hermes_code/tools/browser_tool.py
Normal file
File diff suppressed because it is too large
Load diff
86
hermes_code/tools/browser_use_tool.py
Normal file
86
hermes_code/tools/browser_use_tool.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import socket
|
||||
from browser_use import Agent, Browser, ChatOpenAI
|
||||
from tools.registry import registry
|
||||
|
||||
|
||||
async def run_browser_task(task):
|
||||
browser_host = "browser"
|
||||
browser_port = 9222
|
||||
BROWSER_VIEW_URL = os.getenv("BROWSER_VIEW_URL", "")
|
||||
|
||||
try:
|
||||
browser_ip = socket.gethostbyname(browser_host)
|
||||
cdp_url = f"http://{browser_ip}:{browser_port}"
|
||||
except Exception:
|
||||
cdp_url = f"http://{browser_host}:{browser_port}"
|
||||
|
||||
browser = Browser(cdp_url=cdp_url)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model=os.getenv("MODEL_DEFAULT", "qwen3.5-122b"),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
base_url=os.getenv("OPENAI_BASE_URL"),
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
task=task,
|
||||
llm=llm,
|
||||
browser=browser,
|
||||
use_vision=False
|
||||
)
|
||||
|
||||
try:
|
||||
history = await agent.run()
|
||||
final_result = history.final_result()
|
||||
|
||||
response = {
|
||||
"success": True,
|
||||
"result": final_result,
|
||||
"browser_view": BROWSER_VIEW_URL
|
||||
}
|
||||
return json.dumps(response, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"Browser automation failed: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
finally:
|
||||
if browser:
|
||||
try:
|
||||
await browser.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
registry.register(
|
||||
name="internet_browser",
|
||||
toolset="browse_cmd",
|
||||
schema={
|
||||
"name": "internet_browser",
|
||||
"description": (
|
||||
"ГЛАВНЫЙ ИНСТРУМЕНТ ДЛЯ ВЕБ-СЕРФИНГА. Вызывай этот инструмент НАПРЯМУЮ (через стандартный tool call/function call). "
|
||||
"КАТЕГОРИЧЕСКИ ЗАПРЕЩЕНО использовать `execute_code` или `delegate_task` для работы с браузером. "
|
||||
"Не пиши Python-скрипты! Просто передай в этот инструмент параметр `task`. "
|
||||
"Используй для любых задач в интернете: поиск товаров (Wildberries, Ozon), чтение статей, клики, навигация."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "Подробная задача на естественном языке. Например: 'Зайди на wildberries.ru, найди черную футболку и верни цену'."
|
||||
}
|
||||
},
|
||||
"required": ["task"]
|
||||
}
|
||||
},
|
||||
|
||||
handler=lambda args, **kw: asyncio.run(run_browser_task(args.get("task"))),
|
||||
emoji="🌐",
|
||||
)
|
||||
548
hermes_code/tools/checkpoint_manager.py
Normal file
548
hermes_code/tools/checkpoint_manager.py
Normal file
|
|
@ -0,0 +1,548 @@
|
|||
"""
|
||||
Checkpoint Manager — Transparent filesystem snapshots via shadow git repos.
|
||||
|
||||
Creates automatic snapshots of working directories before file-mutating
|
||||
operations (write_file, patch), triggered once per conversation turn.
|
||||
Provides rollback to any previous checkpoint.
|
||||
|
||||
This is NOT a tool — the LLM never sees it. It's transparent infrastructure
|
||||
controlled by the ``checkpoints`` config flag or ``--checkpoints`` CLI flag.
|
||||
|
||||
Architecture:
|
||||
~/.hermes/checkpoints/{sha256(abs_dir)[:16]}/ — shadow git repo
|
||||
HEAD, refs/, objects/ — standard git internals
|
||||
HERMES_WORKDIR — original dir path
|
||||
info/exclude — default excludes
|
||||
|
||||
The shadow repo uses GIT_DIR + GIT_WORK_TREE so no git state leaks
|
||||
into the user's project directory.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CHECKPOINT_BASE = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "checkpoints"
|
||||
|
||||
DEFAULT_EXCLUDES = [
|
||||
"node_modules/",
|
||||
"dist/",
|
||||
"build/",
|
||||
".env",
|
||||
".env.*",
|
||||
".env.local",
|
||||
".env.*.local",
|
||||
"__pycache__/",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
".DS_Store",
|
||||
"*.log",
|
||||
".cache/",
|
||||
".next/",
|
||||
".nuxt/",
|
||||
"coverage/",
|
||||
".pytest_cache/",
|
||||
".venv/",
|
||||
"venv/",
|
||||
".git/",
|
||||
]
|
||||
|
||||
# Git subprocess timeout (seconds).
|
||||
_GIT_TIMEOUT: int = max(10, min(60, int(os.getenv("HERMES_CHECKPOINT_TIMEOUT", "30"))))
|
||||
|
||||
# Max files to snapshot — skip huge directories to avoid slowdowns.
|
||||
_MAX_FILES = 50_000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shadow repo helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _shadow_repo_path(working_dir: str) -> Path:
|
||||
"""Deterministic shadow repo path: sha256(abs_path)[:16]."""
|
||||
abs_path = str(Path(working_dir).resolve())
|
||||
dir_hash = hashlib.sha256(abs_path.encode()).hexdigest()[:16]
|
||||
return CHECKPOINT_BASE / dir_hash
|
||||
|
||||
|
||||
def _git_env(shadow_repo: Path, working_dir: str) -> dict:
|
||||
"""Build env dict that redirects git to the shadow repo."""
|
||||
env = os.environ.copy()
|
||||
env["GIT_DIR"] = str(shadow_repo)
|
||||
env["GIT_WORK_TREE"] = str(Path(working_dir).resolve())
|
||||
env.pop("GIT_INDEX_FILE", None)
|
||||
env.pop("GIT_NAMESPACE", None)
|
||||
env.pop("GIT_ALTERNATE_OBJECT_DIRECTORIES", None)
|
||||
return env
|
||||
|
||||
|
||||
def _run_git(
|
||||
args: List[str],
|
||||
shadow_repo: Path,
|
||||
working_dir: str,
|
||||
timeout: int = _GIT_TIMEOUT,
|
||||
allowed_returncodes: Optional[Set[int]] = None,
|
||||
) -> tuple:
|
||||
"""Run a git command against the shadow repo. Returns (ok, stdout, stderr).
|
||||
|
||||
``allowed_returncodes`` suppresses error logging for known/expected non-zero
|
||||
exits while preserving the normal ``ok = (returncode == 0)`` contract.
|
||||
Example: ``git diff --cached --quiet`` returns 1 when changes exist.
|
||||
"""
|
||||
env = _git_env(shadow_repo, working_dir)
|
||||
cmd = ["git"] + list(args)
|
||||
allowed_returncodes = allowed_returncodes or set()
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
env=env,
|
||||
cwd=str(Path(working_dir).resolve()),
|
||||
)
|
||||
ok = result.returncode == 0
|
||||
stdout = result.stdout.strip()
|
||||
stderr = result.stderr.strip()
|
||||
if not ok and result.returncode not in allowed_returncodes:
|
||||
logger.error(
|
||||
"Git command failed: %s (rc=%d) stderr=%s",
|
||||
" ".join(cmd), result.returncode, stderr,
|
||||
)
|
||||
return ok, stdout, stderr
|
||||
except subprocess.TimeoutExpired:
|
||||
msg = f"git timed out after {timeout}s: {' '.join(cmd)}"
|
||||
logger.error(msg, exc_info=True)
|
||||
return False, "", msg
|
||||
except FileNotFoundError:
|
||||
logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True)
|
||||
return False, "", "git not found"
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected git error running %s: %s", " ".join(cmd), exc, exc_info=True)
|
||||
return False, "", str(exc)
|
||||
|
||||
|
||||
def _init_shadow_repo(shadow_repo: Path, working_dir: str) -> Optional[str]:
|
||||
"""Initialise shadow repo if needed. Returns error string or None."""
|
||||
if (shadow_repo / "HEAD").exists():
|
||||
return None
|
||||
|
||||
shadow_repo.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ok, _, err = _run_git(["init"], shadow_repo, working_dir)
|
||||
if not ok:
|
||||
return f"Shadow repo init failed: {err}"
|
||||
|
||||
_run_git(["config", "user.email", "hermes@local"], shadow_repo, working_dir)
|
||||
_run_git(["config", "user.name", "Hermes Checkpoint"], shadow_repo, working_dir)
|
||||
|
||||
info_dir = shadow_repo / "info"
|
||||
info_dir.mkdir(exist_ok=True)
|
||||
(info_dir / "exclude").write_text(
|
||||
"\n".join(DEFAULT_EXCLUDES) + "\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
(shadow_repo / "HERMES_WORKDIR").write_text(
|
||||
str(Path(working_dir).resolve()) + "\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
logger.debug("Initialised checkpoint repo at %s for %s", shadow_repo, working_dir)
|
||||
return None
|
||||
|
||||
|
||||
def _dir_file_count(path: str) -> int:
|
||||
"""Quick file count estimate (stops early if over _MAX_FILES)."""
|
||||
count = 0
|
||||
try:
|
||||
for _ in Path(path).rglob("*"):
|
||||
count += 1
|
||||
if count > _MAX_FILES:
|
||||
return count
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
return count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CheckpointManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CheckpointManager:
|
||||
"""Manages automatic filesystem checkpoints.
|
||||
|
||||
Designed to be owned by AIAgent. Call ``new_turn()`` at the start of
|
||||
each conversation turn and ``ensure_checkpoint(dir, reason)`` before
|
||||
any file-mutating tool call. The manager deduplicates so at most one
|
||||
snapshot is taken per directory per turn.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
enabled : bool
|
||||
Master switch (from config / CLI flag).
|
||||
max_snapshots : int
|
||||
Keep at most this many checkpoints per directory.
|
||||
"""
|
||||
|
||||
def __init__(self, enabled: bool = False, max_snapshots: int = 50):
|
||||
self.enabled = enabled
|
||||
self.max_snapshots = max_snapshots
|
||||
self._checkpointed_dirs: Set[str] = set()
|
||||
self._git_available: Optional[bool] = None # lazy probe
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Turn lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def new_turn(self) -> None:
|
||||
"""Reset per-turn dedup. Call at the start of each agent iteration."""
|
||||
self._checkpointed_dirs.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def ensure_checkpoint(self, working_dir: str, reason: str = "auto") -> bool:
|
||||
"""Take a checkpoint if enabled and not already done this turn.
|
||||
|
||||
Returns True if a checkpoint was taken, False otherwise.
|
||||
Never raises — all errors are silently logged.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
# Lazy git probe
|
||||
if self._git_available is None:
|
||||
self._git_available = shutil.which("git") is not None
|
||||
if not self._git_available:
|
||||
logger.debug("Checkpoints disabled: git not found")
|
||||
if not self._git_available:
|
||||
return False
|
||||
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
|
||||
# Skip root, home, and other overly broad directories
|
||||
if abs_dir in ("/", str(Path.home())):
|
||||
logger.debug("Checkpoint skipped: directory too broad (%s)", abs_dir)
|
||||
return False
|
||||
|
||||
# Already checkpointed this turn?
|
||||
if abs_dir in self._checkpointed_dirs:
|
||||
return False
|
||||
|
||||
self._checkpointed_dirs.add(abs_dir)
|
||||
|
||||
try:
|
||||
return self._take(abs_dir, reason)
|
||||
except Exception as e:
|
||||
logger.debug("Checkpoint failed (non-fatal): %s", e)
|
||||
return False
|
||||
|
||||
def list_checkpoints(self, working_dir: str) -> List[Dict]:
|
||||
"""List available checkpoints for a directory.
|
||||
|
||||
Returns a list of dicts with keys: hash, short_hash, timestamp, reason,
|
||||
files_changed, insertions, deletions. Most recent first.
|
||||
"""
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
shadow = _shadow_repo_path(abs_dir)
|
||||
|
||||
if not (shadow / "HEAD").exists():
|
||||
return []
|
||||
|
||||
ok, stdout, _ = _run_git(
|
||||
["log", "--format=%H|%h|%aI|%s", "-n", str(self.max_snapshots)],
|
||||
shadow, abs_dir,
|
||||
)
|
||||
|
||||
if not ok or not stdout:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for line in stdout.splitlines():
|
||||
parts = line.split("|", 3)
|
||||
if len(parts) == 4:
|
||||
entry = {
|
||||
"hash": parts[0],
|
||||
"short_hash": parts[1],
|
||||
"timestamp": parts[2],
|
||||
"reason": parts[3],
|
||||
"files_changed": 0,
|
||||
"insertions": 0,
|
||||
"deletions": 0,
|
||||
}
|
||||
# Get diffstat for this commit
|
||||
stat_ok, stat_out, _ = _run_git(
|
||||
["diff", "--shortstat", f"{parts[0]}~1", parts[0]],
|
||||
shadow, abs_dir,
|
||||
allowed_returncodes={128, 129}, # first commit has no parent
|
||||
)
|
||||
if stat_ok and stat_out:
|
||||
self._parse_shortstat(stat_out, entry)
|
||||
results.append(entry)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _parse_shortstat(stat_line: str, entry: Dict) -> None:
|
||||
"""Parse git --shortstat output into entry dict."""
|
||||
import re
|
||||
m = re.search(r'(\d+) file', stat_line)
|
||||
if m:
|
||||
entry["files_changed"] = int(m.group(1))
|
||||
m = re.search(r'(\d+) insertion', stat_line)
|
||||
if m:
|
||||
entry["insertions"] = int(m.group(1))
|
||||
m = re.search(r'(\d+) deletion', stat_line)
|
||||
if m:
|
||||
entry["deletions"] = int(m.group(1))
|
||||
|
||||
def diff(self, working_dir: str, commit_hash: str) -> Dict:
|
||||
"""Show diff between a checkpoint and the current working tree.
|
||||
|
||||
Returns dict with success, diff text, and stat summary.
|
||||
"""
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
shadow = _shadow_repo_path(abs_dir)
|
||||
|
||||
if not (shadow / "HEAD").exists():
|
||||
return {"success": False, "error": "No checkpoints exist for this directory"}
|
||||
|
||||
# Verify the commit exists
|
||||
ok, _, err = _run_git(
|
||||
["cat-file", "-t", commit_hash], shadow, abs_dir,
|
||||
)
|
||||
if not ok:
|
||||
return {"success": False, "error": f"Checkpoint '{commit_hash}' not found"}
|
||||
|
||||
# Stage current state to compare against checkpoint
|
||||
_run_git(["add", "-A"], shadow, abs_dir, timeout=_GIT_TIMEOUT * 2)
|
||||
|
||||
# Get stat summary: checkpoint vs current working tree
|
||||
ok_stat, stat_out, _ = _run_git(
|
||||
["diff", "--stat", commit_hash, "--cached"],
|
||||
shadow, abs_dir,
|
||||
)
|
||||
|
||||
# Get actual diff (limited to avoid terminal flood)
|
||||
ok_diff, diff_out, _ = _run_git(
|
||||
["diff", commit_hash, "--cached", "--no-color"],
|
||||
shadow, abs_dir,
|
||||
)
|
||||
|
||||
# Unstage to avoid polluting the shadow repo index
|
||||
_run_git(["reset", "HEAD", "--quiet"], shadow, abs_dir)
|
||||
|
||||
if not ok_stat and not ok_diff:
|
||||
return {"success": False, "error": "Could not generate diff"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"stat": stat_out if ok_stat else "",
|
||||
"diff": diff_out if ok_diff else "",
|
||||
}
|
||||
|
||||
def restore(self, working_dir: str, commit_hash: str, file_path: str = None) -> Dict:
|
||||
"""Restore files to a checkpoint state.
|
||||
|
||||
Uses ``git checkout <hash> -- .`` (or a specific file) which restores
|
||||
tracked files without moving HEAD — safe and reversible.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : str, optional
|
||||
If provided, restore only this file instead of the entire directory.
|
||||
|
||||
Returns dict with success/error info.
|
||||
"""
|
||||
abs_dir = str(Path(working_dir).resolve())
|
||||
shadow = _shadow_repo_path(abs_dir)
|
||||
|
||||
if not (shadow / "HEAD").exists():
|
||||
return {"success": False, "error": "No checkpoints exist for this directory"}
|
||||
|
||||
# Verify the commit exists
|
||||
ok, _, err = _run_git(
|
||||
["cat-file", "-t", commit_hash], shadow, abs_dir,
|
||||
)
|
||||
if not ok:
|
||||
return {"success": False, "error": f"Checkpoint '{commit_hash}' not found", "debug": err or None}
|
||||
|
||||
# Take a checkpoint of current state before restoring (so you can undo the undo)
|
||||
self._take(abs_dir, f"pre-rollback snapshot (restoring to {commit_hash[:8]})")
|
||||
|
||||
# Restore — full directory or single file
|
||||
restore_target = file_path if file_path else "."
|
||||
ok, stdout, err = _run_git(
|
||||
["checkout", commit_hash, "--", restore_target],
|
||||
shadow, abs_dir, timeout=_GIT_TIMEOUT * 2,
|
||||
)
|
||||
|
||||
if not ok:
|
||||
return {"success": False, "error": f"Restore failed: {err}", "debug": err or None}
|
||||
|
||||
# Get info about what was restored
|
||||
ok2, reason_out, _ = _run_git(
|
||||
["log", "--format=%s", "-1", commit_hash], shadow, abs_dir,
|
||||
)
|
||||
reason = reason_out if ok2 else "unknown"
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"restored_to": commit_hash[:8],
|
||||
"reason": reason,
|
||||
"directory": abs_dir,
|
||||
}
|
||||
if file_path:
|
||||
result["file"] = file_path
|
||||
return result
|
||||
|
||||
def get_working_dir_for_path(self, file_path: str) -> str:
|
||||
"""Resolve a file path to its working directory for checkpointing.
|
||||
|
||||
Walks up from the file's parent to find a reasonable project root
|
||||
(directory containing .git, pyproject.toml, package.json, etc.).
|
||||
Falls back to the file's parent directory.
|
||||
"""
|
||||
path = Path(file_path).resolve()
|
||||
if path.is_dir():
|
||||
candidate = path
|
||||
else:
|
||||
candidate = path.parent
|
||||
|
||||
# Walk up looking for project root markers
|
||||
markers = {".git", "pyproject.toml", "package.json", "Cargo.toml",
|
||||
"go.mod", "Makefile", "pom.xml", ".hg", "Gemfile"}
|
||||
check = candidate
|
||||
while check != check.parent:
|
||||
if any((check / m).exists() for m in markers):
|
||||
return str(check)
|
||||
check = check.parent
|
||||
|
||||
# No project root found — use the file's parent
|
||||
return str(candidate)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _take(self, working_dir: str, reason: str) -> bool:
|
||||
"""Take a snapshot. Returns True on success."""
|
||||
shadow = _shadow_repo_path(working_dir)
|
||||
|
||||
# Init if needed
|
||||
err = _init_shadow_repo(shadow, working_dir)
|
||||
if err:
|
||||
logger.debug("Checkpoint init failed: %s", err)
|
||||
return False
|
||||
|
||||
# Quick size guard — don't try to snapshot enormous directories
|
||||
if _dir_file_count(working_dir) > _MAX_FILES:
|
||||
logger.debug("Checkpoint skipped: >%d files in %s", _MAX_FILES, working_dir)
|
||||
return False
|
||||
|
||||
# Stage everything
|
||||
ok, _, err = _run_git(
|
||||
["add", "-A"], shadow, working_dir, timeout=_GIT_TIMEOUT * 2,
|
||||
)
|
||||
if not ok:
|
||||
logger.debug("Checkpoint git-add failed: %s", err)
|
||||
return False
|
||||
|
||||
# Check if there's anything to commit
|
||||
ok_diff, diff_out, _ = _run_git(
|
||||
["diff", "--cached", "--quiet"],
|
||||
shadow,
|
||||
working_dir,
|
||||
allowed_returncodes={1},
|
||||
)
|
||||
if ok_diff:
|
||||
# No changes to commit
|
||||
logger.debug("Checkpoint skipped: no changes in %s", working_dir)
|
||||
return False
|
||||
|
||||
# Commit
|
||||
ok, _, err = _run_git(
|
||||
["commit", "-m", reason, "--allow-empty-message"],
|
||||
shadow, working_dir, timeout=_GIT_TIMEOUT * 2,
|
||||
)
|
||||
if not ok:
|
||||
logger.debug("Checkpoint commit failed: %s", err)
|
||||
return False
|
||||
|
||||
logger.debug("Checkpoint taken in %s: %s", working_dir, reason)
|
||||
|
||||
# Prune old snapshots
|
||||
self._prune(shadow, working_dir)
|
||||
|
||||
return True
|
||||
|
||||
def _prune(self, shadow_repo: Path, working_dir: str) -> None:
|
||||
"""Keep only the last max_snapshots commits via orphan reset."""
|
||||
ok, stdout, _ = _run_git(
|
||||
["rev-list", "--count", "HEAD"], shadow_repo, working_dir,
|
||||
)
|
||||
if not ok:
|
||||
return
|
||||
|
||||
try:
|
||||
count = int(stdout)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
if count <= self.max_snapshots:
|
||||
return
|
||||
|
||||
# Get the hash of the commit at the cutoff point
|
||||
ok, cutoff_hash, _ = _run_git(
|
||||
["rev-list", "--reverse", "HEAD", "--skip=0",
|
||||
f"--max-count=1"],
|
||||
shadow_repo, working_dir,
|
||||
)
|
||||
|
||||
# For simplicity, we don't actually prune — git's pack mechanism
|
||||
# handles this efficiently, and the objects are small. The log
|
||||
# listing is already limited by max_snapshots.
|
||||
# Full pruning would require rebase --onto or filter-branch which
|
||||
# is fragile for a background feature. We just limit the log view.
|
||||
logger.debug("Checkpoint repo has %d commits (limit %d)", count, self.max_snapshots)
|
||||
|
||||
|
||||
def format_checkpoint_list(checkpoints: List[Dict], directory: str) -> str:
|
||||
"""Format checkpoint list for display to user."""
|
||||
if not checkpoints:
|
||||
return f"No checkpoints found for {directory}"
|
||||
|
||||
lines = [f"📸 Checkpoints for {directory}:\n"]
|
||||
for i, cp in enumerate(checkpoints, 1):
|
||||
# Parse ISO timestamp to something readable
|
||||
ts = cp["timestamp"]
|
||||
if "T" in ts:
|
||||
ts = ts.split("T")[1].split("+")[0].split("-")[0][:5] # HH:MM
|
||||
date = cp["timestamp"].split("T")[0]
|
||||
ts = f"{date} {ts}"
|
||||
|
||||
# Build change summary
|
||||
files = cp.get("files_changed", 0)
|
||||
ins = cp.get("insertions", 0)
|
||||
dele = cp.get("deletions", 0)
|
||||
if files:
|
||||
stat = f" ({files} file{'s' if files != 1 else ''}, +{ins}/-{dele})"
|
||||
else:
|
||||
stat = ""
|
||||
|
||||
lines.append(f" {i}. {cp['short_hash']} {ts} {cp['reason']}{stat}")
|
||||
|
||||
lines.append(f"\n /rollback <N> restore to checkpoint N")
|
||||
lines.append(f" /rollback diff <N> preview changes since checkpoint N")
|
||||
lines.append(f" /rollback <N> <file> restore a single file from checkpoint N")
|
||||
return "\n".join(lines)
|
||||
141
hermes_code/tools/clarify_tool.py
Normal file
141
hermes_code/tools/clarify_tool.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clarify Tool Module - Interactive Clarifying Questions
|
||||
|
||||
Allows the agent to present structured multiple-choice questions or open-ended
|
||||
prompts to the user. In CLI mode, choices are navigable with arrow keys. On
|
||||
messaging platforms, choices are rendered as a numbered list.
|
||||
|
||||
The actual user-interaction logic lives in the platform layer (cli.py for CLI,
|
||||
gateway/run.py for messaging). This module defines the schema, validation, and
|
||||
a thin dispatcher that delegates to a platform-provided callback.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
|
||||
|
||||
# Maximum number of predefined choices the agent can offer.
|
||||
# A 5th "Other (type your answer)" option is always appended by the UI.
|
||||
MAX_CHOICES = 4
|
||||
|
||||
|
||||
def clarify_tool(
|
||||
question: str,
|
||||
choices: Optional[List[str]] = None,
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Ask the user a question, optionally with multiple-choice options.
|
||||
|
||||
Args:
|
||||
question: The question text to present.
|
||||
choices: Up to 4 predefined answer choices. When omitted the
|
||||
question is purely open-ended.
|
||||
callback: Platform-provided function that handles the actual UI
|
||||
interaction. Signature: callback(question, choices) -> str.
|
||||
Injected by the agent runner (cli.py / gateway).
|
||||
|
||||
Returns:
|
||||
JSON string with the user's response.
|
||||
"""
|
||||
if not question or not question.strip():
|
||||
return json.dumps({"error": "Question text is required."}, ensure_ascii=False)
|
||||
|
||||
question = question.strip()
|
||||
|
||||
# Validate and trim choices
|
||||
if choices is not None:
|
||||
if not isinstance(choices, list):
|
||||
return json.dumps({"error": "choices must be a list of strings."}, ensure_ascii=False)
|
||||
choices = [str(c).strip() for c in choices if str(c).strip()]
|
||||
if len(choices) > MAX_CHOICES:
|
||||
choices = choices[:MAX_CHOICES]
|
||||
if not choices:
|
||||
choices = None # empty list → open-ended
|
||||
|
||||
if callback is None:
|
||||
return json.dumps(
|
||||
{"error": "Clarify tool is not available in this execution context."},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
try:
|
||||
user_response = callback(question, choices)
|
||||
except Exception as exc:
|
||||
return json.dumps(
|
||||
{"error": f"Failed to get user input: {exc}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"question": question,
|
||||
"choices_offered": choices,
|
||||
"user_response": str(user_response).strip(),
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_clarify_requirements() -> bool:
|
||||
"""Clarify tool has no external requirements -- always available."""
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI Function-Calling Schema
|
||||
# =============================================================================
|
||||
|
||||
CLARIFY_SCHEMA = {
|
||||
"name": "clarify",
|
||||
"description": (
|
||||
"Ask the user a question when you need clarification, feedback, or a "
|
||||
"decision before proceeding. Supports two modes:\n\n"
|
||||
"1. **Multiple choice** — provide up to 4 choices. The user picks one "
|
||||
"or types their own answer via a 5th 'Other' option.\n"
|
||||
"2. **Open-ended** — omit choices entirely. The user types a free-form "
|
||||
"response.\n\n"
|
||||
"Use this tool when:\n"
|
||||
"- The task is ambiguous and you need the user to choose an approach\n"
|
||||
"- You want post-task feedback ('How did that work out?')\n"
|
||||
"- You want to offer to save a skill or update memory\n"
|
||||
"- A decision has meaningful trade-offs the user should weigh in on\n\n"
|
||||
"Do NOT use this tool for simple yes/no confirmation of dangerous "
|
||||
"commands (the terminal tool handles that). Prefer making a reasonable "
|
||||
"default choice yourself when the decision is low-stakes."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to present to the user.",
|
||||
},
|
||||
"choices": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"maxItems": MAX_CHOICES,
|
||||
"description": (
|
||||
"Up to 4 answer choices. Omit this parameter entirely to "
|
||||
"ask an open-ended question. When provided, the UI "
|
||||
"automatically appends an 'Other (type your answer)' option."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="clarify",
|
||||
toolset="clarify",
|
||||
schema=CLARIFY_SCHEMA,
|
||||
handler=lambda args, **kw: clarify_tool(
|
||||
question=args.get("question", ""),
|
||||
choices=args.get("choices"),
|
||||
callback=kw.get("callback")),
|
||||
check_fn=check_clarify_requirements,
|
||||
emoji="❓",
|
||||
)
|
||||
806
hermes_code/tools/code_execution_tool.py
Normal file
806
hermes_code/tools/code_execution_tool.py
Normal file
|
|
@ -0,0 +1,806 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Code Execution Tool -- Programmatic Tool Calling (PTC)
|
||||
|
||||
Lets the LLM write a Python script that calls Hermes tools via RPC,
|
||||
collapsing multi-step tool chains into a single inference turn.
|
||||
|
||||
Architecture:
|
||||
1. Parent generates a `hermes_tools.py` stub module with RPC functions
|
||||
2. Parent opens a Unix domain socket and starts an RPC listener thread
|
||||
3. Parent spawns a child process that runs the LLM's script
|
||||
4. When the script calls a tool function, the call travels over the UDS
|
||||
back to the parent, which dispatches through handle_function_call
|
||||
5. Only the script's stdout is returned to the LLM; intermediate tool
|
||||
results never enter the context window
|
||||
|
||||
Platform: Linux / macOS only (Unix domain sockets). Disabled on Windows.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Availability gate: UDS requires a POSIX OS
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SANDBOX_AVAILABLE = sys.platform != "win32"
|
||||
|
||||
# The 7 tools allowed inside the sandbox. The intersection of this list
|
||||
# and the session's enabled tools determines which stubs are generated.
|
||||
SANDBOX_ALLOWED_TOOLS = frozenset([
|
||||
"web_search",
|
||||
"web_extract",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"search_files",
|
||||
"patch",
|
||||
"terminal",
|
||||
])
|
||||
|
||||
# Resource limit defaults (overridable via config.yaml → code_execution.*)
|
||||
DEFAULT_TIMEOUT = 300 # 5 minutes
|
||||
DEFAULT_MAX_TOOL_CALLS = 50
|
||||
MAX_STDOUT_BYTES = 50_000 # 50 KB
|
||||
MAX_STDERR_BYTES = 10_000 # 10 KB
|
||||
|
||||
|
||||
def check_sandbox_requirements() -> bool:
|
||||
"""Code execution sandbox requires a POSIX OS for Unix domain sockets."""
|
||||
return SANDBOX_AVAILABLE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# hermes_tools.py code generator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Per-tool stub templates: (function_name, signature, docstring, args_dict_expr)
|
||||
# The args_dict_expr builds the JSON payload sent over the RPC socket.
|
||||
_TOOL_STUBS = {
|
||||
"web_search": (
|
||||
"web_search",
|
||||
"query: str, limit: int = 5",
|
||||
'"""Search the web. Returns dict with data.web list of {url, title, description}."""',
|
||||
'{"query": query, "limit": limit}',
|
||||
),
|
||||
"web_extract": (
|
||||
"web_extract",
|
||||
"urls: list",
|
||||
'"""Extract content from URLs. Returns dict with results list of {url, title, content, error}."""',
|
||||
'{"urls": urls}',
|
||||
),
|
||||
"read_file": (
|
||||
"read_file",
|
||||
"path: str, offset: int = 1, limit: int = 500",
|
||||
'"""Read a file (1-indexed lines). Returns dict with "content" and "total_lines"."""',
|
||||
'{"path": path, "offset": offset, "limit": limit}',
|
||||
),
|
||||
"write_file": (
|
||||
"write_file",
|
||||
"path: str, content: str",
|
||||
'"""Write content to a file (always overwrites). Returns dict with status."""',
|
||||
'{"path": path, "content": content}',
|
||||
),
|
||||
"search_files": (
|
||||
"search_files",
|
||||
'pattern: str, target: str = "content", path: str = ".", file_glob: str = None, limit: int = 50, offset: int = 0, output_mode: str = "content", context: int = 0',
|
||||
'"""Search file contents (target="content") or find files by name (target="files"). Returns dict with "matches"."""',
|
||||
'{"pattern": pattern, "target": target, "path": path, "file_glob": file_glob, "limit": limit, "offset": offset, "output_mode": output_mode, "context": context}',
|
||||
),
|
||||
"patch": (
|
||||
"patch",
|
||||
'path: str = None, old_string: str = None, new_string: str = None, replace_all: bool = False, mode: str = "replace", patch: str = None',
|
||||
'"""Targeted find-and-replace (mode="replace") or V4A multi-file patches (mode="patch"). Returns dict with status."""',
|
||||
'{"path": path, "old_string": old_string, "new_string": new_string, "replace_all": replace_all, "mode": mode, "patch": patch}',
|
||||
),
|
||||
"terminal": (
|
||||
"terminal",
|
||||
"command: str, timeout: int = None, workdir: str = None",
|
||||
'"""Run a shell command (foreground only). Returns dict with "output" and "exit_code"."""',
|
||||
'{"command": command, "timeout": timeout, "workdir": workdir}',
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def generate_hermes_tools_module(enabled_tools: List[str]) -> str:
|
||||
"""
|
||||
Build the source code for the hermes_tools.py stub module.
|
||||
|
||||
Only tools in both SANDBOX_ALLOWED_TOOLS and enabled_tools get stubs.
|
||||
"""
|
||||
tools_to_generate = sorted(SANDBOX_ALLOWED_TOOLS & set(enabled_tools))
|
||||
|
||||
stub_functions = []
|
||||
export_names = []
|
||||
for tool_name in tools_to_generate:
|
||||
if tool_name not in _TOOL_STUBS:
|
||||
continue
|
||||
func_name, sig, doc, args_expr = _TOOL_STUBS[tool_name]
|
||||
stub_functions.append(
|
||||
f"def {func_name}({sig}):\n"
|
||||
f" {doc}\n"
|
||||
f" return _call({func_name!r}, {args_expr})\n"
|
||||
)
|
||||
export_names.append(func_name)
|
||||
|
||||
header = '''\
|
||||
"""Auto-generated Hermes tools RPC stubs."""
|
||||
import json, os, socket, shlex, time
|
||||
|
||||
_sock = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience helpers (avoid common scripting pitfalls)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def json_parse(text: str):
|
||||
"""Parse JSON tolerant of control characters (strict=False).
|
||||
Use this instead of json.loads() when parsing output from terminal()
|
||||
or web_extract() that may contain raw tabs/newlines in strings."""
|
||||
return json.loads(text, strict=False)
|
||||
|
||||
|
||||
def shell_quote(s: str) -> str:
|
||||
"""Shell-escape a string for safe interpolation into commands.
|
||||
Use this when inserting dynamic content into terminal() commands:
|
||||
terminal(f"echo {shell_quote(user_input)}")
|
||||
"""
|
||||
return shlex.quote(s)
|
||||
|
||||
|
||||
def retry(fn, max_attempts=3, delay=2):
|
||||
"""Retry a function up to max_attempts times with exponential backoff.
|
||||
Use for transient failures (network errors, API rate limits):
|
||||
result = retry(lambda: terminal("gh issue list ..."))
|
||||
"""
|
||||
last_err = None
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return fn()
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
if attempt < max_attempts - 1:
|
||||
time.sleep(delay * (2 ** attempt))
|
||||
raise last_err
|
||||
|
||||
def _connect():
|
||||
global _sock
|
||||
if _sock is None:
|
||||
_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
_sock.connect(os.environ["HERMES_RPC_SOCKET"])
|
||||
_sock.settimeout(300)
|
||||
return _sock
|
||||
|
||||
def _call(tool_name, args):
|
||||
"""Send a tool call to the parent process and return the parsed result."""
|
||||
conn = _connect()
|
||||
request = json.dumps({"tool": tool_name, "args": args}) + "\\n"
|
||||
conn.sendall(request.encode())
|
||||
buf = b""
|
||||
while True:
|
||||
chunk = conn.recv(65536)
|
||||
if not chunk:
|
||||
raise RuntimeError("Agent process disconnected")
|
||||
buf += chunk
|
||||
if buf.endswith(b"\\n"):
|
||||
break
|
||||
raw = buf.decode().strip()
|
||||
result = json.loads(raw)
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
return json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return result
|
||||
return result
|
||||
|
||||
'''
|
||||
|
||||
return header + "\n".join(stub_functions)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RPC server (runs in a thread inside the parent process)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Terminal parameters that must not be used from ephemeral sandbox scripts
|
||||
_TERMINAL_BLOCKED_PARAMS = {"background", "check_interval", "pty"}
|
||||
|
||||
|
||||
def _rpc_server_loop(
|
||||
server_sock: socket.socket,
|
||||
task_id: str,
|
||||
tool_call_log: list,
|
||||
tool_call_counter: list, # mutable [int] so the thread can increment
|
||||
max_tool_calls: int,
|
||||
allowed_tools: frozenset,
|
||||
):
|
||||
"""
|
||||
Accept one client connection and dispatch tool-call requests until
|
||||
the client disconnects or the call limit is reached.
|
||||
"""
|
||||
from model_tools import handle_function_call
|
||||
|
||||
conn = None
|
||||
try:
|
||||
server_sock.settimeout(5)
|
||||
conn, _ = server_sock.accept()
|
||||
conn.settimeout(300)
|
||||
|
||||
buf = b""
|
||||
while True:
|
||||
try:
|
||||
chunk = conn.recv(65536)
|
||||
except socket.timeout:
|
||||
break
|
||||
if not chunk:
|
||||
break
|
||||
buf += chunk
|
||||
|
||||
# Process all complete newline-delimited messages in the buffer
|
||||
while b"\n" in buf:
|
||||
line, buf = buf.split(b"\n", 1)
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
call_start = time.monotonic()
|
||||
try:
|
||||
request = json.loads(line.decode())
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
resp = json.dumps({"error": f"Invalid RPC request: {exc}"})
|
||||
conn.sendall((resp + "\n").encode())
|
||||
continue
|
||||
|
||||
tool_name = request.get("tool", "")
|
||||
tool_args = request.get("args", {})
|
||||
|
||||
# Enforce the allow-list
|
||||
if tool_name not in allowed_tools:
|
||||
available = ", ".join(sorted(allowed_tools))
|
||||
resp = json.dumps({
|
||||
"error": (
|
||||
f"Tool '{tool_name}' is not available in execute_code. "
|
||||
f"Available: {available}"
|
||||
)
|
||||
})
|
||||
conn.sendall((resp + "\n").encode())
|
||||
continue
|
||||
|
||||
# Enforce tool call limit
|
||||
if tool_call_counter[0] >= max_tool_calls:
|
||||
resp = json.dumps({
|
||||
"error": (
|
||||
f"Tool call limit reached ({max_tool_calls}). "
|
||||
"No more tool calls allowed in this execution."
|
||||
)
|
||||
})
|
||||
conn.sendall((resp + "\n").encode())
|
||||
continue
|
||||
|
||||
# Strip forbidden terminal parameters
|
||||
if tool_name == "terminal" and isinstance(tool_args, dict):
|
||||
for param in _TERMINAL_BLOCKED_PARAMS:
|
||||
tool_args.pop(param, None)
|
||||
|
||||
# Dispatch through the standard tool handler.
|
||||
# Suppress stdout/stderr from internal tool handlers so
|
||||
# their status prints don't leak into the CLI spinner.
|
||||
try:
|
||||
_real_stdout, _real_stderr = sys.stdout, sys.stderr
|
||||
devnull = open(os.devnull, "w")
|
||||
try:
|
||||
sys.stdout = devnull
|
||||
sys.stderr = devnull
|
||||
result = handle_function_call(
|
||||
tool_name, tool_args, task_id=task_id
|
||||
)
|
||||
finally:
|
||||
sys.stdout, sys.stderr = _real_stdout, _real_stderr
|
||||
devnull.close()
|
||||
except Exception as exc:
|
||||
logger.error("Tool call failed in sandbox: %s", exc, exc_info=True)
|
||||
result = json.dumps({"error": str(exc)})
|
||||
|
||||
tool_call_counter[0] += 1
|
||||
call_duration = time.monotonic() - call_start
|
||||
|
||||
# Log for observability
|
||||
args_preview = str(tool_args)[:80]
|
||||
tool_call_log.append({
|
||||
"tool": tool_name,
|
||||
"args_preview": args_preview,
|
||||
"duration": round(call_duration, 2),
|
||||
})
|
||||
|
||||
conn.sendall((result + "\n").encode())
|
||||
|
||||
except socket.timeout:
|
||||
logger.debug("RPC listener socket timeout")
|
||||
except OSError as e:
|
||||
logger.debug("RPC listener socket error: %s", e, exc_info=True)
|
||||
finally:
|
||||
if conn:
|
||||
try:
|
||||
conn.close()
|
||||
except OSError as e:
|
||||
logger.debug("RPC conn close error: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def execute_code(
|
||||
code: str,
|
||||
task_id: Optional[str] = None,
|
||||
enabled_tools: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Run a Python script in a sandboxed child process with RPC access
|
||||
to a subset of Hermes tools.
|
||||
|
||||
Args:
|
||||
code: Python source code to execute.
|
||||
task_id: Session task ID for tool isolation (terminal env, etc.).
|
||||
enabled_tools: Tool names enabled in the current session. The sandbox
|
||||
gets the intersection with SANDBOX_ALLOWED_TOOLS.
|
||||
|
||||
Returns:
|
||||
JSON string with execution results.
|
||||
"""
|
||||
if not SANDBOX_AVAILABLE:
|
||||
return json.dumps({
|
||||
"error": "execute_code is not available on Windows. Use normal tool calls instead."
|
||||
})
|
||||
|
||||
if not code or not code.strip():
|
||||
return json.dumps({"error": "No code provided."})
|
||||
|
||||
# Import interrupt event from terminal_tool (cooperative cancellation)
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
|
||||
# Resolve config
|
||||
_cfg = _load_config()
|
||||
timeout = _cfg.get("timeout", DEFAULT_TIMEOUT)
|
||||
max_tool_calls = _cfg.get("max_tool_calls", DEFAULT_MAX_TOOL_CALLS)
|
||||
|
||||
# Determine which tools the sandbox can call
|
||||
session_tools = set(enabled_tools) if enabled_tools else set()
|
||||
sandbox_tools = frozenset(SANDBOX_ALLOWED_TOOLS & session_tools)
|
||||
|
||||
if not sandbox_tools:
|
||||
sandbox_tools = SANDBOX_ALLOWED_TOOLS
|
||||
|
||||
# --- Set up temp directory with hermes_tools.py and script.py ---
|
||||
tmpdir = tempfile.mkdtemp(prefix="hermes_sandbox_")
|
||||
# Use /tmp on macOS to avoid the long /var/folders/... path that pushes
|
||||
# Unix domain socket paths past the 104-byte macOS AF_UNIX limit.
|
||||
# On Linux, tempfile.gettempdir() already returns /tmp.
|
||||
_sock_tmpdir = "/tmp" if sys.platform == "darwin" else tempfile.gettempdir()
|
||||
sock_path = os.path.join(_sock_tmpdir, f"hermes_rpc_{uuid.uuid4().hex}.sock")
|
||||
|
||||
tool_call_log: list = []
|
||||
tool_call_counter = [0] # mutable so the RPC thread can increment
|
||||
exec_start = time.monotonic()
|
||||
server_sock = None
|
||||
|
||||
try:
|
||||
# Write the auto-generated hermes_tools module
|
||||
# sandbox_tools is already the correct set (intersection with session
|
||||
# tools, or SANDBOX_ALLOWED_TOOLS as fallback — see lines above).
|
||||
tools_src = generate_hermes_tools_module(list(sandbox_tools))
|
||||
with open(os.path.join(tmpdir, "hermes_tools.py"), "w") as f:
|
||||
f.write(tools_src)
|
||||
|
||||
# Write the user's script
|
||||
with open(os.path.join(tmpdir, "script.py"), "w") as f:
|
||||
f.write(code)
|
||||
|
||||
# --- Start UDS server ---
|
||||
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server_sock.bind(sock_path)
|
||||
server_sock.listen(1)
|
||||
|
||||
rpc_thread = threading.Thread(
|
||||
target=_rpc_server_loop,
|
||||
args=(
|
||||
server_sock, task_id, tool_call_log,
|
||||
tool_call_counter, max_tool_calls, sandbox_tools,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
rpc_thread.start()
|
||||
|
||||
# --- Spawn child process ---
|
||||
# Build a minimal environment for the child. We intentionally exclude
|
||||
# API keys and tokens to prevent credential exfiltration from LLM-
|
||||
# generated scripts. The child accesses tools via RPC, not direct API.
|
||||
# Exception: env vars declared by loaded skills (via env_passthrough
|
||||
# registry) or explicitly allowed by the user in config.yaml
|
||||
# (terminal.env_passthrough) are passed through.
|
||||
_SAFE_ENV_PREFIXES = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM",
|
||||
"TMPDIR", "TMP", "TEMP", "SHELL", "LOGNAME",
|
||||
"XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA")
|
||||
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL",
|
||||
"PASSWD", "AUTH")
|
||||
try:
|
||||
from tools.env_passthrough import is_env_passthrough as _is_passthrough
|
||||
except Exception:
|
||||
_is_passthrough = lambda _: False # noqa: E731
|
||||
child_env = {}
|
||||
for k, v in os.environ.items():
|
||||
# Passthrough vars (skill-declared or user-configured) always pass.
|
||||
if _is_passthrough(k):
|
||||
child_env[k] = v
|
||||
continue
|
||||
# Block vars with secret-like names.
|
||||
if any(s in k.upper() for s in _SECRET_SUBSTRINGS):
|
||||
continue
|
||||
# Allow vars with known safe prefixes.
|
||||
if any(k.startswith(p) for p in _SAFE_ENV_PREFIXES):
|
||||
child_env[k] = v
|
||||
child_env["HERMES_RPC_SOCKET"] = sock_path
|
||||
child_env["PYTHONDONTWRITEBYTECODE"] = "1"
|
||||
# Ensure the hermes-agent root is importable in the sandbox so
|
||||
# repo-root modules are available to child scripts.
|
||||
_hermes_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
_existing_pp = child_env.get("PYTHONPATH", "")
|
||||
child_env["PYTHONPATH"] = _hermes_root + (os.pathsep + _existing_pp if _existing_pp else "")
|
||||
# Inject user's configured timezone so datetime.now() in sandboxed
|
||||
# code reflects the correct wall-clock time.
|
||||
_tz_name = os.getenv("HERMES_TIMEZONE", "").strip()
|
||||
if _tz_name:
|
||||
child_env["TZ"] = _tz_name
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[sys.executable, "script.py"],
|
||||
cwd=tmpdir,
|
||||
env=child_env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
stdin=subprocess.DEVNULL,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
# --- Poll loop: watch for exit, timeout, and interrupt ---
|
||||
deadline = time.monotonic() + timeout
|
||||
stderr_chunks: list = []
|
||||
|
||||
# Background readers to avoid pipe buffer deadlocks.
|
||||
# For stdout we use a head+tail strategy: keep the first HEAD_BYTES
|
||||
# and a rolling window of the last TAIL_BYTES so the final print()
|
||||
# output is never lost. Stderr keeps head-only (errors appear early).
|
||||
_STDOUT_HEAD_BYTES = int(MAX_STDOUT_BYTES * 0.4) # 40% head
|
||||
_STDOUT_TAIL_BYTES = MAX_STDOUT_BYTES - _STDOUT_HEAD_BYTES # 60% tail
|
||||
|
||||
def _drain(pipe, chunks, max_bytes):
|
||||
"""Simple head-only drain (used for stderr)."""
|
||||
total = 0
|
||||
try:
|
||||
while True:
|
||||
data = pipe.read(4096)
|
||||
if not data:
|
||||
break
|
||||
if total < max_bytes:
|
||||
keep = max_bytes - total
|
||||
chunks.append(data[:keep])
|
||||
total += len(data)
|
||||
except (ValueError, OSError) as e:
|
||||
logger.debug("Error reading process output: %s", e, exc_info=True)
|
||||
|
||||
stdout_total_bytes = [0] # mutable ref for total bytes seen
|
||||
|
||||
def _drain_head_tail(pipe, head_chunks, tail_chunks, head_bytes, tail_bytes, total_ref):
|
||||
"""Drain stdout keeping both head and tail data."""
|
||||
head_collected = 0
|
||||
from collections import deque
|
||||
tail_buf = deque()
|
||||
tail_collected = 0
|
||||
try:
|
||||
while True:
|
||||
data = pipe.read(4096)
|
||||
if not data:
|
||||
break
|
||||
total_ref[0] += len(data)
|
||||
# Fill head buffer first
|
||||
if head_collected < head_bytes:
|
||||
keep = min(len(data), head_bytes - head_collected)
|
||||
head_chunks.append(data[:keep])
|
||||
head_collected += keep
|
||||
data = data[keep:] # remaining goes to tail
|
||||
if not data:
|
||||
continue
|
||||
# Everything past head goes into rolling tail buffer
|
||||
tail_buf.append(data)
|
||||
tail_collected += len(data)
|
||||
# Evict old tail data to stay within tail_bytes budget
|
||||
while tail_collected > tail_bytes and tail_buf:
|
||||
oldest = tail_buf.popleft()
|
||||
tail_collected -= len(oldest)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
# Transfer final tail to output list
|
||||
tail_chunks.extend(tail_buf)
|
||||
|
||||
stdout_head_chunks: list = []
|
||||
stdout_tail_chunks: list = []
|
||||
|
||||
stdout_reader = threading.Thread(
|
||||
target=_drain_head_tail,
|
||||
args=(proc.stdout, stdout_head_chunks, stdout_tail_chunks,
|
||||
_STDOUT_HEAD_BYTES, _STDOUT_TAIL_BYTES, stdout_total_bytes),
|
||||
daemon=True
|
||||
)
|
||||
stderr_reader = threading.Thread(
|
||||
target=_drain, args=(proc.stderr, stderr_chunks, MAX_STDERR_BYTES), daemon=True
|
||||
)
|
||||
stdout_reader.start()
|
||||
stderr_reader.start()
|
||||
|
||||
status = "success"
|
||||
while proc.poll() is None:
|
||||
if _interrupt_event.is_set():
|
||||
_kill_process_group(proc)
|
||||
status = "interrupted"
|
||||
break
|
||||
if time.monotonic() > deadline:
|
||||
_kill_process_group(proc, escalate=True)
|
||||
status = "timeout"
|
||||
break
|
||||
time.sleep(0.2)
|
||||
|
||||
# Wait for readers to finish draining
|
||||
stdout_reader.join(timeout=3)
|
||||
stderr_reader.join(timeout=3)
|
||||
|
||||
stdout_head = b"".join(stdout_head_chunks).decode("utf-8", errors="replace")
|
||||
stdout_tail = b"".join(stdout_tail_chunks).decode("utf-8", errors="replace")
|
||||
stderr_text = b"".join(stderr_chunks).decode("utf-8", errors="replace")
|
||||
|
||||
# Assemble stdout with head+tail truncation
|
||||
total_stdout = stdout_total_bytes[0]
|
||||
if total_stdout > MAX_STDOUT_BYTES and stdout_tail:
|
||||
omitted = total_stdout - len(stdout_head) - len(stdout_tail)
|
||||
truncated_notice = (
|
||||
f"\n\n... [OUTPUT TRUNCATED - {omitted:,} chars omitted "
|
||||
f"out of {total_stdout:,} total] ...\n\n"
|
||||
)
|
||||
stdout_text = stdout_head + truncated_notice + stdout_tail
|
||||
else:
|
||||
stdout_text = stdout_head + stdout_tail
|
||||
|
||||
exit_code = proc.returncode if proc.returncode is not None else -1
|
||||
duration = round(time.monotonic() - exec_start, 2)
|
||||
|
||||
# Wait for RPC thread to finish
|
||||
server_sock.close() # break accept() so thread exits promptly
|
||||
server_sock = None # prevent double close in finally
|
||||
rpc_thread.join(timeout=3)
|
||||
|
||||
# Strip ANSI escape sequences so the model never sees terminal
|
||||
# formatting — prevents it from copying escapes into file writes.
|
||||
from tools.ansi_strip import strip_ansi
|
||||
stdout_text = strip_ansi(stdout_text)
|
||||
stderr_text = strip_ansi(stderr_text)
|
||||
|
||||
# Build response
|
||||
result: Dict[str, Any] = {
|
||||
"status": status,
|
||||
"output": stdout_text,
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
"duration_seconds": duration,
|
||||
}
|
||||
|
||||
if status == "timeout":
|
||||
result["error"] = f"Script timed out after {timeout}s and was killed."
|
||||
elif status == "interrupted":
|
||||
result["output"] = stdout_text + "\n[execution interrupted — user sent a new message]"
|
||||
elif exit_code != 0:
|
||||
result["status"] = "error"
|
||||
result["error"] = stderr_text or f"Script exited with code {exit_code}"
|
||||
# Include stderr in output so the LLM sees the traceback
|
||||
if stderr_text:
|
||||
result["output"] = stdout_text + "\n--- stderr ---\n" + stderr_text
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
except Exception as exc:
|
||||
duration = round(time.monotonic() - exec_start, 2)
|
||||
logger.error(
|
||||
"execute_code failed after %ss with %d tool calls: %s: %s",
|
||||
duration,
|
||||
tool_call_counter[0],
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return json.dumps({
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
"tool_calls_made": tool_call_counter[0],
|
||||
"duration_seconds": duration,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
finally:
|
||||
# Cleanup temp dir and socket
|
||||
if server_sock is not None:
|
||||
try:
|
||||
server_sock.close()
|
||||
except OSError as e:
|
||||
logger.debug("Server socket close error: %s", e)
|
||||
import shutil
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
try:
|
||||
os.unlink(sock_path)
|
||||
except OSError:
|
||||
pass # already cleaned up or never created
|
||||
|
||||
|
||||
def _kill_process_group(proc, escalate: bool = False):
|
||||
"""Kill the child and its entire process group."""
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError) as e:
|
||||
logger.debug("Could not kill process group: %s", e, exc_info=True)
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception as e2:
|
||||
logger.debug("Could not kill process: %s", e2, exc_info=True)
|
||||
|
||||
if escalate:
|
||||
# Give the process 5s to exit after SIGTERM, then SIGKILL
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.kill()
|
||||
else:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError) as e:
|
||||
logger.debug("Could not kill process group with SIGKILL: %s", e, exc_info=True)
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception as e2:
|
||||
logger.debug("Could not kill process: %s", e2, exc_info=True)
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Load code_execution config from CLI_CONFIG if available."""
|
||||
try:
|
||||
from cli import CLI_CONFIG
|
||||
return CLI_CONFIG.get("code_execution", {})
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI Function-Calling Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Per-tool documentation lines for the execute_code description.
|
||||
# Ordered to match the canonical display order.
|
||||
_TOOL_DOC_LINES = [
|
||||
("web_search",
|
||||
" web_search(query: str, limit: int = 5) -> dict\n"
|
||||
" Returns {\"data\": {\"web\": [{\"url\", \"title\", \"description\"}, ...]}}"),
|
||||
("web_extract",
|
||||
" web_extract(urls: list[str]) -> dict\n"
|
||||
" Returns {\"results\": [{\"url\", \"title\", \"content\", \"error\"}, ...]} where content is markdown"),
|
||||
("read_file",
|
||||
" read_file(path: str, offset: int = 1, limit: int = 500) -> dict\n"
|
||||
" Lines are 1-indexed. Returns {\"content\": \"...\", \"total_lines\": N}"),
|
||||
("write_file",
|
||||
" write_file(path: str, content: str) -> dict\n"
|
||||
" Always overwrites the entire file."),
|
||||
("search_files",
|
||||
" search_files(pattern: str, target=\"content\", path=\".\", file_glob=None, limit=50) -> dict\n"
|
||||
" target: \"content\" (search inside files) or \"files\" (find files by name). Returns {\"matches\": [...]}"),
|
||||
("patch",
|
||||
" patch(path: str, old_string: str, new_string: str, replace_all: bool = False) -> dict\n"
|
||||
" Replaces old_string with new_string in the file."),
|
||||
("terminal",
|
||||
" terminal(command: str, timeout=None, workdir=None) -> dict\n"
|
||||
" Foreground only (no background/pty). Returns {\"output\": \"...\", \"exit_code\": N}"),
|
||||
]
|
||||
|
||||
|
||||
def build_execute_code_schema(enabled_sandbox_tools: set = None) -> dict:
|
||||
"""Build the execute_code schema with description listing only enabled tools.
|
||||
|
||||
When tools are disabled via ``hermes tools`` (e.g. web is turned off),
|
||||
the schema description should NOT mention web_search / web_extract —
|
||||
otherwise the model thinks they are available and keeps trying to use them.
|
||||
"""
|
||||
if enabled_sandbox_tools is None:
|
||||
enabled_sandbox_tools = SANDBOX_ALLOWED_TOOLS
|
||||
|
||||
# Build tool documentation lines for only the enabled tools
|
||||
tool_lines = "\n".join(
|
||||
doc for name, doc in _TOOL_DOC_LINES if name in enabled_sandbox_tools
|
||||
)
|
||||
|
||||
# Build example import list from enabled tools
|
||||
import_examples = [n for n in ("web_search", "terminal") if n in enabled_sandbox_tools]
|
||||
if not import_examples:
|
||||
import_examples = sorted(enabled_sandbox_tools)[:2]
|
||||
if import_examples:
|
||||
import_str = ", ".join(import_examples) + ", ..."
|
||||
else:
|
||||
import_str = "..."
|
||||
|
||||
description = (
|
||||
"Run a Python script that can call Hermes tools programmatically. "
|
||||
"Use this when you need 3+ tool calls with processing logic between them, "
|
||||
"need to filter/reduce large tool outputs before they enter your context, "
|
||||
"need conditional branching (if X then Y else Z), or need to loop "
|
||||
"(fetch N pages, process N files, retry on failure).\n\n"
|
||||
"Use normal tool calls instead when: single tool call with no processing, "
|
||||
"you need to see the full result and apply complex reasoning, "
|
||||
"or the task requires interactive user input.\n\n"
|
||||
f"Available via `from hermes_tools import ...`:\n\n"
|
||||
f"{tool_lines}\n\n"
|
||||
"Limits: 5-minute timeout, 50KB stdout cap, max 50 tool calls per script. "
|
||||
"terminal() is foreground-only (no background or pty).\n\n"
|
||||
"Print your final result to stdout. Use Python stdlib (json, re, math, csv, "
|
||||
"datetime, collections, etc.) for processing between tool calls.\n\n"
|
||||
"Also available (no import needed — built into hermes_tools):\n"
|
||||
" json_parse(text: str) — json.loads with strict=False; use for terminal() output with control chars\n"
|
||||
" shell_quote(s: str) — shlex.quote(); use when interpolating dynamic strings into shell commands\n"
|
||||
" retry(fn, max_attempts=3, delay=2) — retry with exponential backoff for transient failures"
|
||||
)
|
||||
|
||||
return {
|
||||
"name": "execute_code",
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Python code to execute. Import tools with "
|
||||
f"`from hermes_tools import {import_str}` "
|
||||
"and print your final result to stdout."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Default schema used at registration time (all sandbox tools listed)
|
||||
EXECUTE_CODE_SCHEMA = build_execute_code_schema()
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="execute_code",
|
||||
toolset="code_execution",
|
||||
schema=EXECUTE_CODE_SCHEMA,
|
||||
handler=lambda args, **kw: execute_code(
|
||||
code=args.get("code", ""),
|
||||
task_id=kw.get("task_id"),
|
||||
enabled_tools=kw.get("enabled_tools")),
|
||||
check_fn=check_sandbox_requirements,
|
||||
emoji="🐍",
|
||||
)
|
||||
458
hermes_code/tools/cronjob_tools.py
Normal file
458
hermes_code/tools/cronjob_tools.py
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
"""
|
||||
Cron job management tools for Hermes Agent.
|
||||
|
||||
Expose a single compressed action-oriented tool to avoid schema/context bloat.
|
||||
Compatibility wrappers remain for direct Python callers and legacy tests.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Import from cron module (will be available when properly installed)
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import (
|
||||
create_job,
|
||||
get_job,
|
||||
list_jobs,
|
||||
parse_schedule,
|
||||
pause_job,
|
||||
remove_job,
|
||||
resume_job,
|
||||
trigger_job,
|
||||
update_job,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cron prompt scanning — critical-severity patterns only, since cron prompts
|
||||
# run in fresh sessions with full tool access.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CRON_THREAT_PATTERNS = [
|
||||
(r'ignore\s+(?:\w+\s+)*(?:previous|all|above|prior)\s+(?:\w+\s+)*instructions', "prompt_injection"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass)', "read_secrets"),
|
||||
(r'authorized_keys', "ssh_backdoor"),
|
||||
(r'/etc/sudoers|visudo', "sudoers_mod"),
|
||||
(r'rm\s+-rf\s+/', "destructive_root_rm"),
|
||||
]
|
||||
|
||||
_CRON_INVISIBLE_CHARS = {
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
}
|
||||
|
||||
|
||||
def _scan_cron_prompt(prompt: str) -> str:
|
||||
"""Scan a cron prompt for critical threats. Returns error string if blocked, else empty."""
|
||||
for char in _CRON_INVISIBLE_CHARS:
|
||||
if char in prompt:
|
||||
return f"Blocked: prompt contains invisible unicode U+{ord(char):04X} (possible injection)."
|
||||
for pattern, pid in _CRON_THREAT_PATTERNS:
|
||||
if re.search(pattern, prompt, re.IGNORECASE):
|
||||
return f"Blocked: prompt matches threat pattern '{pid}'. Cron prompts must not contain injection or exfiltration payloads."
|
||||
return ""
|
||||
|
||||
|
||||
def _origin_from_env() -> Optional[Dict[str, str]]:
|
||||
origin_platform = os.getenv("HERMES_SESSION_PLATFORM")
|
||||
origin_chat_id = os.getenv("HERMES_SESSION_CHAT_ID")
|
||||
if origin_platform and origin_chat_id:
|
||||
return {
|
||||
"platform": origin_platform,
|
||||
"chat_id": origin_chat_id,
|
||||
"chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"),
|
||||
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID"),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _repeat_display(job: Dict[str, Any]) -> str:
|
||||
times = (job.get("repeat") or {}).get("times")
|
||||
completed = (job.get("repeat") or {}).get("completed", 0)
|
||||
if times is None:
|
||||
return "forever"
|
||||
if times == 1:
|
||||
return "once" if completed == 0 else "1/1"
|
||||
return f"{completed}/{times}" if completed else f"{times} times"
|
||||
|
||||
|
||||
def _canonical_skills(skill: Optional[str] = None, skills: Optional[Any] = None) -> List[str]:
|
||||
if skills is None:
|
||||
raw_items = [skill] if skill else []
|
||||
elif isinstance(skills, str):
|
||||
raw_items = [skills]
|
||||
else:
|
||||
raw_items = list(skills)
|
||||
|
||||
normalized: List[str] = []
|
||||
for item in raw_items:
|
||||
text = str(item or "").strip()
|
||||
if text and text not in normalized:
|
||||
normalized.append(text)
|
||||
return normalized
|
||||
|
||||
|
||||
|
||||
def _normalize_optional_job_value(value: Optional[Any], *, strip_trailing_slash: bool = False) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
if strip_trailing_slash:
|
||||
text = text.rstrip("/")
|
||||
return text or None
|
||||
|
||||
|
||||
|
||||
def _format_job(job: Dict[str, Any]) -> Dict[str, Any]:
|
||||
prompt = job.get("prompt", "")
|
||||
skills = _canonical_skills(job.get("skill"), job.get("skills"))
|
||||
return {
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"skill": skills[0] if skills else None,
|
||||
"skills": skills,
|
||||
"prompt_preview": prompt[:100] + "..." if len(prompt) > 100 else prompt,
|
||||
"model": job.get("model"),
|
||||
"provider": job.get("provider"),
|
||||
"base_url": job.get("base_url"),
|
||||
"schedule": job.get("schedule_display"),
|
||||
"repeat": _repeat_display(job),
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job.get("next_run_at"),
|
||||
"last_run_at": job.get("last_run_at"),
|
||||
"last_status": job.get("last_status"),
|
||||
"enabled": job.get("enabled", True),
|
||||
"state": job.get("state", "scheduled" if job.get("enabled", True) else "paused"),
|
||||
"paused_at": job.get("paused_at"),
|
||||
"paused_reason": job.get("paused_reason"),
|
||||
}
|
||||
|
||||
|
||||
def cronjob(
|
||||
action: str,
|
||||
job_id: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
schedule: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
repeat: Optional[int] = None,
|
||||
deliver: Optional[str] = None,
|
||||
include_disabled: bool = False,
|
||||
skill: Optional[str] = None,
|
||||
skills: Optional[List[str]] = None,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
reason: Optional[str] = None,
|
||||
task_id: str = None,
|
||||
) -> str:
|
||||
"""Unified cron job management tool."""
|
||||
del task_id # unused but kept for handler signature compatibility
|
||||
|
||||
try:
|
||||
normalized = (action or "").strip().lower()
|
||||
|
||||
if normalized == "create":
|
||||
if not schedule:
|
||||
return json.dumps({"success": False, "error": "schedule is required for create"}, indent=2)
|
||||
canonical_skills = _canonical_skills(skill, skills)
|
||||
if not prompt and not canonical_skills:
|
||||
return json.dumps({"success": False, "error": "create requires either prompt or at least one skill"}, indent=2)
|
||||
if prompt:
|
||||
scan_error = _scan_cron_prompt(prompt)
|
||||
if scan_error:
|
||||
return json.dumps({"success": False, "error": scan_error}, indent=2)
|
||||
|
||||
job = create_job(
|
||||
prompt=prompt or "",
|
||||
schedule=schedule,
|
||||
name=name,
|
||||
repeat=repeat,
|
||||
deliver=deliver,
|
||||
origin=_origin_from_env(),
|
||||
skills=canonical_skills,
|
||||
model=_normalize_optional_job_value(model),
|
||||
provider=_normalize_optional_job_value(provider),
|
||||
base_url=_normalize_optional_job_value(base_url, strip_trailing_slash=True),
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"job_id": job["id"],
|
||||
"name": job["name"],
|
||||
"skill": job.get("skill"),
|
||||
"skills": job.get("skills", []),
|
||||
"schedule": job["schedule_display"],
|
||||
"repeat": _repeat_display(job),
|
||||
"deliver": job.get("deliver", "local"),
|
||||
"next_run_at": job["next_run_at"],
|
||||
"job": _format_job(job),
|
||||
"message": f"Cron job '{job['name']}' created.",
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
if normalized == "list":
|
||||
jobs = [_format_job(job) for job in list_jobs(include_disabled=include_disabled)]
|
||||
return json.dumps({"success": True, "count": len(jobs), "jobs": jobs}, indent=2)
|
||||
|
||||
if not job_id:
|
||||
return json.dumps({"success": False, "error": f"job_id is required for action '{normalized}'"}, indent=2)
|
||||
|
||||
job = get_job(job_id)
|
||||
if not job:
|
||||
return json.dumps(
|
||||
{"success": False, "error": f"Job with ID '{job_id}' not found. Use cronjob(action='list') to inspect jobs."},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
if normalized == "remove":
|
||||
removed = remove_job(job_id)
|
||||
if not removed:
|
||||
return json.dumps({"success": False, "error": f"Failed to remove job '{job_id}'"}, indent=2)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Cron job '{job['name']}' removed.",
|
||||
"removed_job": {
|
||||
"id": job_id,
|
||||
"name": job["name"],
|
||||
"schedule": job.get("schedule_display"),
|
||||
},
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
|
||||
if normalized == "pause":
|
||||
updated = pause_job(job_id, reason=reason)
|
||||
return json.dumps({"success": True, "job": _format_job(updated)}, indent=2)
|
||||
|
||||
if normalized == "resume":
|
||||
updated = resume_job(job_id)
|
||||
return json.dumps({"success": True, "job": _format_job(updated)}, indent=2)
|
||||
|
||||
if normalized in {"run", "run_now", "trigger"}:
|
||||
updated = trigger_job(job_id)
|
||||
return json.dumps({"success": True, "job": _format_job(updated)}, indent=2)
|
||||
|
||||
if normalized == "update":
|
||||
updates: Dict[str, Any] = {}
|
||||
if prompt is not None:
|
||||
scan_error = _scan_cron_prompt(prompt)
|
||||
if scan_error:
|
||||
return json.dumps({"success": False, "error": scan_error}, indent=2)
|
||||
updates["prompt"] = prompt
|
||||
if name is not None:
|
||||
updates["name"] = name
|
||||
if deliver is not None:
|
||||
updates["deliver"] = deliver
|
||||
if skills is not None or skill is not None:
|
||||
canonical_skills = _canonical_skills(skill, skills)
|
||||
updates["skills"] = canonical_skills
|
||||
updates["skill"] = canonical_skills[0] if canonical_skills else None
|
||||
if model is not None:
|
||||
updates["model"] = _normalize_optional_job_value(model)
|
||||
if provider is not None:
|
||||
updates["provider"] = _normalize_optional_job_value(provider)
|
||||
if base_url is not None:
|
||||
updates["base_url"] = _normalize_optional_job_value(base_url, strip_trailing_slash=True)
|
||||
if repeat is not None:
|
||||
# Normalize: treat 0 or negative as None (infinite)
|
||||
normalized_repeat = None if repeat <= 0 else repeat
|
||||
repeat_state = dict(job.get("repeat") or {})
|
||||
repeat_state["times"] = normalized_repeat
|
||||
updates["repeat"] = repeat_state
|
||||
if schedule is not None:
|
||||
parsed_schedule = parse_schedule(schedule)
|
||||
updates["schedule"] = parsed_schedule
|
||||
updates["schedule_display"] = parsed_schedule.get("display", schedule)
|
||||
if job.get("state") != "paused":
|
||||
updates["state"] = "scheduled"
|
||||
updates["enabled"] = True
|
||||
if not updates:
|
||||
return json.dumps({"success": False, "error": "No updates provided."}, indent=2)
|
||||
updated = update_job(job_id, updates)
|
||||
return json.dumps({"success": True, "job": _format_job(updated)}, indent=2)
|
||||
|
||||
return json.dumps({"success": False, "error": f"Unknown cron action '{action}'"}, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({"success": False, "error": str(e)}, indent=2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compatibility wrappers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def schedule_cronjob(
|
||||
prompt: str,
|
||||
schedule: str,
|
||||
name: Optional[str] = None,
|
||||
repeat: Optional[int] = None,
|
||||
deliver: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
task_id: str = None,
|
||||
) -> str:
|
||||
return cronjob(
|
||||
action="create",
|
||||
prompt=prompt,
|
||||
schedule=schedule,
|
||||
name=name,
|
||||
repeat=repeat,
|
||||
deliver=deliver,
|
||||
model=model,
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
|
||||
def list_cronjobs(include_disabled: bool = False, task_id: str = None) -> str:
|
||||
return cronjob(action="list", include_disabled=include_disabled, task_id=task_id)
|
||||
|
||||
|
||||
def remove_cronjob(job_id: str, task_id: str = None) -> str:
|
||||
return cronjob(action="remove", job_id=job_id, task_id=task_id)
|
||||
|
||||
|
||||
CRONJOB_SCHEMA = {
|
||||
"name": "cronjob",
|
||||
"description": """Manage scheduled cron jobs with a single compressed tool.
|
||||
|
||||
Use action='create' to schedule a new job from a prompt or one or more skills.
|
||||
Use action='list' to inspect jobs.
|
||||
Use action='update', 'pause', 'resume', 'remove', or 'run' to manage an existing job.
|
||||
|
||||
Jobs run in a fresh session with no current-chat context, so prompts must be self-contained.
|
||||
If skill or skills are provided on create, the future cron run loads those skills in order, then follows the prompt as the task instruction.
|
||||
On update, passing skills=[] clears attached skills.
|
||||
|
||||
NOTE: The agent's final response is auto-delivered to the target. Put the primary
|
||||
user-facing content in the final response. Cron jobs run autonomously with no user
|
||||
present — they cannot ask questions or request clarification.
|
||||
|
||||
Important safety rule: cron-run sessions should not recursively schedule more cron jobs.""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "One of: create, list, update, pause, resume, remove, run"
|
||||
},
|
||||
"job_id": {
|
||||
"type": "string",
|
||||
"description": "Required for update/pause/resume/remove/run"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "For create: the full self-contained prompt. If skill or skills are also provided, this becomes the task instruction paired with those skills."
|
||||
},
|
||||
"schedule": {
|
||||
"type": "string",
|
||||
"description": "For create/update: '30m', 'every 2h', '0 9 * * *', or ISO timestamp"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional human-friendly name"
|
||||
},
|
||||
"repeat": {
|
||||
"type": "integer",
|
||||
"description": "Optional repeat count. Omit for defaults (once for one-shot, forever for recurring)."
|
||||
},
|
||||
"deliver": {
|
||||
"type": "string",
|
||||
"description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, matrix, mattermost, homeassistant, dingtalk, email, sms, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Optional per-job model override used when the cron job runs"
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
"description": "Optional per-job provider override used when resolving runtime credentials"
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"description": "Optional per-job base URL override paired with provider/model routing"
|
||||
},
|
||||
"include_disabled": {
|
||||
"type": "boolean",
|
||||
"description": "For list: include paused/completed jobs"
|
||||
},
|
||||
"skill": {
|
||||
"type": "string",
|
||||
"description": "Optional single skill name to load before executing the cron prompt"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional ordered list of skills to load before executing the cron prompt. On update, pass an empty array to clear attached skills."
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Optional pause reason"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def check_cronjob_requirements() -> bool:
|
||||
"""
|
||||
Check if cronjob tools can be used.
|
||||
|
||||
Available in interactive CLI mode and gateway/messaging platforms.
|
||||
The cron system is internal (JSON file-based scheduler ticked by the gateway),
|
||||
so no external crontab executable is required.
|
||||
"""
|
||||
return bool(
|
||||
os.getenv("HERMES_INTERACTIVE")
|
||||
or os.getenv("HERMES_GATEWAY_SESSION")
|
||||
or os.getenv("HERMES_EXEC_ASK")
|
||||
)
|
||||
|
||||
|
||||
def get_cronjob_tool_definitions():
|
||||
"""Return tool definitions for cronjob management."""
|
||||
return [CRONJOB_SCHEMA]
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="cronjob",
|
||||
toolset="cronjob",
|
||||
schema=CRONJOB_SCHEMA,
|
||||
handler=lambda args, **kw: cronjob(
|
||||
action=args.get("action", ""),
|
||||
job_id=args.get("job_id"),
|
||||
prompt=args.get("prompt"),
|
||||
schedule=args.get("schedule"),
|
||||
name=args.get("name"),
|
||||
repeat=args.get("repeat"),
|
||||
deliver=args.get("deliver"),
|
||||
include_disabled=args.get("include_disabled", False),
|
||||
skill=args.get("skill"),
|
||||
skills=args.get("skills"),
|
||||
model=args.get("model"),
|
||||
provider=args.get("provider"),
|
||||
base_url=args.get("base_url"),
|
||||
reason=args.get("reason"),
|
||||
task_id=kw.get("task_id"),
|
||||
),
|
||||
check_fn=check_cronjob_requirements,
|
||||
emoji="⏰",
|
||||
)
|
||||
104
hermes_code/tools/debug_helpers.py
Normal file
104
hermes_code/tools/debug_helpers.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Shared debug session infrastructure for Hermes tools.
|
||||
|
||||
Replaces the identical DEBUG_MODE / _log_debug_call / _save_debug_log /
|
||||
get_debug_session_info boilerplate previously duplicated across web_tools,
|
||||
vision_tools, mixture_of_agents_tool, and image_generation_tool.
|
||||
|
||||
Usage in a tool module:
|
||||
|
||||
from tools.debug_helpers import DebugSession
|
||||
|
||||
_debug = DebugSession("web_tools", env_var="WEB_TOOLS_DEBUG")
|
||||
|
||||
# Log a call (no-op when debug mode is off)
|
||||
_debug.log_call("web_search", {"query": q, "results": len(r)})
|
||||
|
||||
# Save the debug log (no-op when debug mode is off)
|
||||
_debug.save()
|
||||
|
||||
# Expose debug info to external callers
|
||||
def get_debug_session_info():
|
||||
return _debug.get_session_info()
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DebugSession:
|
||||
"""Per-tool debug session that records tool calls to a JSON log file.
|
||||
|
||||
Activated by a tool-specific environment variable (e.g. WEB_TOOLS_DEBUG=true).
|
||||
When disabled, all methods are cheap no-ops.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_name: str, *, env_var: str) -> None:
|
||||
self.tool_name = tool_name
|
||||
self.enabled = os.getenv(env_var, "false").lower() == "true"
|
||||
self.session_id = str(uuid.uuid4()) if self.enabled else ""
|
||||
self.log_dir = Path("./logs")
|
||||
self._calls: list[Dict[str, Any]] = []
|
||||
self._start_time = datetime.datetime.now().isoformat() if self.enabled else ""
|
||||
|
||||
if self.enabled:
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
logger.debug("%s debug mode enabled - Session ID: %s",
|
||||
tool_name, self.session_id)
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
return self.enabled
|
||||
|
||||
def log_call(self, call_name: str, call_data: Dict[str, Any]) -> None:
|
||||
"""Append a tool-call entry to the in-memory log."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self._calls.append({
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": call_name,
|
||||
**call_data,
|
||||
})
|
||||
|
||||
def save(self) -> None:
|
||||
"""Flush the in-memory log to a JSON file in the logs directory."""
|
||||
if not self.enabled:
|
||||
return
|
||||
try:
|
||||
filename = f"{self.tool_name}_debug_{self.session_id}.json"
|
||||
filepath = self.log_dir / filename
|
||||
payload = {
|
||||
"session_id": self.session_id,
|
||||
"start_time": self._start_time,
|
||||
"end_time": datetime.datetime.now().isoformat(),
|
||||
"debug_enabled": True,
|
||||
"total_calls": len(self._calls),
|
||||
"tool_calls": self._calls,
|
||||
}
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2, ensure_ascii=False)
|
||||
logger.debug("%s debug log saved: %s", self.tool_name, filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error saving %s debug log: %s", self.tool_name, e)
|
||||
|
||||
def get_session_info(self) -> Dict[str, Any]:
|
||||
"""Return a summary dict suitable for returning from get_debug_session_info()."""
|
||||
if not self.enabled:
|
||||
return {
|
||||
"enabled": False,
|
||||
"session_id": None,
|
||||
"log_path": None,
|
||||
"total_calls": 0,
|
||||
}
|
||||
return {
|
||||
"enabled": True,
|
||||
"session_id": self.session_id,
|
||||
"log_path": str(self.log_dir / f"{self.tool_name}_debug_{self.session_id}.json"),
|
||||
"total_calls": len(self._calls),
|
||||
}
|
||||
789
hermes_code/tools/delegate_tool.py
Normal file
789
hermes_code/tools/delegate_tool.py
Normal file
|
|
@ -0,0 +1,789 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Delegate Tool -- Subagent Architecture
|
||||
|
||||
Spawns child AIAgent instances with isolated context, restricted toolsets,
|
||||
and their own terminal sessions. Supports single-task and batch (parallel)
|
||||
modes. The parent blocks until all children complete.
|
||||
|
||||
Each child gets:
|
||||
- A fresh conversation (no parent history)
|
||||
- Its own task_id (own terminal session, file ops cache)
|
||||
- A restricted toolset (configurable, with blocked tools always stripped)
|
||||
- A focused system prompt built from the delegated goal + context
|
||||
|
||||
The parent's context only sees the delegation call and the summary result,
|
||||
never the child's intermediate tool calls or reasoning.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
# Tools that children must never have access to
|
||||
DELEGATE_BLOCKED_TOOLS = frozenset([
|
||||
"delegate_task", # no recursive delegation
|
||||
"clarify", # no user interaction
|
||||
"memory", # no writes to shared MEMORY.md
|
||||
"send_message", # no cross-platform side effects
|
||||
"execute_code", # children should reason step-by-step, not write scripts
|
||||
])
|
||||
|
||||
MAX_CONCURRENT_CHILDREN = 3
|
||||
MAX_DEPTH = 2 # parent (0) -> child (1) -> grandchild rejected (2)
|
||||
DEFAULT_MAX_ITERATIONS = 50
|
||||
DEFAULT_TOOLSETS = ["terminal", "file", "web"]
|
||||
|
||||
|
||||
def check_delegate_requirements() -> bool:
|
||||
"""Delegation has no external requirements -- always available."""
|
||||
return True
|
||||
|
||||
|
||||
def _build_child_system_prompt(goal: str, context: Optional[str] = None) -> str:
|
||||
"""Build a focused system prompt for a child agent."""
|
||||
parts = [
|
||||
"You are a focused subagent working on a specific delegated task.",
|
||||
"",
|
||||
f"YOUR TASK:\n{goal}",
|
||||
]
|
||||
if context and context.strip():
|
||||
parts.append(f"\nCONTEXT:\n{context}")
|
||||
parts.append(
|
||||
"\nComplete this task using the tools available to you. "
|
||||
"When finished, provide a clear, concise summary of:\n"
|
||||
"- What you did\n"
|
||||
"- What you found or accomplished\n"
|
||||
"- Any files you created or modified\n"
|
||||
"- Any issues encountered\n\n"
|
||||
"Be thorough but concise -- your response is returned to the "
|
||||
"parent agent as a summary."
|
||||
)
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _strip_blocked_tools(toolsets: List[str]) -> List[str]:
|
||||
"""Remove toolsets that contain only blocked tools."""
|
||||
blocked_toolset_names = {
|
||||
"delegation", "clarify", "memory", "code_execution",
|
||||
}
|
||||
return [t for t in toolsets if t not in blocked_toolset_names]
|
||||
|
||||
|
||||
def _build_child_progress_callback(task_index: int, parent_agent, task_count: int = 1) -> Optional[callable]:
|
||||
"""Build a callback that relays child agent tool calls to the parent display.
|
||||
|
||||
Two display paths:
|
||||
CLI: prints tree-view lines above the parent's delegation spinner
|
||||
Gateway: batches tool names and relays to parent's progress callback
|
||||
|
||||
Returns None if no display mechanism is available, in which case the
|
||||
child agent runs with no progress callback (identical to current behavior).
|
||||
"""
|
||||
spinner = getattr(parent_agent, '_delegate_spinner', None)
|
||||
parent_cb = getattr(parent_agent, 'tool_progress_callback', None)
|
||||
|
||||
if not spinner and not parent_cb:
|
||||
return None # No display → no callback → zero behavior change
|
||||
|
||||
# Show 1-indexed prefix only in batch mode (multiple tasks)
|
||||
prefix = f"[{task_index + 1}] " if task_count > 1 else ""
|
||||
|
||||
# Gateway: batch tool names, flush periodically
|
||||
_BATCH_SIZE = 5
|
||||
_batch: List[str] = []
|
||||
|
||||
def _callback(tool_name: str, preview: str = None):
|
||||
# Special "_thinking" event: model produced text content (reasoning)
|
||||
if tool_name == "_thinking":
|
||||
if spinner:
|
||||
short = (preview[:55] + "...") if preview and len(preview) > 55 else (preview or "")
|
||||
try:
|
||||
spinner.print_above(f" {prefix}├─ 💭 \"{short}\"")
|
||||
except Exception as e:
|
||||
logger.debug("Spinner print_above failed: %s", e)
|
||||
# Don't relay thinking to gateway (too noisy for chat)
|
||||
return
|
||||
|
||||
# Regular tool call event
|
||||
if spinner:
|
||||
short = (preview[:35] + "...") if preview and len(preview) > 35 else (preview or "")
|
||||
from agent.display import get_tool_emoji
|
||||
emoji = get_tool_emoji(tool_name)
|
||||
line = f" {prefix}├─ {emoji} {tool_name}"
|
||||
if short:
|
||||
line += f" \"{short}\""
|
||||
try:
|
||||
spinner.print_above(line)
|
||||
except Exception as e:
|
||||
logger.debug("Spinner print_above failed: %s", e)
|
||||
|
||||
if parent_cb:
|
||||
_batch.append(tool_name)
|
||||
if len(_batch) >= _BATCH_SIZE:
|
||||
summary = ", ".join(_batch)
|
||||
try:
|
||||
parent_cb("subagent_progress", f"🔀 {prefix}{summary}")
|
||||
except Exception as e:
|
||||
logger.debug("Parent callback failed: %s", e)
|
||||
_batch.clear()
|
||||
|
||||
def _flush():
|
||||
"""Flush remaining batched tool names to gateway on completion."""
|
||||
if parent_cb and _batch:
|
||||
summary = ", ".join(_batch)
|
||||
try:
|
||||
parent_cb("subagent_progress", f"🔀 {prefix}{summary}")
|
||||
except Exception as e:
|
||||
logger.debug("Parent callback flush failed: %s", e)
|
||||
_batch.clear()
|
||||
|
||||
_callback._flush = _flush
|
||||
return _callback
|
||||
|
||||
|
||||
def _build_child_agent(
|
||||
task_index: int,
|
||||
goal: str,
|
||||
context: Optional[str],
|
||||
toolsets: Optional[List[str]],
|
||||
model: Optional[str],
|
||||
max_iterations: int,
|
||||
parent_agent,
|
||||
# Credential overrides from delegation config (provider:model resolution)
|
||||
override_provider: Optional[str] = None,
|
||||
override_base_url: Optional[str] = None,
|
||||
override_api_key: Optional[str] = None,
|
||||
override_api_mode: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Build a child AIAgent on the main thread (thread-safe construction).
|
||||
Returns the constructed child agent without running it.
|
||||
|
||||
When override_* params are set (from delegation config), the child uses
|
||||
those credentials instead of inheriting from the parent. This enables
|
||||
routing subagents to a different provider:model pair (e.g. cheap/fast
|
||||
model on OpenRouter while the parent runs on Nous Portal).
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
import model_tools
|
||||
|
||||
# When no explicit toolsets given, inherit from parent's enabled toolsets
|
||||
# so disabled tools (e.g. web) don't leak to subagents.
|
||||
if toolsets:
|
||||
child_toolsets = _strip_blocked_tools(toolsets)
|
||||
elif parent_agent and getattr(parent_agent, "enabled_toolsets", None):
|
||||
child_toolsets = _strip_blocked_tools(parent_agent.enabled_toolsets)
|
||||
else:
|
||||
child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS)
|
||||
|
||||
child_prompt = _build_child_system_prompt(goal, context)
|
||||
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
|
||||
parent_api_key = getattr(parent_agent, "api_key", None)
|
||||
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
|
||||
parent_api_key = parent_agent._client_kwargs.get("api_key")
|
||||
|
||||
# Build progress callback to relay tool calls to parent display
|
||||
child_progress_cb = _build_child_progress_callback(task_index, parent_agent)
|
||||
|
||||
# Share the parent's iteration budget so subagent tool calls
|
||||
# count toward the session-wide limit.
|
||||
shared_budget = getattr(parent_agent, "iteration_budget", None)
|
||||
|
||||
# Resolve effective credentials: config override > parent inherit
|
||||
effective_model = model or parent_agent.model
|
||||
effective_provider = override_provider or getattr(parent_agent, "provider", None)
|
||||
effective_base_url = override_base_url or parent_agent.base_url
|
||||
effective_api_key = override_api_key or parent_api_key
|
||||
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None)
|
||||
effective_acp_command = getattr(parent_agent, "acp_command", None)
|
||||
effective_acp_args = list(getattr(parent_agent, "acp_args", []) or [])
|
||||
|
||||
child = AIAgent(
|
||||
base_url=effective_base_url,
|
||||
api_key=effective_api_key,
|
||||
model=effective_model,
|
||||
provider=effective_provider,
|
||||
api_mode=effective_api_mode,
|
||||
acp_command=effective_acp_command,
|
||||
acp_args=effective_acp_args,
|
||||
max_iterations=max_iterations,
|
||||
max_tokens=getattr(parent_agent, "max_tokens", None),
|
||||
reasoning_config=getattr(parent_agent, "reasoning_config", None),
|
||||
prefill_messages=getattr(parent_agent, "prefill_messages", None),
|
||||
enabled_toolsets=child_toolsets,
|
||||
quiet_mode=True,
|
||||
ephemeral_system_prompt=child_prompt,
|
||||
log_prefix=f"[subagent-{task_index}]",
|
||||
platform=parent_agent.platform,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
clarify_callback=None,
|
||||
session_db=getattr(parent_agent, '_session_db', None),
|
||||
providers_allowed=parent_agent.providers_allowed,
|
||||
providers_ignored=parent_agent.providers_ignored,
|
||||
providers_order=parent_agent.providers_order,
|
||||
provider_sort=parent_agent.provider_sort,
|
||||
tool_progress_callback=child_progress_cb,
|
||||
iteration_budget=shared_budget,
|
||||
)
|
||||
# Set delegation depth so children can't spawn grandchildren
|
||||
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
||||
|
||||
# Register child for interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
lock = getattr(parent_agent, '_active_children_lock', None)
|
||||
if lock:
|
||||
with lock:
|
||||
parent_agent._active_children.append(child)
|
||||
else:
|
||||
parent_agent._active_children.append(child)
|
||||
|
||||
return child
|
||||
|
||||
def _run_single_child(
|
||||
task_index: int,
|
||||
goal: str,
|
||||
child=None,
|
||||
parent_agent=None,
|
||||
**_kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a pre-built child agent. Called from within a thread.
|
||||
Returns a structured result dict.
|
||||
"""
|
||||
child_start = time.monotonic()
|
||||
|
||||
# Get the progress callback from the child agent
|
||||
child_progress_cb = getattr(child, 'tool_progress_callback', None)
|
||||
|
||||
# Restore parent tool names using the value saved before child construction
|
||||
# mutated the global. This is the correct parent toolset, not the child's.
|
||||
import model_tools
|
||||
_saved_tool_names = getattr(child, "_delegate_saved_tool_names",
|
||||
list(model_tools._last_resolved_tool_names))
|
||||
|
||||
try:
|
||||
result = child.run_conversation(user_message=goal)
|
||||
|
||||
# Flush any remaining batched progress to gateway
|
||||
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
|
||||
try:
|
||||
child_progress_cb._flush()
|
||||
except Exception as e:
|
||||
logger.debug("Progress callback flush failed: %s", e)
|
||||
|
||||
duration = round(time.monotonic() - child_start, 2)
|
||||
|
||||
summary = result.get("final_response") or ""
|
||||
completed = result.get("completed", False)
|
||||
interrupted = result.get("interrupted", False)
|
||||
api_calls = result.get("api_calls", 0)
|
||||
|
||||
if interrupted:
|
||||
status = "interrupted"
|
||||
elif completed and summary:
|
||||
status = "completed"
|
||||
else:
|
||||
status = "failed"
|
||||
|
||||
# Build tool trace from conversation messages (already in memory).
|
||||
# Uses tool_call_id to correctly pair parallel tool calls with results.
|
||||
tool_trace: list[Dict[str, Any]] = []
|
||||
trace_by_id: Dict[str, Dict[str, Any]] = {}
|
||||
messages = result.get("messages") or []
|
||||
if isinstance(messages, list):
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in (msg.get("tool_calls") or []):
|
||||
fn = tc.get("function", {})
|
||||
entry_t = {
|
||||
"tool": fn.get("name", "unknown"),
|
||||
"args_bytes": len(fn.get("arguments", "")),
|
||||
}
|
||||
tool_trace.append(entry_t)
|
||||
tc_id = tc.get("id")
|
||||
if tc_id:
|
||||
trace_by_id[tc_id] = entry_t
|
||||
elif msg.get("role") == "tool":
|
||||
content = msg.get("content", "")
|
||||
is_error = bool(
|
||||
content and "error" in content[:80].lower()
|
||||
)
|
||||
result_meta = {
|
||||
"result_bytes": len(content),
|
||||
"status": "error" if is_error else "ok",
|
||||
}
|
||||
# Match by tool_call_id for parallel calls
|
||||
tc_id = msg.get("tool_call_id")
|
||||
target = trace_by_id.get(tc_id) if tc_id else None
|
||||
if target is not None:
|
||||
target.update(result_meta)
|
||||
elif tool_trace:
|
||||
# Fallback for messages without tool_call_id
|
||||
tool_trace[-1].update(result_meta)
|
||||
|
||||
# Determine exit reason
|
||||
if interrupted:
|
||||
exit_reason = "interrupted"
|
||||
elif completed:
|
||||
exit_reason = "completed"
|
||||
else:
|
||||
exit_reason = "max_iterations"
|
||||
|
||||
# Extract token counts (safe for mock objects)
|
||||
_input_tokens = getattr(child, "session_prompt_tokens", 0)
|
||||
_output_tokens = getattr(child, "session_completion_tokens", 0)
|
||||
_model = getattr(child, "model", None)
|
||||
|
||||
entry: Dict[str, Any] = {
|
||||
"task_index": task_index,
|
||||
"status": status,
|
||||
"summary": summary,
|
||||
"api_calls": api_calls,
|
||||
"duration_seconds": duration,
|
||||
"model": _model if isinstance(_model, str) else None,
|
||||
"exit_reason": exit_reason,
|
||||
"tokens": {
|
||||
"input": _input_tokens if isinstance(_input_tokens, (int, float)) else 0,
|
||||
"output": _output_tokens if isinstance(_output_tokens, (int, float)) else 0,
|
||||
},
|
||||
"tool_trace": tool_trace,
|
||||
}
|
||||
if status == "failed":
|
||||
entry["error"] = result.get("error", "Subagent did not produce a response.")
|
||||
|
||||
return entry
|
||||
|
||||
except Exception as exc:
|
||||
duration = round(time.monotonic() - child_start, 2)
|
||||
logging.exception(f"[subagent-{task_index}] failed")
|
||||
return {
|
||||
"task_index": task_index,
|
||||
"status": "error",
|
||||
"summary": None,
|
||||
"error": str(exc),
|
||||
"api_calls": 0,
|
||||
"duration_seconds": duration,
|
||||
}
|
||||
|
||||
finally:
|
||||
# Restore the parent's tool names so the process-global is correct
|
||||
# for any subsequent execute_code calls or other consumers.
|
||||
import model_tools
|
||||
|
||||
saved_tool_names = getattr(child, "_delegate_saved_tool_names", None)
|
||||
if isinstance(saved_tool_names, list):
|
||||
model_tools._last_resolved_tool_names = list(saved_tool_names)
|
||||
|
||||
# Unregister child from interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
try:
|
||||
lock = getattr(parent_agent, '_active_children_lock', None)
|
||||
if lock:
|
||||
with lock:
|
||||
parent_agent._active_children.remove(child)
|
||||
else:
|
||||
parent_agent._active_children.remove(child)
|
||||
except (ValueError, UnboundLocalError) as e:
|
||||
logger.debug("Could not remove child from active_children: %s", e)
|
||||
|
||||
def delegate_task(
|
||||
goal: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
toolsets: Optional[List[str]] = None,
|
||||
tasks: Optional[List[Dict[str, Any]]] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
parent_agent=None,
|
||||
) -> str:
|
||||
"""
|
||||
Spawn one or more child agents to handle delegated tasks.
|
||||
|
||||
Supports two modes:
|
||||
- Single: provide goal (+ optional context, toolsets)
|
||||
- Batch: provide tasks array [{goal, context, toolsets}, ...]
|
||||
|
||||
Returns JSON with results array, one entry per task.
|
||||
"""
|
||||
if parent_agent is None:
|
||||
return json.dumps({"error": "delegate_task requires a parent agent context."})
|
||||
|
||||
# Depth limit
|
||||
depth = getattr(parent_agent, '_delegate_depth', 0)
|
||||
if depth >= MAX_DEPTH:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Delegation depth limit reached ({MAX_DEPTH}). "
|
||||
"Subagents cannot spawn further subagents."
|
||||
)
|
||||
})
|
||||
|
||||
# Load config
|
||||
cfg = _load_config()
|
||||
default_max_iter = cfg.get("max_iterations", DEFAULT_MAX_ITERATIONS)
|
||||
effective_max_iter = max_iterations or default_max_iter
|
||||
|
||||
# Resolve delegation credentials (provider:model pair).
|
||||
# When delegation.provider is configured, this resolves the full credential
|
||||
# bundle (base_url, api_key, api_mode) via the same runtime provider system
|
||||
# used by CLI/gateway startup. When unconfigured, returns None values so
|
||||
# children inherit from the parent.
|
||||
try:
|
||||
creds = _resolve_delegation_credentials(cfg, parent_agent)
|
||||
except ValueError as exc:
|
||||
return json.dumps({"error": str(exc)})
|
||||
|
||||
# Normalize to task list
|
||||
if tasks and isinstance(tasks, list):
|
||||
task_list = tasks[:MAX_CONCURRENT_CHILDREN]
|
||||
elif goal and isinstance(goal, str) and goal.strip():
|
||||
task_list = [{"goal": goal, "context": context, "toolsets": toolsets}]
|
||||
else:
|
||||
return json.dumps({"error": "Provide either 'goal' (single task) or 'tasks' (batch)."})
|
||||
|
||||
if not task_list:
|
||||
return json.dumps({"error": "No tasks provided."})
|
||||
|
||||
# Validate each task has a goal
|
||||
for i, task in enumerate(task_list):
|
||||
if not task.get("goal", "").strip():
|
||||
return json.dumps({"error": f"Task {i} is missing a 'goal'."})
|
||||
|
||||
overall_start = time.monotonic()
|
||||
results = []
|
||||
|
||||
n_tasks = len(task_list)
|
||||
# Track goal labels for progress display (truncated for readability)
|
||||
task_labels = [t["goal"][:40] for t in task_list]
|
||||
|
||||
# Save parent tool names BEFORE any child construction mutates the global.
|
||||
# _build_child_agent() calls AIAgent() which calls get_tool_definitions(),
|
||||
# which overwrites model_tools._last_resolved_tool_names with child's toolset.
|
||||
import model_tools as _model_tools
|
||||
_parent_tool_names = list(_model_tools._last_resolved_tool_names)
|
||||
|
||||
# Build all child agents on the main thread (thread-safe construction)
|
||||
# Wrapped in try/finally so the global is always restored even if a
|
||||
# child build raises (otherwise _last_resolved_tool_names stays corrupted).
|
||||
children = []
|
||||
try:
|
||||
for i, t in enumerate(task_list):
|
||||
child = _build_child_agent(
|
||||
task_index=i, goal=t["goal"], context=t.get("context"),
|
||||
toolsets=t.get("toolsets") or toolsets, model=creds["model"],
|
||||
max_iterations=effective_max_iter, parent_agent=parent_agent,
|
||||
override_provider=creds["provider"], override_base_url=creds["base_url"],
|
||||
override_api_key=creds["api_key"],
|
||||
override_api_mode=creds["api_mode"],
|
||||
)
|
||||
# Override with correct parent tool names (before child construction mutated global)
|
||||
child._delegate_saved_tool_names = _parent_tool_names
|
||||
children.append((i, t, child))
|
||||
finally:
|
||||
# Authoritative restore: reset global to parent's tool names after all children built
|
||||
_model_tools._last_resolved_tool_names = _parent_tool_names
|
||||
|
||||
if n_tasks == 1:
|
||||
# Single task -- run directly (no thread pool overhead)
|
||||
_i, _t, child = children[0]
|
||||
result = _run_single_child(0, _t["goal"], child, parent_agent)
|
||||
results.append(result)
|
||||
else:
|
||||
# Batch -- run in parallel with per-task progress lines
|
||||
completed_count = 0
|
||||
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor:
|
||||
futures = {}
|
||||
for i, t, child in children:
|
||||
future = executor.submit(
|
||||
_run_single_child,
|
||||
task_index=i,
|
||||
goal=t["goal"],
|
||||
child=child,
|
||||
parent_agent=parent_agent,
|
||||
)
|
||||
futures[future] = i
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
entry = future.result()
|
||||
except Exception as exc:
|
||||
idx = futures[future]
|
||||
entry = {
|
||||
"task_index": idx,
|
||||
"status": "error",
|
||||
"summary": None,
|
||||
"error": str(exc),
|
||||
"api_calls": 0,
|
||||
"duration_seconds": 0,
|
||||
}
|
||||
results.append(entry)
|
||||
completed_count += 1
|
||||
|
||||
# Print per-task completion line above the spinner
|
||||
idx = entry["task_index"]
|
||||
label = task_labels[idx] if idx < len(task_labels) else f"Task {idx}"
|
||||
dur = entry.get("duration_seconds", 0)
|
||||
status = entry.get("status", "?")
|
||||
icon = "✓" if status == "completed" else "✗"
|
||||
remaining = n_tasks - completed_count
|
||||
completion_line = f"{icon} [{idx+1}/{n_tasks}] {label} ({dur}s)"
|
||||
if spinner_ref:
|
||||
try:
|
||||
spinner_ref.print_above(completion_line)
|
||||
except Exception:
|
||||
print(f" {completion_line}")
|
||||
else:
|
||||
print(f" {completion_line}")
|
||||
|
||||
# Update spinner text to show remaining count
|
||||
if spinner_ref and remaining > 0:
|
||||
try:
|
||||
spinner_ref.update_text(f"🔀 {remaining} task{'s' if remaining != 1 else ''} remaining")
|
||||
except Exception as e:
|
||||
logger.debug("Spinner update_text failed: %s", e)
|
||||
|
||||
# Sort by task_index so results match input order
|
||||
results.sort(key=lambda r: r["task_index"])
|
||||
|
||||
total_duration = round(time.monotonic() - overall_start, 2)
|
||||
|
||||
return json.dumps({
|
||||
"results": results,
|
||||
"total_duration_seconds": total_duration,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict:
|
||||
"""Resolve credentials for subagent delegation.
|
||||
|
||||
If ``delegation.base_url`` is configured, subagents use that direct
|
||||
OpenAI-compatible endpoint. Otherwise, if ``delegation.provider`` is
|
||||
configured, the full credential bundle (base_url, api_key, api_mode,
|
||||
provider) is resolved via the runtime provider system — the same path used
|
||||
by CLI/gateway startup. This lets subagents run on a completely different
|
||||
provider:model pair.
|
||||
|
||||
If neither base_url nor provider is configured, returns None values so the
|
||||
child inherits everything from the parent agent.
|
||||
|
||||
Raises ValueError with a user-friendly message on credential failure.
|
||||
"""
|
||||
configured_model = str(cfg.get("model") or "").strip() or None
|
||||
configured_provider = str(cfg.get("provider") or "").strip() or None
|
||||
configured_base_url = str(cfg.get("base_url") or "").strip() or None
|
||||
configured_api_key = str(cfg.get("api_key") or "").strip() or None
|
||||
|
||||
if configured_base_url:
|
||||
api_key = (
|
||||
configured_api_key
|
||||
or os.getenv("OPENAI_API_KEY", "").strip()
|
||||
)
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Delegation base_url is configured but no API key was found. "
|
||||
"Set delegation.api_key or OPENAI_API_KEY."
|
||||
)
|
||||
|
||||
base_lower = configured_base_url.lower()
|
||||
provider = "custom"
|
||||
api_mode = "chat_completions"
|
||||
if "chatgpt.com/backend-api/codex" in base_lower:
|
||||
provider = "openai-codex"
|
||||
api_mode = "codex_responses"
|
||||
elif "api.anthropic.com" in base_lower:
|
||||
provider = "anthropic"
|
||||
api_mode = "anthropic_messages"
|
||||
|
||||
return {
|
||||
"model": configured_model,
|
||||
"provider": provider,
|
||||
"base_url": configured_base_url,
|
||||
"api_key": api_key,
|
||||
"api_mode": api_mode,
|
||||
}
|
||||
|
||||
if not configured_provider:
|
||||
# No provider override — child inherits everything from parent
|
||||
return {
|
||||
"model": configured_model,
|
||||
"provider": None,
|
||||
"base_url": None,
|
||||
"api_key": None,
|
||||
"api_mode": None,
|
||||
}
|
||||
|
||||
# Provider is configured — resolve full credentials
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=configured_provider)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Cannot resolve delegation provider '{configured_provider}': {exc}. "
|
||||
f"Check that the provider is configured (API key set, valid provider name), "
|
||||
f"or set delegation.base_url/delegation.api_key for a direct endpoint. "
|
||||
f"Available providers: openrouter, nous, zai, kimi-coding, minimax."
|
||||
) from exc
|
||||
|
||||
api_key = runtime.get("api_key", "")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
f"Delegation provider '{configured_provider}' resolved but has no API key. "
|
||||
f"Set the appropriate environment variable or run 'hermes login'."
|
||||
)
|
||||
|
||||
return {
|
||||
"model": configured_model,
|
||||
"provider": runtime.get("provider"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
"api_key": api_key,
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
"command": runtime.get("command"),
|
||||
"args": list(runtime.get("args") or []),
|
||||
}
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Load delegation config from CLI_CONFIG or persistent config.
|
||||
|
||||
Checks the runtime config (cli.py CLI_CONFIG) first, then falls back
|
||||
to the persistent config (hermes_cli/config.py load_config()) so that
|
||||
``delegation.model`` / ``delegation.provider`` are picked up regardless
|
||||
of the entry point (CLI, gateway, cron).
|
||||
"""
|
||||
try:
|
||||
from cli import CLI_CONFIG
|
||||
cfg = CLI_CONFIG.get("delegation", {})
|
||||
if cfg:
|
||||
return cfg
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
full = load_config()
|
||||
return full.get("delegation", {})
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI Function-Calling Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DELEGATE_TASK_SCHEMA = {
|
||||
"name": "delegate_task",
|
||||
"description": (
|
||||
"Spawn one or more subagents to work on tasks in isolated contexts. "
|
||||
"Each subagent gets its own conversation, terminal session, and toolset. "
|
||||
"Only the final summary is returned -- intermediate tool results "
|
||||
"never enter your context window.\n\n"
|
||||
"TWO MODES (one of 'goal' or 'tasks' is required):\n"
|
||||
"1. Single task: provide 'goal' (+ optional context, toolsets)\n"
|
||||
"2. Batch (parallel): provide 'tasks' array with up to 3 items. "
|
||||
"All run concurrently and results are returned together.\n\n"
|
||||
"WHEN TO USE delegate_task:\n"
|
||||
"- Reasoning-heavy subtasks (debugging, code review, research synthesis)\n"
|
||||
"- Tasks that would flood your context with intermediate data\n"
|
||||
"- Parallel independent workstreams (research A and B simultaneously)\n\n"
|
||||
"WHEN NOT TO USE (use these instead):\n"
|
||||
"- Mechanical multi-step work with no reasoning needed -> use execute_code\n"
|
||||
"- Single tool call -> just call the tool directly\n"
|
||||
"- Tasks needing user interaction -> subagents cannot use clarify\n\n"
|
||||
"IMPORTANT:\n"
|
||||
"- Subagents have NO memory of your conversation. Pass all relevant "
|
||||
"info (file paths, error messages, constraints) via the 'context' field.\n"
|
||||
"- Subagents CANNOT call: delegate_task, clarify, memory, send_message, "
|
||||
"execute_code.\n"
|
||||
"- Each subagent gets its own terminal session (separate working directory and state).\n"
|
||||
"- Results are always returned as an array, one entry per task."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"goal": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"What the subagent should accomplish. Be specific and "
|
||||
"self-contained -- the subagent knows nothing about your "
|
||||
"conversation history."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Background information the subagent needs: file paths, "
|
||||
"error messages, project structure, constraints. The more "
|
||||
"specific you are, the better the subagent performs."
|
||||
),
|
||||
},
|
||||
"toolsets": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Toolsets to enable for this subagent. "
|
||||
"Default: inherits your enabled toolsets. "
|
||||
"Common patterns: ['terminal', 'file'] for code work, "
|
||||
"['web'] for research, ['terminal', 'file', 'web'] for "
|
||||
"full-stack tasks."
|
||||
),
|
||||
},
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"goal": {"type": "string", "description": "Task goal"},
|
||||
"context": {"type": "string", "description": "Task-specific context"},
|
||||
"toolsets": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Toolsets for this specific task",
|
||||
},
|
||||
},
|
||||
"required": ["goal"],
|
||||
},
|
||||
"maxItems": 3,
|
||||
"description": (
|
||||
"Batch mode: up to 3 tasks to run in parallel. Each gets "
|
||||
"its own subagent with isolated context and terminal session. "
|
||||
"When provided, top-level goal/context/toolsets are ignored."
|
||||
),
|
||||
},
|
||||
"max_iterations": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max tool-calling turns per subagent (default: 50). "
|
||||
"Only set lower for simple tasks."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="delegate_task",
|
||||
toolset="delegation",
|
||||
schema=DELEGATE_TASK_SCHEMA,
|
||||
handler=lambda args, **kw: delegate_task(
|
||||
goal=args.get("goal"),
|
||||
context=args.get("context"),
|
||||
toolsets=args.get("toolsets"),
|
||||
tasks=args.get("tasks"),
|
||||
max_iterations=args.get("max_iterations"),
|
||||
parent_agent=kw.get("parent_agent")),
|
||||
check_fn=check_delegate_requirements,
|
||||
emoji="🔀",
|
||||
)
|
||||
99
hermes_code/tools/env_passthrough.py
Normal file
99
hermes_code/tools/env_passthrough.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Environment variable passthrough registry.
|
||||
|
||||
Skills that declare ``required_environment_variables`` in their frontmatter
|
||||
need those vars available in sandboxed execution environments (execute_code,
|
||||
terminal). By default both sandboxes strip secrets from the child process
|
||||
environment for security. This module provides a session-scoped allowlist
|
||||
so skill-declared vars (and user-configured overrides) pass through.
|
||||
|
||||
Two sources feed the allowlist:
|
||||
|
||||
1. **Skill declarations** — when a skill is loaded via ``skill_view``, its
|
||||
``required_environment_variables`` are registered here automatically.
|
||||
2. **User config** — ``terminal.env_passthrough`` in config.yaml lets users
|
||||
explicitly allowlist vars for non-skill use cases.
|
||||
|
||||
Both ``code_execution_tool.py`` and ``tools/environments/local.py`` consult
|
||||
:func:`is_env_passthrough` before stripping a variable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Session-scoped set of env var names that should pass through to sandboxes.
|
||||
_allowed_env_vars: set[str] = set()
|
||||
|
||||
# Cache for the config-based allowlist (loaded once per process).
|
||||
_config_passthrough: frozenset[str] | None = None
|
||||
|
||||
|
||||
def register_env_passthrough(var_names: Iterable[str]) -> None:
|
||||
"""Register environment variable names as allowed in sandboxed environments.
|
||||
|
||||
Typically called when a skill declares ``required_environment_variables``.
|
||||
"""
|
||||
for name in var_names:
|
||||
name = name.strip()
|
||||
if name:
|
||||
_allowed_env_vars.add(name)
|
||||
logger.debug("env passthrough: registered %s", name)
|
||||
|
||||
|
||||
def _load_config_passthrough() -> frozenset[str]:
|
||||
"""Load ``tools.env_passthrough`` from config.yaml (cached)."""
|
||||
global _config_passthrough
|
||||
if _config_passthrough is not None:
|
||||
return _config_passthrough
|
||||
|
||||
result: set[str] = set()
|
||||
try:
|
||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
config_path = hermes_home / "config.yaml"
|
||||
if config_path.exists():
|
||||
import yaml
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
passthrough = cfg.get("terminal", {}).get("env_passthrough")
|
||||
if isinstance(passthrough, list):
|
||||
for item in passthrough:
|
||||
if isinstance(item, str) and item.strip():
|
||||
result.add(item.strip())
|
||||
except Exception as e:
|
||||
logger.debug("Could not read tools.env_passthrough from config: %s", e)
|
||||
|
||||
_config_passthrough = frozenset(result)
|
||||
return _config_passthrough
|
||||
|
||||
|
||||
def is_env_passthrough(var_name: str) -> bool:
|
||||
"""Check whether *var_name* is allowed to pass through to sandboxes.
|
||||
|
||||
Returns ``True`` if the variable was registered by a skill or listed in
|
||||
the user's ``tools.env_passthrough`` config.
|
||||
"""
|
||||
if var_name in _allowed_env_vars:
|
||||
return True
|
||||
return var_name in _load_config_passthrough()
|
||||
|
||||
|
||||
def get_all_passthrough() -> frozenset[str]:
|
||||
"""Return the union of skill-registered and config-based passthrough vars."""
|
||||
return frozenset(_allowed_env_vars) | _load_config_passthrough()
|
||||
|
||||
|
||||
def clear_env_passthrough() -> None:
|
||||
"""Reset the skill-scoped allowlist (e.g. on session reset)."""
|
||||
_allowed_env_vars.clear()
|
||||
|
||||
|
||||
def reset_config_cache() -> None:
|
||||
"""Force re-read of config on next access (for testing)."""
|
||||
global _config_passthrough
|
||||
_config_passthrough = None
|
||||
13
hermes_code/tools/environments/__init__.py
Normal file
13
hermes_code/tools/environments/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Hermes execution environment backends.
|
||||
|
||||
Each backend provides the same interface (BaseEnvironment ABC) for running
|
||||
shell commands in a specific execution context: local, Docker, Singularity,
|
||||
SSH, Modal, or Daytona.
|
||||
|
||||
The terminal_tool.py factory (_create_environment) selects the backend
|
||||
based on the TERMINAL_ENV configuration.
|
||||
"""
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
|
||||
__all__ = ["BaseEnvironment"]
|
||||
99
hermes_code/tools/environments/base.py
Normal file
99
hermes_code/tools/environments/base.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
"""Base class for all Hermes execution environment backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.config import get_hermes_home
|
||||
|
||||
|
||||
def get_sandbox_dir() -> Path:
|
||||
"""Return the host-side root for all sandbox storage (Docker workspaces,
|
||||
Singularity overlays/SIF cache, etc.).
|
||||
|
||||
Configurable via TERMINAL_SANDBOX_DIR. Defaults to {HERMES_HOME}/sandboxes/.
|
||||
"""
|
||||
custom = os.getenv("TERMINAL_SANDBOX_DIR")
|
||||
if custom:
|
||||
p = Path(custom)
|
||||
else:
|
||||
p = get_hermes_home() / "sandboxes"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p
|
||||
|
||||
|
||||
class BaseEnvironment(ABC):
|
||||
"""Common interface for all Hermes execution backends.
|
||||
|
||||
Subclasses implement execute() and cleanup(). Shared helpers eliminate
|
||||
duplicated subprocess boilerplate across backends.
|
||||
"""
|
||||
|
||||
def __init__(self, cwd: str, timeout: int, env: dict = None):
|
||||
self.cwd = cwd
|
||||
self.timeout = timeout
|
||||
self.env = env or {}
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""Release backend resources (container, instance, connection)."""
|
||||
...
|
||||
|
||||
def stop(self):
|
||||
"""Alias for cleanup (compat with older callers)."""
|
||||
self.cleanup()
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.cleanup()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared helpers (eliminate duplication across backends)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _prepare_command(self, command: str) -> tuple[str, str | None]:
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available.
|
||||
|
||||
Returns:
|
||||
(transformed_command, sudo_stdin) — see _transform_sudo_command
|
||||
for the full contract. Callers that drive a subprocess directly
|
||||
should prepend sudo_stdin (when not None) to any stdin_data they
|
||||
pass to Popen. Callers that embed stdin via heredoc (modal,
|
||||
daytona) handle sudo_stdin in their own execute() method.
|
||||
"""
|
||||
from tools.terminal_tool import _transform_sudo_command
|
||||
return _transform_sudo_command(command)
|
||||
|
||||
def _build_run_kwargs(self, timeout: int | None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Build common subprocess.run kwargs for non-interactive execution."""
|
||||
kw = {
|
||||
"text": True,
|
||||
"timeout": timeout or self.timeout,
|
||||
"encoding": "utf-8",
|
||||
"errors": "replace",
|
||||
"stdout": subprocess.PIPE,
|
||||
"stderr": subprocess.STDOUT,
|
||||
}
|
||||
if stdin_data is not None:
|
||||
kw["input"] = stdin_data
|
||||
else:
|
||||
kw["stdin"] = subprocess.DEVNULL
|
||||
return kw
|
||||
|
||||
def _timeout_result(self, timeout: int | None) -> dict:
|
||||
"""Standard return dict when a command times out."""
|
||||
return {
|
||||
"output": f"Command timed out after {timeout or self.timeout}s",
|
||||
"returncode": 124,
|
||||
}
|
||||
250
hermes_code/tools/environments/daytona.py
Normal file
250
hermes_code/tools/environments/daytona.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
"""Daytona cloud execution environment.
|
||||
|
||||
Uses the Daytona Python SDK to run commands in cloud sandboxes.
|
||||
Supports persistent sandboxes: when enabled, sandboxes are stopped on cleanup
|
||||
and resumed on next creation, preserving the filesystem across sessions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import math
|
||||
import shlex
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DaytonaEnvironment(BaseEnvironment):
|
||||
"""Daytona cloud sandbox execution backend.
|
||||
|
||||
Uses stopped/started sandbox lifecycle for filesystem persistence
|
||||
instead of snapshots, making it faster and stateless on the host.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "/home/daytona",
|
||||
timeout: int = 60,
|
||||
cpu: int = 1,
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
self._requested_cwd = cwd
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
from daytona import (
|
||||
Daytona,
|
||||
CreateSandboxFromImageParams,
|
||||
DaytonaError,
|
||||
Resources,
|
||||
SandboxState,
|
||||
)
|
||||
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._SandboxState = SandboxState
|
||||
self._daytona = Daytona()
|
||||
self._sandbox = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
memory_gib = max(1, math.ceil(memory / 1024))
|
||||
disk_gib = max(1, math.ceil(disk / 1024))
|
||||
if disk_gib > 10:
|
||||
warnings.warn(
|
||||
f"Daytona: requested disk ({disk_gib}GB) exceeds platform limit (10GB). "
|
||||
f"Capping to 10GB. Set container_disk: 10240 in config to silence this.",
|
||||
stacklevel=2,
|
||||
)
|
||||
disk_gib = 10
|
||||
resources = Resources(cpu=cpu, memory=memory_gib, disk=disk_gib)
|
||||
|
||||
labels = {"hermes_task_id": task_id}
|
||||
sandbox_name = f"hermes-{task_id}"
|
||||
|
||||
# Try to resume an existing sandbox for this task
|
||||
if self._persistent:
|
||||
# 1. Try name-based lookup (new path)
|
||||
try:
|
||||
self._sandbox = self._daytona.get(sandbox_name)
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: resumed sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
except DaytonaError:
|
||||
self._sandbox = None
|
||||
except Exception as e:
|
||||
logger.warning("Daytona: failed to resume sandbox for task %s: %s",
|
||||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# 2. Legacy fallback: find sandbox created before the naming migration
|
||||
if self._sandbox is None:
|
||||
try:
|
||||
page = self._daytona.list(labels=labels, page=1, limit=1)
|
||||
if page.items:
|
||||
self._sandbox = page.items[0]
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: resumed legacy sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
except Exception as e:
|
||||
logger.debug("Daytona: no legacy sandbox found for task %s: %s",
|
||||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# Create a fresh sandbox if we don't have one
|
||||
if self._sandbox is None:
|
||||
self._sandbox = self._daytona.create(
|
||||
CreateSandboxFromImageParams(
|
||||
image=image,
|
||||
name=sandbox_name,
|
||||
labels=labels,
|
||||
auto_stop_interval=0,
|
||||
resources=resources,
|
||||
)
|
||||
)
|
||||
logger.info("Daytona: created sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
|
||||
# Resolve cwd: detect actual home dir inside the sandbox
|
||||
if self._requested_cwd in ("~", "/home/daytona"):
|
||||
try:
|
||||
home = self._sandbox.process.exec("echo $HOME").result.strip()
|
||||
if home:
|
||||
self.cwd = home
|
||||
except Exception:
|
||||
pass # leave cwd as-is; sandbox will use its own default
|
||||
logger.info("Daytona: resolved cwd to %s", self.cwd)
|
||||
|
||||
def _ensure_sandbox_ready(self):
|
||||
"""Restart sandbox if it was stopped (e.g., by a previous interrupt)."""
|
||||
self._sandbox.refresh_data()
|
||||
if self._sandbox.state in (self._SandboxState.STOPPED, self._SandboxState.ARCHIVED):
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
|
||||
|
||||
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
|
||||
"""Run exec in a background thread with interrupt polling.
|
||||
|
||||
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
|
||||
server-side timeout is not enforced and the SDK has no client-side
|
||||
fallback), so we wrap the command with the shell ``timeout`` utility
|
||||
which reliably kills the process and returns exit code 124.
|
||||
"""
|
||||
# Wrap with shell `timeout` to enforce the deadline reliably.
|
||||
# Add a small buffer so the shell timeout fires before any SDK-level
|
||||
# timeout would, giving us a clean exit code 124.
|
||||
timed_command = f"timeout {timeout} sh -c {shlex.quote(exec_command)}"
|
||||
|
||||
result_holder: dict = {"value": None, "error": None}
|
||||
|
||||
def _run():
|
||||
try:
|
||||
response = self._sandbox.process.exec(
|
||||
timed_command, cwd=cwd,
|
||||
)
|
||||
result_holder["value"] = {
|
||||
"output": response.result or "",
|
||||
"returncode": response.exit_code,
|
||||
}
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
# Wait for timeout + generous buffer for network/SDK overhead
|
||||
deadline = time.monotonic() + timeout + 10
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.2)
|
||||
if is_interrupted():
|
||||
with self._lock:
|
||||
try:
|
||||
self._sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"output": "[Command interrupted - Daytona sandbox stopped]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
# Shell timeout didn't fire and SDK is hung — force stop
|
||||
with self._lock:
|
||||
try:
|
||||
self._sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
return self._timeout_result(timeout)
|
||||
|
||||
if result_holder["error"]:
|
||||
return {"error": result_holder["error"]}
|
||||
return result_holder["value"]
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: Optional[int] = None,
|
||||
stdin_data: Optional[str] = None) -> dict:
|
||||
with self._lock:
|
||||
self._ensure_sandbox_ready()
|
||||
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
# Daytona sandboxes execute commands via the Daytona SDK and cannot
|
||||
# pipe subprocess stdin directly the way a local Popen can. When a
|
||||
# sudo password is present, use a shell-level pipe from printf so that
|
||||
# the password feeds sudo -S without appearing as an echo argument
|
||||
# embedded in the shell string. The password is still visible in the
|
||||
# remote sandbox's command line, but it is not exposed on the user's
|
||||
# local machine — which is the primary threat being mitigated.
|
||||
if sudo_stdin is not None:
|
||||
import shlex
|
||||
exec_command = (
|
||||
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
|
||||
)
|
||||
effective_cwd = cwd or self.cwd or None
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
||||
|
||||
if "error" in result:
|
||||
from daytona import DaytonaError
|
||||
err = result["error"]
|
||||
if isinstance(err, DaytonaError):
|
||||
with self._lock:
|
||||
try:
|
||||
self._ensure_sandbox_ready()
|
||||
except Exception:
|
||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
||||
if "error" not in result:
|
||||
return result
|
||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
||||
|
||||
return result
|
||||
|
||||
def cleanup(self):
|
||||
with self._lock:
|
||||
if self._sandbox is None:
|
||||
return
|
||||
try:
|
||||
if self._persistent:
|
||||
self._sandbox.stop()
|
||||
logger.info("Daytona: stopped sandbox %s (filesystem preserved)",
|
||||
self._sandbox.id)
|
||||
else:
|
||||
self._daytona.delete(self._sandbox)
|
||||
logger.info("Daytona: deleted sandbox %s", self._sandbox.id)
|
||||
except Exception as e:
|
||||
logger.warning("Daytona: cleanup failed: %s", e)
|
||||
self._sandbox = None
|
||||
494
hermes_code/tools/environments/docker.py
Normal file
494
hermes_code/tools/environments/docker.py
Normal file
|
|
@ -0,0 +1,494 @@
|
|||
"""Docker execution environment for sandboxed command execution.
|
||||
|
||||
Security hardened (cap-drop ALL, no-new-privileges, PID limits),
|
||||
configurable resource limits (CPU, memory, disk), and optional filesystem
|
||||
persistence via bind mounts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Common Docker Desktop install paths checked when 'docker' is not in PATH.
|
||||
# macOS Intel: /usr/local/bin, macOS Apple Silicon (Homebrew): /opt/homebrew/bin,
|
||||
# Docker Desktop app bundle: /Applications/Docker.app/Contents/Resources/bin
|
||||
_DOCKER_SEARCH_PATHS = [
|
||||
"/usr/local/bin/docker",
|
||||
"/opt/homebrew/bin/docker",
|
||||
"/Applications/Docker.app/Contents/Resources/bin/docker",
|
||||
]
|
||||
|
||||
_docker_executable: Optional[str] = None # resolved once, cached
|
||||
_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def _normalize_forward_env_names(forward_env: list[str] | None) -> list[str]:
|
||||
"""Return a deduplicated list of valid environment variable names."""
|
||||
normalized: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for item in forward_env or []:
|
||||
if not isinstance(item, str):
|
||||
logger.warning("Ignoring non-string docker_forward_env entry: %r", item)
|
||||
continue
|
||||
|
||||
key = item.strip()
|
||||
if not key:
|
||||
continue
|
||||
if not _ENV_VAR_NAME_RE.match(key):
|
||||
logger.warning("Ignoring invalid docker_forward_env entry: %r", item)
|
||||
continue
|
||||
if key in seen:
|
||||
continue
|
||||
|
||||
seen.add(key)
|
||||
normalized.append(key)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _load_hermes_env_vars() -> dict[str, str]:
|
||||
"""Load ~/.hermes/.env values without failing Docker command execution."""
|
||||
try:
|
||||
from hermes_cli.config import load_env
|
||||
|
||||
return load_env() or {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def find_docker() -> Optional[str]:
|
||||
"""Locate the docker CLI binary.
|
||||
|
||||
Checks ``shutil.which`` first (respects PATH), then probes well-known
|
||||
install locations on macOS where Docker Desktop may not be in PATH
|
||||
(e.g. when running as a gateway service via launchd).
|
||||
|
||||
Returns the absolute path, or ``None`` if docker cannot be found.
|
||||
"""
|
||||
global _docker_executable
|
||||
if _docker_executable is not None:
|
||||
return _docker_executable
|
||||
|
||||
found = shutil.which("docker")
|
||||
if found:
|
||||
_docker_executable = found
|
||||
return found
|
||||
|
||||
for path in _DOCKER_SEARCH_PATHS:
|
||||
if os.path.isfile(path) and os.access(path, os.X_OK):
|
||||
_docker_executable = path
|
||||
logger.info("Found docker at non-PATH location: %s", path)
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Security flags applied to every container.
|
||||
# The container itself is the security boundary (isolated from host).
|
||||
# We drop all capabilities then add back the minimum needed:
|
||||
# DAC_OVERRIDE - root can write to bind-mounted dirs owned by host user
|
||||
# CHOWN/FOWNER - package managers (pip, npm, apt) need to set file ownership
|
||||
# Block privilege escalation and limit PIDs.
|
||||
# /tmp is size-limited and nosuid but allows exec (needed by pip/npm builds).
|
||||
_SECURITY_ARGS = [
|
||||
"--cap-drop", "ALL",
|
||||
"--cap-add", "DAC_OVERRIDE",
|
||||
"--cap-add", "CHOWN",
|
||||
"--cap-add", "FOWNER",
|
||||
"--security-opt", "no-new-privileges",
|
||||
"--pids-limit", "256",
|
||||
"--tmpfs", "/tmp:rw,nosuid,size=512m",
|
||||
"--tmpfs", "/var/tmp:rw,noexec,nosuid,size=256m",
|
||||
"--tmpfs", "/run:rw,noexec,nosuid,size=64m",
|
||||
]
|
||||
|
||||
|
||||
_storage_opt_ok: Optional[bool] = None # cached result across instances
|
||||
|
||||
|
||||
def _ensure_docker_available() -> None:
|
||||
"""Best-effort check that the docker CLI is available before use.
|
||||
|
||||
Reuses ``find_docker()`` so this preflight stays consistent with the rest of
|
||||
the Docker backend, including known non-PATH Docker Desktop locations.
|
||||
"""
|
||||
docker_exe = find_docker()
|
||||
if not docker_exe:
|
||||
logger.error(
|
||||
"Docker backend selected but no docker executable was found in PATH "
|
||||
"or known install locations. Install Docker Desktop and ensure the "
|
||||
"CLI is available."
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Docker executable not found in PATH or known install locations. "
|
||||
"Install Docker and ensure the 'docker' command is available."
|
||||
)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[docker_exe, "version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Docker backend selected but the resolved docker executable '%s' could "
|
||||
"not be executed.",
|
||||
docker_exe,
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Docker executable could not be executed. Check your Docker installation."
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(
|
||||
"Docker backend selected but '%s version' timed out. "
|
||||
"The Docker daemon may not be running.",
|
||||
docker_exe,
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Docker daemon is not responding. Ensure Docker is running and try again."
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Unexpected error while checking Docker availability.",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
if result.returncode != 0:
|
||||
logger.error(
|
||||
"Docker backend selected but '%s version' failed "
|
||||
"(exit code %d, stderr=%s)",
|
||||
docker_exe,
|
||||
result.returncode,
|
||||
result.stderr.strip(),
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Docker command is available but 'docker version' failed. "
|
||||
"Check your Docker installation."
|
||||
)
|
||||
|
||||
|
||||
class DockerEnvironment(BaseEnvironment):
|
||||
"""Hardened Docker container execution with resource limits and persistence.
|
||||
|
||||
Security: all capabilities dropped, no privilege escalation, PID limits,
|
||||
size-limited tmpfs for scratch dirs. The container itself is the security
|
||||
boundary — the filesystem inside is writable so agents can install packages
|
||||
(pip, npm, apt) as needed. Writable workspace via tmpfs or bind mounts.
|
||||
|
||||
Persistence: when enabled, bind mounts preserve /workspace and /root
|
||||
across container restarts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "/root",
|
||||
timeout: int = 60,
|
||||
cpu: float = 0,
|
||||
memory: int = 0,
|
||||
disk: int = 0,
|
||||
persistent_filesystem: bool = False,
|
||||
task_id: str = "default",
|
||||
volumes: list = None,
|
||||
forward_env: list[str] | None = None,
|
||||
network: bool = True,
|
||||
host_cwd: str = None,
|
||||
auto_mount_cwd: bool = False,
|
||||
):
|
||||
if cwd == "~":
|
||||
cwd = "/root"
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self._base_image = image
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._forward_env = _normalize_forward_env_names(forward_env)
|
||||
self._container_id: Optional[str] = None
|
||||
logger.info(f"DockerEnvironment volumes: {volumes}")
|
||||
# Ensure volumes is a list (config.yaml could be malformed)
|
||||
if volumes is not None and not isinstance(volumes, list):
|
||||
logger.warning(f"docker_volumes config is not a list: {volumes!r}")
|
||||
volumes = []
|
||||
|
||||
# Fail fast if Docker is not available.
|
||||
_ensure_docker_available()
|
||||
|
||||
# Build resource limit args
|
||||
resource_args = []
|
||||
if cpu > 0:
|
||||
resource_args.extend(["--cpus", str(cpu)])
|
||||
if memory > 0:
|
||||
resource_args.extend(["--memory", f"{memory}m"])
|
||||
if disk > 0 and sys.platform != "darwin":
|
||||
if self._storage_opt_supported():
|
||||
resource_args.extend(["--storage-opt", f"size={disk}m"])
|
||||
else:
|
||||
logger.warning(
|
||||
"Docker storage driver does not support per-container disk limits "
|
||||
"(requires overlay2 on XFS with pquota). Container will run without disk quota."
|
||||
)
|
||||
if not network:
|
||||
resource_args.append("--network=none")
|
||||
|
||||
# Persistent workspace via bind mounts from a configurable host directory
|
||||
# (TERMINAL_SANDBOX_DIR, default ~/.hermes/sandboxes/). Non-persistent
|
||||
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
|
||||
from tools.environments.base import get_sandbox_dir
|
||||
|
||||
# User-configured volume mounts (from config.yaml docker_volumes)
|
||||
volume_args = []
|
||||
workspace_explicitly_mounted = False
|
||||
for vol in (volumes or []):
|
||||
if not isinstance(vol, str):
|
||||
logger.warning(f"Docker volume entry is not a string: {vol!r}")
|
||||
continue
|
||||
vol = vol.strip()
|
||||
if not vol:
|
||||
continue
|
||||
if ":" in vol:
|
||||
volume_args.extend(["-v", vol])
|
||||
if ":/workspace" in vol:
|
||||
workspace_explicitly_mounted = True
|
||||
else:
|
||||
logger.warning(f"Docker volume '{vol}' missing colon, skipping")
|
||||
|
||||
host_cwd_abs = os.path.abspath(os.path.expanduser(host_cwd)) if host_cwd else ""
|
||||
bind_host_cwd = (
|
||||
auto_mount_cwd
|
||||
and bool(host_cwd_abs)
|
||||
and os.path.isdir(host_cwd_abs)
|
||||
and not workspace_explicitly_mounted
|
||||
)
|
||||
if auto_mount_cwd and host_cwd and not os.path.isdir(host_cwd_abs):
|
||||
logger.debug(f"Skipping docker cwd mount: host_cwd is not a valid directory: {host_cwd}")
|
||||
|
||||
self._workspace_dir: Optional[str] = None
|
||||
self._home_dir: Optional[str] = None
|
||||
writable_args = []
|
||||
if self._persistent:
|
||||
sandbox = get_sandbox_dir() / "docker" / task_id
|
||||
self._home_dir = str(sandbox / "home")
|
||||
os.makedirs(self._home_dir, exist_ok=True)
|
||||
writable_args.extend([
|
||||
"-v", f"{self._home_dir}:/root",
|
||||
])
|
||||
if not bind_host_cwd and not workspace_explicitly_mounted:
|
||||
self._workspace_dir = str(sandbox / "workspace")
|
||||
os.makedirs(self._workspace_dir, exist_ok=True)
|
||||
writable_args.extend([
|
||||
"-v", f"{self._workspace_dir}:/workspace",
|
||||
])
|
||||
else:
|
||||
if not bind_host_cwd and not workspace_explicitly_mounted:
|
||||
writable_args.extend([
|
||||
"--tmpfs", "/workspace:rw,exec,size=10g",
|
||||
])
|
||||
writable_args.extend([
|
||||
"--tmpfs", "/home:rw,exec,size=1g",
|
||||
"--tmpfs", "/root:rw,exec,size=1g",
|
||||
])
|
||||
|
||||
if bind_host_cwd:
|
||||
logger.info(f"Mounting configured host cwd to /workspace: {host_cwd_abs}")
|
||||
volume_args = ["-v", f"{host_cwd_abs}:/workspace", *volume_args]
|
||||
elif workspace_explicitly_mounted:
|
||||
logger.debug("Skipping docker cwd mount: /workspace already mounted by user config")
|
||||
|
||||
logger.info(f"Docker volume_args: {volume_args}")
|
||||
all_run_args = list(_SECURITY_ARGS) + writable_args + resource_args + volume_args
|
||||
logger.info(f"Docker run_args: {all_run_args}")
|
||||
|
||||
# Resolve the docker executable once so it works even when
|
||||
# /usr/local/bin is not in PATH (common on macOS gateway/service).
|
||||
self._docker_exe = find_docker() or "docker"
|
||||
|
||||
# Start the container directly via `docker run -d`.
|
||||
container_name = f"hermes-{uuid.uuid4().hex[:8]}"
|
||||
run_cmd = [
|
||||
self._docker_exe, "run", "-d",
|
||||
"--name", container_name,
|
||||
"-w", cwd,
|
||||
*all_run_args,
|
||||
image,
|
||||
"sleep", "2h",
|
||||
]
|
||||
logger.debug(f"Starting container: {' '.join(run_cmd)}")
|
||||
result = subprocess.run(
|
||||
run_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120, # image pull may take a while
|
||||
check=True,
|
||||
)
|
||||
self._container_id = result.stdout.strip()
|
||||
logger.info(f"Started container {container_name} ({self._container_id[:12]})")
|
||||
|
||||
@staticmethod
|
||||
def _storage_opt_supported() -> bool:
|
||||
"""Check if Docker's storage driver supports --storage-opt size=.
|
||||
|
||||
Only overlay2 on XFS with pquota supports per-container disk quotas.
|
||||
Ubuntu (and most distros) default to ext4, where this flag errors out.
|
||||
"""
|
||||
global _storage_opt_ok
|
||||
if _storage_opt_ok is not None:
|
||||
return _storage_opt_ok
|
||||
try:
|
||||
docker = find_docker() or "docker"
|
||||
result = subprocess.run(
|
||||
[docker, "info", "--format", "{{.Driver}}"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
driver = result.stdout.strip().lower()
|
||||
if driver != "overlay2":
|
||||
_storage_opt_ok = False
|
||||
return False
|
||||
# overlay2 only supports storage-opt on XFS with pquota.
|
||||
# Probe by attempting a dry-ish run — the fastest reliable check.
|
||||
probe = subprocess.run(
|
||||
[docker, "create", "--storage-opt", "size=1m", "hello-world"],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if probe.returncode == 0:
|
||||
# Clean up the created container
|
||||
container_id = probe.stdout.strip()
|
||||
if container_id:
|
||||
subprocess.run([docker, "rm", container_id],
|
||||
capture_output=True, timeout=5)
|
||||
_storage_opt_ok = True
|
||||
else:
|
||||
_storage_opt_ok = False
|
||||
except Exception:
|
||||
_storage_opt_ok = False
|
||||
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
|
||||
return _storage_opt_ok
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
work_dir = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
# docker exec -w doesn't expand ~, so prepend a cd into the command
|
||||
if work_dir == "~" or work_dir.startswith("~/"):
|
||||
exec_command = f"cd {work_dir} && {exec_command}"
|
||||
work_dir = "/"
|
||||
|
||||
assert self._container_id, "Container not started"
|
||||
cmd = [self._docker_exe, "exec"]
|
||||
if effective_stdin is not None:
|
||||
cmd.append("-i")
|
||||
cmd.extend(["-w", work_dir])
|
||||
hermes_env = _load_hermes_env_vars() if self._forward_env else {}
|
||||
for key in self._forward_env:
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
value = hermes_env.get(key)
|
||||
if value is not None:
|
||||
cmd.extend(["-e", f"{key}={value}"])
|
||||
cmd.extend([self._container_id, "bash", "-lc", exec_command])
|
||||
|
||||
try:
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
except Exception as e:
|
||||
return {"output": f"Docker execution error: {e}", "returncode": 1}
|
||||
|
||||
def cleanup(self):
|
||||
"""Stop and remove the container. Bind-mount dirs persist if persistent=True."""
|
||||
if self._container_id:
|
||||
try:
|
||||
# Stop in background so cleanup doesn't block
|
||||
stop_cmd = (
|
||||
f"(timeout 60 {self._docker_exe} stop {self._container_id} || "
|
||||
f"{self._docker_exe} rm -f {self._container_id}) >/dev/null 2>&1 &"
|
||||
)
|
||||
subprocess.Popen(stop_cmd, shell=True)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to stop container %s: %s", self._container_id, e)
|
||||
|
||||
if not self._persistent:
|
||||
# Also schedule removal (stop only leaves it as stopped)
|
||||
try:
|
||||
subprocess.Popen(
|
||||
f"sleep 3 && {self._docker_exe} rm -f {self._container_id} >/dev/null 2>&1 &",
|
||||
shell=True,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._container_id = None
|
||||
|
||||
if not self._persistent:
|
||||
for d in (self._workspace_dir, self._home_dir):
|
||||
if d:
|
||||
shutil.rmtree(d, ignore_errors=True)
|
||||
476
hermes_code/tools/environments/local.py
Normal file
476
hermes_code/tools/environments/local.py
Normal file
|
|
@ -0,0 +1,476 @@
|
|||
"""Local execution environment with interrupt support and non-blocking I/O."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
# Unique marker to isolate real command output from shell init/exit noise.
|
||||
# printf (no trailing newline) keeps the boundaries clean for splitting.
|
||||
_OUTPUT_FENCE = "__HERMES_FENCE_a9f7b3__"
|
||||
|
||||
# Hermes-internal env vars that should NOT leak into terminal subprocesses.
|
||||
# These are loaded from ~/.hermes/.env for Hermes' own LLM/provider calls
|
||||
# but can break external CLIs (e.g. codex) that also honor them.
|
||||
# See: https://github.com/NousResearch/hermes-agent/issues/1002
|
||||
#
|
||||
# Built dynamically from the provider registry so new providers are
|
||||
# automatically covered without manual blocklist maintenance.
|
||||
_HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_"
|
||||
|
||||
|
||||
def _build_provider_env_blocklist() -> frozenset:
|
||||
"""Derive the blocklist from provider, tool, and gateway config.
|
||||
|
||||
Automatically picks up api_key_env_vars and base_url_env_var from
|
||||
every registered provider, plus tool/messaging env vars from the
|
||||
optional config registry, so new Hermes-managed secrets are blocked
|
||||
in subprocesses without having to maintain multiple static lists.
|
||||
"""
|
||||
blocked: set[str] = set()
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
for pconfig in PROVIDER_REGISTRY.values():
|
||||
blocked.update(pconfig.api_key_env_vars)
|
||||
if pconfig.base_url_env_var:
|
||||
blocked.add(pconfig.base_url_env_var)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from hermes_cli.config import OPTIONAL_ENV_VARS
|
||||
for name, metadata in OPTIONAL_ENV_VARS.items():
|
||||
category = metadata.get("category")
|
||||
if category in {"tool", "messaging"}:
|
||||
blocked.add(name)
|
||||
elif category == "setting" and metadata.get("password"):
|
||||
blocked.add(name)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Vars not covered above but still Hermes-internal / conflict-prone.
|
||||
blocked.update({
|
||||
"OPENAI_BASE_URL",
|
||||
"OPENAI_API_KEY",
|
||||
"OPENAI_API_BASE", # legacy alias
|
||||
"OPENAI_ORG_ID",
|
||||
"OPENAI_ORGANIZATION",
|
||||
"OPENROUTER_API_KEY",
|
||||
"ANTHROPIC_BASE_URL",
|
||||
"ANTHROPIC_TOKEN", # OAuth token (not in registry as env var)
|
||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||
"LLM_MODEL",
|
||||
# Expanded isolation for other major providers (Issue #1002)
|
||||
"GOOGLE_API_KEY", # Gemini / Google AI Studio
|
||||
"DEEPSEEK_API_KEY", # DeepSeek
|
||||
"MISTRAL_API_KEY", # Mistral AI
|
||||
"GROQ_API_KEY", # Groq
|
||||
"TOGETHER_API_KEY", # Together AI
|
||||
"PERPLEXITY_API_KEY", # Perplexity
|
||||
"COHERE_API_KEY", # Cohere
|
||||
"FIREWORKS_API_KEY", # Fireworks AI
|
||||
"XAI_API_KEY", # xAI (Grok)
|
||||
"HELICONE_API_KEY", # LLM Observability proxy
|
||||
"PARALLEL_API_KEY",
|
||||
"FIRECRAWL_API_KEY",
|
||||
"FIRECRAWL_API_URL",
|
||||
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
|
||||
"TELEGRAM_HOME_CHANNEL",
|
||||
"TELEGRAM_HOME_CHANNEL_NAME",
|
||||
"DISCORD_HOME_CHANNEL",
|
||||
"DISCORD_HOME_CHANNEL_NAME",
|
||||
"DISCORD_REQUIRE_MENTION",
|
||||
"DISCORD_FREE_RESPONSE_CHANNELS",
|
||||
"DISCORD_AUTO_THREAD",
|
||||
"SLACK_HOME_CHANNEL",
|
||||
"SLACK_HOME_CHANNEL_NAME",
|
||||
"SLACK_ALLOWED_USERS",
|
||||
"WHATSAPP_ENABLED",
|
||||
"WHATSAPP_MODE",
|
||||
"WHATSAPP_ALLOWED_USERS",
|
||||
"SIGNAL_HTTP_URL",
|
||||
"SIGNAL_ACCOUNT",
|
||||
"SIGNAL_ALLOWED_USERS",
|
||||
"SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"SIGNAL_HOME_CHANNEL",
|
||||
"SIGNAL_HOME_CHANNEL_NAME",
|
||||
"SIGNAL_IGNORE_STORIES",
|
||||
"HASS_TOKEN",
|
||||
"HASS_URL",
|
||||
"EMAIL_ADDRESS",
|
||||
"EMAIL_PASSWORD",
|
||||
"EMAIL_IMAP_HOST",
|
||||
"EMAIL_SMTP_HOST",
|
||||
"EMAIL_HOME_ADDRESS",
|
||||
"EMAIL_HOME_ADDRESS_NAME",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
# Skills Hub / GitHub app auth paths and aliases.
|
||||
"GH_TOKEN",
|
||||
"GITHUB_APP_ID",
|
||||
"GITHUB_APP_PRIVATE_KEY_PATH",
|
||||
"GITHUB_APP_INSTALLATION_ID",
|
||||
# Remote sandbox backend credentials.
|
||||
"MODAL_TOKEN_ID",
|
||||
"MODAL_TOKEN_SECRET",
|
||||
"DAYTONA_API_KEY",
|
||||
})
|
||||
return frozenset(blocked)
|
||||
|
||||
|
||||
_HERMES_PROVIDER_ENV_BLOCKLIST = _build_provider_env_blocklist()
|
||||
|
||||
|
||||
def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict:
|
||||
"""Filter Hermes-managed secrets from a subprocess environment.
|
||||
|
||||
`_HERMES_FORCE_<VAR>` entries in ``extra_env`` opt a blocked variable back in
|
||||
intentionally for callers that truly need it. Vars registered via
|
||||
:mod:`tools.env_passthrough` (skill-declared or user-configured) also
|
||||
bypass the blocklist.
|
||||
"""
|
||||
try:
|
||||
from tools.env_passthrough import is_env_passthrough as _is_passthrough
|
||||
except Exception:
|
||||
_is_passthrough = lambda _: False # noqa: E731
|
||||
|
||||
sanitized: dict[str, str] = {}
|
||||
|
||||
for key, value in (base_env or {}).items():
|
||||
if key.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
|
||||
continue
|
||||
if key not in _HERMES_PROVIDER_ENV_BLOCKLIST or _is_passthrough(key):
|
||||
sanitized[key] = value
|
||||
|
||||
for key, value in (extra_env or {}).items():
|
||||
if key.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
|
||||
real_key = key[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
|
||||
sanitized[real_key] = value
|
||||
elif key not in _HERMES_PROVIDER_ENV_BLOCKLIST or _is_passthrough(key):
|
||||
sanitized[key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _find_bash() -> str:
|
||||
"""Find bash for command execution.
|
||||
|
||||
The fence wrapper uses bash syntax (semicolons, $?, printf), so we
|
||||
must use bash — not the user's $SHELL which could be fish/zsh/etc.
|
||||
On Windows: uses Git Bash (bundled with Git for Windows).
|
||||
"""
|
||||
if not _IS_WINDOWS:
|
||||
return (
|
||||
shutil.which("bash")
|
||||
or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None)
|
||||
or ("/bin/bash" if os.path.isfile("/bin/bash") else None)
|
||||
or os.environ.get("SHELL") # last resort: whatever they have
|
||||
or "/bin/sh"
|
||||
)
|
||||
|
||||
# Windows: look for Git Bash (installed with Git for Windows).
|
||||
# Allow override via env var (same pattern as Claude Code).
|
||||
custom = os.environ.get("HERMES_GIT_BASH_PATH")
|
||||
if custom and os.path.isfile(custom):
|
||||
return custom
|
||||
|
||||
# shutil.which finds bash.exe if Git\bin is on PATH
|
||||
found = shutil.which("bash")
|
||||
if found:
|
||||
return found
|
||||
|
||||
# Check common Git for Windows install locations
|
||||
for candidate in (
|
||||
os.path.join(os.environ.get("ProgramFiles", r"C:\Program Files"), "Git", "bin", "bash.exe"),
|
||||
os.path.join(os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)"), "Git", "bin", "bash.exe"),
|
||||
os.path.join(os.environ.get("LOCALAPPDATA", ""), "Programs", "Git", "bin", "bash.exe"),
|
||||
):
|
||||
if candidate and os.path.isfile(candidate):
|
||||
return candidate
|
||||
|
||||
raise RuntimeError(
|
||||
"Git Bash not found. Hermes Agent requires Git for Windows on Windows.\n"
|
||||
"Install it from: https://git-scm.com/download/win\n"
|
||||
"Or set HERMES_GIT_BASH_PATH to your bash.exe location."
|
||||
)
|
||||
|
||||
|
||||
# Backward compat — process_registry.py imports this name
|
||||
_find_shell = _find_bash
|
||||
|
||||
|
||||
# Noise lines emitted by interactive shells when stdin is not a terminal.
|
||||
# Used as a fallback when output fence markers are missing.
|
||||
_SHELL_NOISE_SUBSTRINGS = (
|
||||
# bash
|
||||
"bash: cannot set terminal process group",
|
||||
"bash: no job control in this shell",
|
||||
"no job control in this shell",
|
||||
"cannot set terminal process group",
|
||||
"tcsetattr: Inappropriate ioctl for device",
|
||||
# zsh / oh-my-zsh / macOS terminal session
|
||||
"Restored session:",
|
||||
"Saving session...",
|
||||
"Last login:",
|
||||
"command not found:",
|
||||
"Oh My Zsh",
|
||||
"compinit:",
|
||||
)
|
||||
|
||||
|
||||
def _clean_shell_noise(output: str) -> str:
|
||||
"""Strip shell startup/exit warnings that leak when using -i without a TTY.
|
||||
|
||||
Removes lines matching known noise patterns from both the beginning
|
||||
and end of the output. Lines in the middle are left untouched.
|
||||
"""
|
||||
|
||||
def _is_noise(line: str) -> bool:
|
||||
return any(noise in line for noise in _SHELL_NOISE_SUBSTRINGS)
|
||||
|
||||
lines = output.split("\n")
|
||||
|
||||
# Strip leading noise
|
||||
while lines and _is_noise(lines[0]):
|
||||
lines.pop(0)
|
||||
|
||||
# Strip trailing noise (walk backwards, skip empty lines from split)
|
||||
end = len(lines) - 1
|
||||
while end >= 0 and (not lines[end] or _is_noise(lines[end])):
|
||||
end -= 1
|
||||
|
||||
if end < 0:
|
||||
return ""
|
||||
|
||||
cleaned = lines[: end + 1]
|
||||
result = "\n".join(cleaned)
|
||||
|
||||
# Preserve trailing newline if original had one
|
||||
if output.endswith("\n") and result and not result.endswith("\n"):
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
|
||||
# Standard PATH entries for environments with minimal PATH (e.g. systemd services).
|
||||
# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon).
|
||||
_SANE_PATH = (
|
||||
"/opt/homebrew/bin:/opt/homebrew/sbin:"
|
||||
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
)
|
||||
|
||||
|
||||
def _make_run_env(env: dict) -> dict:
|
||||
"""Build a run environment with a sane PATH and provider-var stripping."""
|
||||
try:
|
||||
from tools.env_passthrough import is_env_passthrough as _is_passthrough
|
||||
except Exception:
|
||||
_is_passthrough = lambda _: False # noqa: E731
|
||||
|
||||
merged = dict(os.environ | env)
|
||||
run_env = {}
|
||||
for k, v in merged.items():
|
||||
if k.startswith(_HERMES_PROVIDER_ENV_FORCE_PREFIX):
|
||||
real_key = k[len(_HERMES_PROVIDER_ENV_FORCE_PREFIX):]
|
||||
run_env[real_key] = v
|
||||
elif k not in _HERMES_PROVIDER_ENV_BLOCKLIST or _is_passthrough(k):
|
||||
run_env[k] = v
|
||||
existing_path = run_env.get("PATH", "")
|
||||
if "/usr/bin" not in existing_path.split(":"):
|
||||
run_env["PATH"] = f"{existing_path}:{_SANE_PATH}" if existing_path else _SANE_PATH
|
||||
return run_env
|
||||
|
||||
|
||||
def _extract_fenced_output(raw: str) -> str:
|
||||
"""Extract real command output from between fence markers.
|
||||
|
||||
The execute() method wraps each command with printf(FENCE) markers.
|
||||
This function finds the first and last fence and returns only the
|
||||
content between them, which is the actual command output free of
|
||||
any shell init/exit noise.
|
||||
|
||||
Falls back to pattern-based _clean_shell_noise if fences are missing.
|
||||
"""
|
||||
first = raw.find(_OUTPUT_FENCE)
|
||||
if first == -1:
|
||||
return _clean_shell_noise(raw)
|
||||
|
||||
start = first + len(_OUTPUT_FENCE)
|
||||
last = raw.rfind(_OUTPUT_FENCE)
|
||||
|
||||
if last <= first:
|
||||
# Only start fence found (e.g. user command called `exit`)
|
||||
return _clean_shell_noise(raw[start:])
|
||||
|
||||
return raw[start:last]
|
||||
|
||||
|
||||
class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
"""Run commands directly on the host machine.
|
||||
|
||||
Features:
|
||||
- Popen + polling for interrupt support (user can cancel mid-command)
|
||||
- Background stdout drain thread to prevent pipe buffer deadlocks
|
||||
- stdin_data support for piping content (bypasses ARG_MAX limits)
|
||||
- sudo -S transform via SUDO_PASSWORD env var
|
||||
- Uses interactive login shell so full user env is available
|
||||
- Optional persistent shell mode (cwd/env vars survive across calls)
|
||||
"""
|
||||
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None,
|
||||
persistent: bool = False):
|
||||
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
||||
self.persistent = persistent
|
||||
if self.persistent:
|
||||
self._init_persistent_shell()
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-local-{self._session_id}"
|
||||
|
||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
||||
user_shell = _find_bash()
|
||||
run_env = _make_run_env(self.env)
|
||||
return subprocess.Popen(
|
||||
[user_shell, "-l"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
env=run_env,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
||||
results = []
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
with open(path) as f:
|
||||
results.append(f.read())
|
||||
else:
|
||||
results.append("")
|
||||
return results
|
||||
|
||||
def _kill_shell_children(self):
|
||||
if self._shell_pid is None:
|
||||
return
|
||||
try:
|
||||
subprocess.run(
|
||||
["pkill", "-P", str(self._shell_pid)],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
def _cleanup_temp_files(self):
|
||||
for f in glob.glob(f"{self._temp_prefix}-*"):
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
|
||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd or os.getcwd()
|
||||
effective_timeout = timeout or self.timeout
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
user_shell = _find_bash()
|
||||
fenced_cmd = (
|
||||
f"printf '{_OUTPUT_FENCE}';"
|
||||
f" {exec_command};"
|
||||
f" __hermes_rc=$?;"
|
||||
f" printf '{_OUTPUT_FENCE}';"
|
||||
f" exit $__hermes_rc"
|
||||
)
|
||||
run_env = _make_run_env(self.env)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", fenced_cmd],
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
env=run_env,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
if effective_stdin is not None:
|
||||
def _write_stdin():
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
||||
|
||||
_output_chunks: list[str] = []
|
||||
|
||||
def _drain_stdout():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except ValueError:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
proc.stdout.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain_stdout, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
output = _extract_fenced_output("".join(_output_chunks))
|
||||
return {"output": output, "returncode": proc.returncode}
|
||||
259
hermes_code/tools/environments/modal.py
Normal file
259
hermes_code/tools/environments/modal.py
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
"""Modal cloud execution environment using SWE-ReX directly.
|
||||
|
||||
Supports persistent filesystem snapshots: when enabled, the sandbox's filesystem
|
||||
is snapshotted on cleanup and restored on next creation, so installed packages,
|
||||
project files, and config changes survive across sessions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_cli.config import get_hermes_home
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SNAPSHOT_STORE = get_hermes_home() / "modal_snapshots.json"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
"""Load snapshot ID mapping from disk."""
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
"""Persist snapshot ID mapping to disk."""
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
class _AsyncWorker:
|
||||
"""Background thread with its own event loop for async-safe swe-rex calls.
|
||||
|
||||
Allows sync code to submit async coroutines and block for results,
|
||||
even when called from inside another running event loop (e.g. Atropos).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._started = threading.Event()
|
||||
|
||||
def start(self):
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait(timeout=30)
|
||||
|
||||
def _run_loop(self):
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
def run_coroutine(self, coro, timeout=600):
|
||||
if self._loop is None or self._loop.is_closed():
|
||||
raise RuntimeError("AsyncWorker loop is not running")
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
def stop(self):
|
||||
if self._loop and self._loop.is_running():
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=10)
|
||||
|
||||
|
||||
class ModalEnvironment(BaseEnvironment):
|
||||
"""Modal cloud execution via SWE-ReX.
|
||||
|
||||
Uses swe-rex's ModalDeployment directly for sandbox management.
|
||||
Adds sudo -S support, configurable resources (CPU, memory, disk),
|
||||
and optional filesystem persistence via Modal's snapshot API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "/root",
|
||||
timeout: int = 60,
|
||||
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._base_image = image
|
||||
self._deployment = None
|
||||
self._worker = _AsyncWorker()
|
||||
|
||||
sandbox_kwargs = dict(modal_sandbox_kwargs or {})
|
||||
|
||||
# If persistent, try to restore from a previous snapshot
|
||||
restored_image = None
|
||||
if self._persistent:
|
||||
snapshot_id = _load_snapshots().get(self._task_id)
|
||||
if snapshot_id:
|
||||
try:
|
||||
import modal
|
||||
restored_image = modal.Image.from_id(snapshot_id)
|
||||
logger.info("Modal: restoring from snapshot %s", snapshot_id[:20])
|
||||
except Exception as e:
|
||||
logger.warning("Modal: failed to restore snapshot, using base image: %s", e)
|
||||
restored_image = None
|
||||
|
||||
effective_image = restored_image if restored_image else image
|
||||
|
||||
# Pre-build a modal.Image with pip fix for Modal's legacy image builder.
|
||||
# Some task images have broken pip; fix via ensurepip before Modal uses it.
|
||||
import modal as _modal
|
||||
if isinstance(effective_image, str):
|
||||
effective_image = _modal.Image.from_registry(
|
||||
effective_image,
|
||||
setup_dockerfile_commands=[
|
||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||
],
|
||||
)
|
||||
|
||||
# Start the async worker thread and create the deployment on it
|
||||
# so all gRPC channels are bound to the worker's event loop.
|
||||
self._worker.start()
|
||||
|
||||
from swerex.deployment.modal import ModalDeployment
|
||||
|
||||
async def _create_and_start():
|
||||
deployment = ModalDeployment(
|
||||
image=effective_image,
|
||||
startup_timeout=180.0,
|
||||
runtime_timeout=3600.0,
|
||||
deployment_timeout=3600.0,
|
||||
install_pipx=True,
|
||||
modal_sandbox_kwargs=sandbox_kwargs,
|
||||
)
|
||||
await deployment.start()
|
||||
return deployment
|
||||
|
||||
self._deployment = self._worker.run_coroutine(_create_and_start())
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
# Modal sandboxes execute commands via the Modal SDK and cannot pipe
|
||||
# subprocess stdin directly the way a local Popen can. When a sudo
|
||||
# password is present, use a shell-level pipe from printf so that the
|
||||
# password feeds sudo -S without appearing as an echo argument embedded
|
||||
# in the shell string.
|
||||
if sudo_stdin is not None:
|
||||
import shlex
|
||||
exec_command = (
|
||||
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
|
||||
)
|
||||
|
||||
from swerex.runtime.abstract import Command as RexCommand
|
||||
|
||||
effective_cwd = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
# Run in a background thread so we can poll for interrupts
|
||||
result_holder = {"value": None, "error": None}
|
||||
|
||||
def _run():
|
||||
try:
|
||||
async def _do_execute():
|
||||
return await self._deployment.runtime.execute(
|
||||
RexCommand(
|
||||
command=exec_command,
|
||||
shell=True,
|
||||
check=False,
|
||||
cwd=effective_cwd,
|
||||
timeout=effective_timeout,
|
||||
merge_output_streams=True,
|
||||
)
|
||||
)
|
||||
output = self._worker.run_coroutine(_do_execute())
|
||||
result_holder["value"] = {
|
||||
"output": output.stdout,
|
||||
"returncode": output.exit_code,
|
||||
}
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.2)
|
||||
if is_interrupted():
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
asyncio.wait_for(self._deployment.stop(), timeout=10),
|
||||
timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"output": "[Command interrupted - Modal sandbox terminated]",
|
||||
"returncode": 130,
|
||||
}
|
||||
|
||||
if result_holder["error"]:
|
||||
return {"output": f"Modal execution error: {result_holder['error']}", "returncode": 1}
|
||||
return result_holder["value"]
|
||||
|
||||
def cleanup(self):
|
||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
if self._deployment is None:
|
||||
return
|
||||
|
||||
if self._persistent:
|
||||
try:
|
||||
sandbox = getattr(self._deployment, '_sandbox', None)
|
||||
if sandbox:
|
||||
async def _snapshot():
|
||||
img = await sandbox.snapshot_filesystem.aio()
|
||||
return img.object_id
|
||||
|
||||
try:
|
||||
snapshot_id = self._worker.run_coroutine(_snapshot(), timeout=60)
|
||||
except Exception:
|
||||
snapshot_id = None
|
||||
|
||||
if snapshot_id:
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = snapshot_id
|
||||
_save_snapshots(snapshots)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20], self._task_id)
|
||||
except Exception as e:
|
||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
asyncio.wait_for(self._deployment.stop(), timeout=10),
|
||||
timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._worker.stop()
|
||||
self._deployment = None
|
||||
272
hermes_code/tools/environments/persistent_shell.py
Normal file
272
hermes_code/tools/environments/persistent_shell.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
"""Persistent shell mixin: file-based IPC protocol for long-lived bash shells."""
|
||||
|
||||
import logging
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PersistentShellMixin:
|
||||
"""Mixin that adds persistent shell capability to any BaseEnvironment.
|
||||
|
||||
Subclasses must implement ``_spawn_shell_process()``, ``_read_temp_files()``,
|
||||
``_kill_shell_children()``, ``_execute_oneshot()``, and ``_cleanup_temp_files()``.
|
||||
"""
|
||||
|
||||
persistent: bool
|
||||
|
||||
@abstractmethod
|
||||
def _spawn_shell_process(self) -> subprocess.Popen: ...
|
||||
|
||||
@abstractmethod
|
||||
def _read_temp_files(self, *paths: str) -> list[str]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _kill_shell_children(self): ...
|
||||
|
||||
@abstractmethod
|
||||
def _execute_oneshot(self, command: str, cwd: str, *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict: ...
|
||||
|
||||
@abstractmethod
|
||||
def _cleanup_temp_files(self): ...
|
||||
|
||||
_session_id: str = ""
|
||||
_poll_interval: float = 0.01
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-persistent-{self._session_id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_persistent_shell(self):
|
||||
self._shell_lock = threading.Lock()
|
||||
self._shell_proc: subprocess.Popen | None = None
|
||||
self._shell_alive: bool = False
|
||||
self._shell_pid: int | None = None
|
||||
|
||||
self._session_id = uuid.uuid4().hex[:12]
|
||||
p = self._temp_prefix
|
||||
self._pshell_stdout = f"{p}-stdout"
|
||||
self._pshell_stderr = f"{p}-stderr"
|
||||
self._pshell_status = f"{p}-status"
|
||||
self._pshell_cwd = f"{p}-cwd"
|
||||
self._pshell_pid_file = f"{p}-pid"
|
||||
|
||||
self._shell_proc = self._spawn_shell_process()
|
||||
self._shell_alive = True
|
||||
|
||||
self._drain_thread = threading.Thread(
|
||||
target=self._drain_shell_output, daemon=True,
|
||||
)
|
||||
self._drain_thread.start()
|
||||
|
||||
init_script = (
|
||||
f"export TERM=${{TERM:-dumb}}\n"
|
||||
f"touch {self._pshell_stdout} {self._pshell_stderr} "
|
||||
f"{self._pshell_status} {self._pshell_cwd} {self._pshell_pid_file}\n"
|
||||
f"echo $$ > {self._pshell_pid_file}\n"
|
||||
f"pwd > {self._pshell_cwd}\n"
|
||||
)
|
||||
self._send_to_shell(init_script)
|
||||
|
||||
deadline = time.monotonic() + 3.0
|
||||
while time.monotonic() < deadline:
|
||||
pid_str = self._read_temp_files(self._pshell_pid_file)[0].strip()
|
||||
if pid_str.isdigit():
|
||||
self._shell_pid = int(pid_str)
|
||||
break
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
logger.warning("Could not read persistent shell PID")
|
||||
self._shell_pid = None
|
||||
|
||||
if self._shell_pid:
|
||||
logger.info(
|
||||
"Persistent shell started (session=%s, pid=%d)",
|
||||
self._session_id, self._shell_pid,
|
||||
)
|
||||
|
||||
reported_cwd = self._read_temp_files(self._pshell_cwd)[0].strip()
|
||||
if reported_cwd:
|
||||
self.cwd = reported_cwd
|
||||
|
||||
def _cleanup_persistent_shell(self):
|
||||
if self._shell_proc is None:
|
||||
return
|
||||
|
||||
if self._session_id:
|
||||
self._cleanup_temp_files()
|
||||
|
||||
try:
|
||||
self._shell_proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._shell_proc.terminate()
|
||||
self._shell_proc.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
self._shell_proc.kill()
|
||||
|
||||
self._shell_alive = False
|
||||
self._shell_proc = None
|
||||
|
||||
if hasattr(self, "_drain_thread") and self._drain_thread.is_alive():
|
||||
self._drain_thread.join(timeout=1.0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# execute() / cleanup() — shared dispatcher, subclasses inherit
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if self.persistent:
|
||||
return self._execute_persistent(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
if self.persistent:
|
||||
self._cleanup_persistent_shell()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shell I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _drain_shell_output(self):
|
||||
try:
|
||||
for _ in self._shell_proc.stdout:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
self._shell_alive = False
|
||||
|
||||
def _send_to_shell(self, text: str):
|
||||
if not self._shell_alive or self._shell_proc is None:
|
||||
return
|
||||
try:
|
||||
self._shell_proc.stdin.write(text)
|
||||
self._shell_proc.stdin.flush()
|
||||
except (BrokenPipeError, OSError):
|
||||
self._shell_alive = False
|
||||
|
||||
def _read_persistent_output(self) -> tuple[str, int, str]:
|
||||
stdout, stderr, status_raw, cwd = self._read_temp_files(
|
||||
self._pshell_stdout, self._pshell_stderr,
|
||||
self._pshell_status, self._pshell_cwd,
|
||||
)
|
||||
output = self._merge_output(stdout, stderr)
|
||||
status = status_raw.strip()
|
||||
if ":" in status:
|
||||
status = status.split(":", 1)[1]
|
||||
try:
|
||||
exit_code = int(status.strip())
|
||||
except ValueError:
|
||||
exit_code = 1
|
||||
return output, exit_code, cwd.strip()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _execute_persistent(self, command: str, cwd: str, *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if not self._shell_alive:
|
||||
logger.info("Persistent shell died, restarting...")
|
||||
self._init_persistent_shell()
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
effective_timeout = timeout or self.timeout
|
||||
if stdin_data or sudo_stdin:
|
||||
return self._execute_oneshot(
|
||||
command, cwd, timeout=timeout, stdin_data=stdin_data,
|
||||
)
|
||||
|
||||
with self._shell_lock:
|
||||
return self._execute_persistent_locked(
|
||||
exec_command, cwd, effective_timeout,
|
||||
)
|
||||
|
||||
def _execute_persistent_locked(self, command: str, cwd: str,
|
||||
timeout: int) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
cmd_id = uuid.uuid4().hex[:8]
|
||||
truncate = (
|
||||
f": > {self._pshell_stdout}\n"
|
||||
f": > {self._pshell_stderr}\n"
|
||||
f": > {self._pshell_status}\n"
|
||||
)
|
||||
self._send_to_shell(truncate)
|
||||
escaped = command.replace("'", "'\\''")
|
||||
|
||||
ipc_script = (
|
||||
f"cd {shlex.quote(work_dir)}\n"
|
||||
f"eval '{escaped}' < /dev/null > {self._pshell_stdout} 2> {self._pshell_stderr}\n"
|
||||
f"__EC=$?\n"
|
||||
f"pwd > {self._pshell_cwd}\n"
|
||||
f"echo {cmd_id}:$__EC > {self._pshell_status}\n"
|
||||
)
|
||||
self._send_to_shell(ipc_script)
|
||||
deadline = time.monotonic() + timeout
|
||||
poll_interval = self._poll_interval
|
||||
|
||||
while True:
|
||||
if is_interrupted():
|
||||
self._kill_shell_children()
|
||||
output, _, _ = self._read_persistent_output()
|
||||
return {
|
||||
"output": output + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
|
||||
if time.monotonic() > deadline:
|
||||
self._kill_shell_children()
|
||||
output, _, _ = self._read_persistent_output()
|
||||
if output:
|
||||
return {
|
||||
"output": output + f"\n[Command timed out after {timeout}s]",
|
||||
"returncode": 124,
|
||||
}
|
||||
return self._timeout_result(timeout)
|
||||
|
||||
if not self._shell_alive:
|
||||
return {
|
||||
"output": "Persistent shell died during execution",
|
||||
"returncode": 1,
|
||||
}
|
||||
|
||||
status_content = self._read_temp_files(self._pshell_status)[0].strip()
|
||||
if status_content.startswith(cmd_id + ":"):
|
||||
break
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
output, exit_code, new_cwd = self._read_persistent_output()
|
||||
if new_cwd:
|
||||
self.cwd = new_cwd
|
||||
return {"output": output, "returncode": exit_code}
|
||||
|
||||
@staticmethod
|
||||
def _merge_output(stdout: str, stderr: str) -> str:
|
||||
parts = []
|
||||
if stdout.strip():
|
||||
parts.append(stdout.rstrip("\n"))
|
||||
if stderr.strip():
|
||||
parts.append(stderr.rstrip("\n"))
|
||||
return "\n".join(parts)
|
||||
369
hermes_code/tools/environments/singularity.py
Normal file
369
hermes_code/tools/environments/singularity.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
"""Singularity/Apptainer persistent container environment.
|
||||
|
||||
Security-hardened with --containall, --no-home, capability dropping.
|
||||
Supports configurable resource limits and optional filesystem persistence
|
||||
via writable overlay directories that survive across sessions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_cli.config import get_hermes_home
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SNAPSHOT_STORE = get_hermes_home() / "singularity_snapshots.json"
|
||||
|
||||
|
||||
def _find_singularity_executable() -> str:
|
||||
"""Locate the apptainer or singularity CLI binary.
|
||||
|
||||
Returns the executable name (``"apptainer"`` or ``"singularity"``).
|
||||
Raises ``RuntimeError`` with install instructions if neither is found.
|
||||
"""
|
||||
if shutil.which("apptainer"):
|
||||
return "apptainer"
|
||||
if shutil.which("singularity"):
|
||||
return "singularity"
|
||||
raise RuntimeError(
|
||||
"Neither 'apptainer' nor 'singularity' was found in PATH. "
|
||||
"Install Apptainer (https://apptainer.org/docs/admin/main/installation.html) "
|
||||
"or Singularity and ensure the CLI is available."
|
||||
)
|
||||
|
||||
|
||||
def _ensure_singularity_available() -> str:
|
||||
"""Preflight check: resolve the executable and verify it responds.
|
||||
|
||||
Returns the executable name on success.
|
||||
Raises ``RuntimeError`` with an actionable message on failure.
|
||||
"""
|
||||
exe = _find_singularity_executable()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[exe, "version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(
|
||||
f"Singularity backend selected but the resolved executable '{exe}' "
|
||||
"could not be executed. Check your installation."
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(
|
||||
f"'{exe} version' timed out. The runtime may be misconfigured."
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.strip()[:200]
|
||||
raise RuntimeError(
|
||||
f"'{exe} version' failed (exit code {result.returncode}): {stderr}"
|
||||
)
|
||||
|
||||
return exe
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Singularity helpers (scratch dir, SIF cache, SIF building)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _get_scratch_dir() -> Path:
|
||||
"""Get the best directory for Singularity sandboxes.
|
||||
|
||||
Resolution order:
|
||||
1. TERMINAL_SCRATCH_DIR (explicit override)
|
||||
2. TERMINAL_SANDBOX_DIR / singularity (shared sandbox root)
|
||||
3. /scratch (common on HPC clusters)
|
||||
4. ~/.hermes/sandboxes/singularity (fallback)
|
||||
"""
|
||||
custom_scratch = os.getenv("TERMINAL_SCRATCH_DIR")
|
||||
if custom_scratch:
|
||||
scratch_path = Path(custom_scratch)
|
||||
scratch_path.mkdir(parents=True, exist_ok=True)
|
||||
return scratch_path
|
||||
|
||||
from tools.environments.base import get_sandbox_dir
|
||||
sandbox = get_sandbox_dir() / "singularity"
|
||||
|
||||
scratch = Path("/scratch")
|
||||
if scratch.exists() and os.access(scratch, os.W_OK):
|
||||
user_scratch = scratch / os.getenv("USER", "hermes") / "hermes-agent"
|
||||
user_scratch.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Using /scratch for sandboxes: %s", user_scratch)
|
||||
return user_scratch
|
||||
|
||||
sandbox.mkdir(parents=True, exist_ok=True)
|
||||
return sandbox
|
||||
|
||||
|
||||
def _get_apptainer_cache_dir() -> Path:
|
||||
"""Get the Apptainer cache directory for SIF images."""
|
||||
cache_dir = os.getenv("APPTAINER_CACHEDIR")
|
||||
if cache_dir:
|
||||
cache_path = Path(cache_dir)
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
return cache_path
|
||||
scratch = _get_scratch_dir()
|
||||
cache_path = scratch / ".apptainer"
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
return cache_path
|
||||
|
||||
|
||||
_sif_build_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_or_build_sif(image: str, executable: str = "apptainer") -> str:
|
||||
"""Get or build a SIF image from a docker:// URL.
|
||||
|
||||
Returns the path unchanged if it's already a .sif file.
|
||||
For docker:// URLs, checks the cache and builds if needed.
|
||||
"""
|
||||
if image.endswith('.sif') and Path(image).exists():
|
||||
return image
|
||||
if not image.startswith('docker://'):
|
||||
return image
|
||||
|
||||
image_name = image.replace('docker://', '').replace('/', '-').replace(':', '-')
|
||||
cache_dir = _get_apptainer_cache_dir()
|
||||
sif_path = cache_dir / f"{image_name}.sif"
|
||||
|
||||
if sif_path.exists():
|
||||
return str(sif_path)
|
||||
|
||||
with _sif_build_lock:
|
||||
if sif_path.exists():
|
||||
return str(sif_path)
|
||||
|
||||
logger.info("Building SIF image (one-time setup)...")
|
||||
logger.info(" Source: %s", image)
|
||||
logger.info(" Target: %s", sif_path)
|
||||
|
||||
tmp_dir = cache_dir / "tmp"
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["APPTAINER_TMPDIR"] = str(tmp_dir)
|
||||
env["APPTAINER_CACHEDIR"] = str(cache_dir)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[executable, "build", str(sif_path), image],
|
||||
capture_output=True, text=True, timeout=600, env=env,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning("SIF build failed, falling back to docker:// URL")
|
||||
logger.warning(" Error: %s", result.stderr[:500])
|
||||
return image
|
||||
logger.info("SIF image built successfully")
|
||||
return str(sif_path)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("SIF build timed out, falling back to docker:// URL")
|
||||
if sif_path.exists():
|
||||
sif_path.unlink()
|
||||
return image
|
||||
except Exception as e:
|
||||
logger.warning("SIF build error: %s, falling back to docker:// URL", e)
|
||||
return image
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# SingularityEnvironment
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
class SingularityEnvironment(BaseEnvironment):
|
||||
"""Hardened Singularity/Apptainer container with resource limits and persistence.
|
||||
|
||||
Security: --containall (isolated PID/IPC/mount namespaces, no host home mount),
|
||||
--no-home, writable-tmpfs for scratch space. The container cannot see or modify
|
||||
the host filesystem outside of explicitly bound paths.
|
||||
|
||||
Persistence: when enabled, the writable overlay directory is preserved across
|
||||
sessions so installed packages and files survive cleanup/restore.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "~",
|
||||
timeout: int = 60,
|
||||
cpu: float = 0,
|
||||
memory: int = 0,
|
||||
disk: int = 0,
|
||||
persistent_filesystem: bool = False,
|
||||
task_id: str = "default",
|
||||
):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self.executable = _ensure_singularity_available()
|
||||
self.image = _get_or_build_sif(image, self.executable)
|
||||
self.instance_id = f"hermes_{uuid.uuid4().hex[:12]}"
|
||||
self._instance_started = False
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._overlay_dir: Optional[Path] = None
|
||||
|
||||
# Resource limits
|
||||
self._cpu = cpu
|
||||
self._memory = memory
|
||||
|
||||
# Persistent overlay directory
|
||||
if self._persistent:
|
||||
overlay_base = _get_scratch_dir() / "hermes-overlays"
|
||||
overlay_base.mkdir(parents=True, exist_ok=True)
|
||||
self._overlay_dir = overlay_base / f"overlay-{task_id}"
|
||||
self._overlay_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._start_instance()
|
||||
|
||||
def _start_instance(self):
|
||||
cmd = [self.executable, "instance", "start"]
|
||||
|
||||
# Security: full isolation from host
|
||||
cmd.extend(["--containall", "--no-home"])
|
||||
|
||||
# Writable layer
|
||||
if self._persistent and self._overlay_dir:
|
||||
# Persistent writable overlay -- survives across restarts
|
||||
cmd.extend(["--overlay", str(self._overlay_dir)])
|
||||
else:
|
||||
cmd.append("--writable-tmpfs")
|
||||
|
||||
# Resource limits (cgroup-based, may require root or appropriate config)
|
||||
if self._memory > 0:
|
||||
cmd.extend(["--memory", f"{self._memory}M"])
|
||||
if self._cpu > 0:
|
||||
cmd.extend(["--cpus", str(self._cpu)])
|
||||
|
||||
cmd.extend([str(self.image), self.instance_id])
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to start instance: {result.stderr}")
|
||||
self._instance_started = True
|
||||
logger.info("Singularity instance %s started (persistent=%s)",
|
||||
self.instance_id, self._persistent)
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError("Instance start timed out")
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
if not self._instance_started:
|
||||
return {"output": "Instance not started", "returncode": -1}
|
||||
|
||||
effective_timeout = timeout or self.timeout
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
# apptainer exec --pwd doesn't expand ~, so prepend a cd into the command
|
||||
if work_dir == "~" or work_dir.startswith("~/"):
|
||||
exec_command = f"cd {work_dir} && {exec_command}"
|
||||
work_dir = "/tmp"
|
||||
|
||||
cmd = [self.executable, "exec", "--pwd", work_dir,
|
||||
f"instance://{self.instance_id}",
|
||||
"bash", "-c", exec_command]
|
||||
|
||||
try:
|
||||
import time as _time
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = _time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if _time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
_time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
except Exception as e:
|
||||
return {"output": f"Singularity execution error: {e}", "returncode": 1}
|
||||
|
||||
def cleanup(self):
|
||||
"""Stop the instance. If persistent, the overlay dir survives for next creation."""
|
||||
if self._instance_started:
|
||||
try:
|
||||
subprocess.run(
|
||||
[self.executable, "instance", "stop", self.instance_id],
|
||||
capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
logger.info("Singularity instance %s stopped", self.instance_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to stop Singularity instance %s: %s", self.instance_id, e)
|
||||
self._instance_started = False
|
||||
|
||||
# Record overlay path for persistence restoration
|
||||
if self._persistent and self._overlay_dir:
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = str(self._overlay_dir)
|
||||
_save_snapshots(snapshots)
|
||||
232
hermes_code/tools/environments/ssh.py
Normal file
232
hermes_code/tools/environments/ssh.py
Normal file
|
|
@ -0,0 +1,232 @@
|
|||
"""SSH remote execution environment with ControlMaster connection persistence."""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ensure_ssh_available() -> None:
|
||||
"""Fail fast with a clear error when the SSH client is unavailable."""
|
||||
if not shutil.which("ssh"):
|
||||
raise RuntimeError(
|
||||
"SSH is not installed or not in PATH. Install OpenSSH client: apt install openssh-client"
|
||||
)
|
||||
|
||||
|
||||
class SSHEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
"""Run commands on a remote machine over SSH.
|
||||
|
||||
Uses SSH ControlMaster for connection persistence so subsequent
|
||||
commands are fast. Security benefit: the agent cannot modify its
|
||||
own code since execution happens on a separate machine.
|
||||
|
||||
Foreground commands are interruptible: the local ssh process is killed
|
||||
and a remote kill is attempted over the ControlMaster socket.
|
||||
|
||||
When ``persistent=True``, a single long-lived bash shell is kept alive
|
||||
over SSH and state (cwd, env vars, shell variables) persists across
|
||||
``execute()`` calls. Output capture uses file-based IPC on the remote
|
||||
host (stdout/stderr/exit-code written to temp files, polled via fast
|
||||
ControlMaster one-shot reads).
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, user: str, cwd: str = "~",
|
||||
timeout: int = 60, port: int = 22, key_path: str = "",
|
||||
persistent: bool = False):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
self.host = host
|
||||
self.user = user
|
||||
self.port = port
|
||||
self.key_path = key_path
|
||||
self.persistent = persistent
|
||||
|
||||
self.control_dir = Path(tempfile.gettempdir()) / "hermes-ssh"
|
||||
self.control_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.control_socket = self.control_dir / f"{user}@{host}:{port}.sock"
|
||||
_ensure_ssh_available()
|
||||
self._establish_connection()
|
||||
|
||||
if self.persistent:
|
||||
self._init_persistent_shell()
|
||||
|
||||
def _build_ssh_command(self, extra_args: list | None = None) -> list:
|
||||
cmd = ["ssh"]
|
||||
cmd.extend(["-o", f"ControlPath={self.control_socket}"])
|
||||
cmd.extend(["-o", "ControlMaster=auto"])
|
||||
cmd.extend(["-o", "ControlPersist=300"])
|
||||
cmd.extend(["-o", "BatchMode=yes"])
|
||||
cmd.extend(["-o", "StrictHostKeyChecking=accept-new"])
|
||||
cmd.extend(["-o", "ConnectTimeout=10"])
|
||||
if self.port != 22:
|
||||
cmd.extend(["-p", str(self.port)])
|
||||
if self.key_path:
|
||||
cmd.extend(["-i", self.key_path])
|
||||
if extra_args:
|
||||
cmd.extend(extra_args)
|
||||
cmd.append(f"{self.user}@{self.host}")
|
||||
return cmd
|
||||
|
||||
def _establish_connection(self):
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append("echo 'SSH connection established'")
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
||||
if result.returncode != 0:
|
||||
error_msg = result.stderr.strip() or result.stdout.strip()
|
||||
raise RuntimeError(f"SSH connection failed: {error_msg}")
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"SSH connection to {self.user}@{self.host} timed out")
|
||||
|
||||
_poll_interval: float = 0.15
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-ssh-{self._session_id}"
|
||||
|
||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append("bash -l")
|
||||
return subprocess.Popen(
|
||||
cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
||||
if len(paths) == 1:
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"cat {paths[0]} 2>/dev/null")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
return [result.stdout]
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return [""]
|
||||
|
||||
delim = f"__HERMES_SEP_{self._session_id}__"
|
||||
script = "; ".join(
|
||||
f"cat {p} 2>/dev/null; echo '{delim}'" for p in paths
|
||||
)
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(script)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
parts = result.stdout.split(delim + "\n")
|
||||
return [parts[i] if i < len(parts) else "" for i in range(len(paths))]
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return [""] * len(paths)
|
||||
|
||||
def _kill_shell_children(self):
|
||||
if self._shell_pid is None:
|
||||
return
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"pkill -P {self._shell_pid} 2>/dev/null; true")
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
def _cleanup_temp_files(self):
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(f"rm -f {self._temp_prefix}-*")
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
pass
|
||||
|
||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
wrapped = f'cd {work_dir} && {exec_command}'
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
cmd = self._build_ssh_command()
|
||||
cmd.append(wrapped)
|
||||
|
||||
kwargs = self._build_run_kwargs(timeout, effective_stdin)
|
||||
kwargs.pop("timeout", None)
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
if self.control_socket.exists():
|
||||
try:
|
||||
cmd = ["ssh", "-o", f"ControlPath={self.control_socket}",
|
||||
"-O", "exit", f"{self.user}@{self.host}"]
|
||||
subprocess.run(cmd, capture_output=True, timeout=5)
|
||||
except (OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
try:
|
||||
self.control_socket.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
1165
hermes_code/tools/file_operations.py
Normal file
1165
hermes_code/tools/file_operations.py
Normal file
File diff suppressed because it is too large
Load diff
522
hermes_code/tools/file_tools.py
Normal file
522
hermes_code/tools/file_tools.py
Normal file
|
|
@ -0,0 +1,522 @@
|
|||
#!/usr/bin/env python3
|
||||
"""File Tools Module - LLM agent file manipulation tools."""
|
||||
|
||||
import errno
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
from tools.file_operations import ShellFileOperations
|
||||
from agent.redact import redact_sensitive_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_EXPECTED_WRITE_ERRNOS = {errno.EACCES, errno.EPERM, errno.EROFS}
|
||||
|
||||
|
||||
def _is_expected_write_exception(exc: Exception) -> bool:
|
||||
"""Return True for expected write denials that should not hit error logs."""
|
||||
if isinstance(exc, PermissionError):
|
||||
return True
|
||||
if isinstance(exc, OSError) and exc.errno in _EXPECTED_WRITE_ERRNOS:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
_file_ops_lock = threading.Lock()
|
||||
_file_ops_cache: dict = {}
|
||||
|
||||
# Track files read per task to detect re-read loops after context compression.
|
||||
# Per task_id we store:
|
||||
# "last_key": the key of the most recent read/search call (or None)
|
||||
# "consecutive": how many times that exact call has been repeated in a row
|
||||
# "read_history": set of (path, offset, limit) tuples for get_read_files_summary
|
||||
_read_tracker_lock = threading.Lock()
|
||||
_read_tracker: dict = {}
|
||||
|
||||
|
||||
def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||
"""Get or create ShellFileOperations for a terminal environment.
|
||||
|
||||
Respects the TERMINAL_ENV setting -- if the task_id doesn't have an
|
||||
environment yet, creates one using the configured backend (local, docker,
|
||||
modal, etc.) rather than always defaulting to local.
|
||||
|
||||
Thread-safe: uses the same per-task creation locks as terminal_tool to
|
||||
prevent duplicate sandbox creation from concurrent tool calls.
|
||||
"""
|
||||
from tools.terminal_tool import (
|
||||
_active_environments, _env_lock, _create_environment,
|
||||
_get_env_config, _last_activity, _start_cleanup_thread,
|
||||
_check_disk_usage_warning,
|
||||
_creation_locks, _creation_locks_lock,
|
||||
)
|
||||
import time
|
||||
|
||||
# Fast path: check cache -- but also verify the underlying environment
|
||||
# is still alive (it may have been killed by the cleanup thread).
|
||||
with _file_ops_lock:
|
||||
cached = _file_ops_cache.get(task_id)
|
||||
if cached is not None:
|
||||
with _env_lock:
|
||||
if task_id in _active_environments:
|
||||
_last_activity[task_id] = time.time()
|
||||
return cached
|
||||
else:
|
||||
# Environment was cleaned up -- invalidate stale cache entry
|
||||
with _file_ops_lock:
|
||||
_file_ops_cache.pop(task_id, None)
|
||||
|
||||
# Need to ensure the environment exists before building file_ops.
|
||||
# Acquire per-task lock so only one thread creates the sandbox.
|
||||
with _creation_locks_lock:
|
||||
if task_id not in _creation_locks:
|
||||
_creation_locks[task_id] = threading.Lock()
|
||||
task_lock = _creation_locks[task_id]
|
||||
|
||||
with task_lock:
|
||||
# Double-check: another thread may have created it while we waited
|
||||
with _env_lock:
|
||||
if task_id in _active_environments:
|
||||
_last_activity[task_id] = time.time()
|
||||
terminal_env = _active_environments[task_id]
|
||||
else:
|
||||
terminal_env = None
|
||||
|
||||
if terminal_env is None:
|
||||
from tools.terminal_tool import _task_env_overrides
|
||||
|
||||
config = _get_env_config()
|
||||
env_type = config["env_type"]
|
||||
overrides = _task_env_overrides.get(task_id, {})
|
||||
|
||||
if env_type == "docker":
|
||||
image = overrides.get("docker_image") or config["docker_image"]
|
||||
elif env_type == "singularity":
|
||||
image = overrides.get("singularity_image") or config["singularity_image"]
|
||||
elif env_type == "modal":
|
||||
image = overrides.get("modal_image") or config["modal_image"]
|
||||
elif env_type == "daytona":
|
||||
image = overrides.get("daytona_image") or config["daytona_image"]
|
||||
else:
|
||||
image = ""
|
||||
|
||||
cwd = overrides.get("cwd") or config["cwd"]
|
||||
logger.info("Creating new %s environment for task %s...", env_type, task_id[:8])
|
||||
|
||||
container_config = None
|
||||
if env_type in ("docker", "singularity", "modal", "daytona"):
|
||||
container_config = {
|
||||
"container_cpu": config.get("container_cpu", 1),
|
||||
"container_memory": config.get("container_memory", 5120),
|
||||
"container_disk": config.get("container_disk", 51200),
|
||||
"container_persistent": config.get("container_persistent", True),
|
||||
"docker_volumes": config.get("docker_volumes", []),
|
||||
}
|
||||
|
||||
ssh_config = None
|
||||
if env_type == "ssh":
|
||||
ssh_config = {
|
||||
"host": config.get("ssh_host", ""),
|
||||
"user": config.get("ssh_user", ""),
|
||||
"port": config.get("ssh_port", 22),
|
||||
"key": config.get("ssh_key", ""),
|
||||
"persistent": config.get("ssh_persistent", False),
|
||||
}
|
||||
|
||||
local_config = None
|
||||
if env_type == "local":
|
||||
local_config = {
|
||||
"persistent": config.get("local_persistent", False),
|
||||
}
|
||||
|
||||
terminal_env = _create_environment(
|
||||
env_type=env_type,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
timeout=config["timeout"],
|
||||
ssh_config=ssh_config,
|
||||
container_config=container_config,
|
||||
local_config=local_config,
|
||||
task_id=task_id,
|
||||
host_cwd=config.get("host_cwd"),
|
||||
)
|
||||
|
||||
with _env_lock:
|
||||
_active_environments[task_id] = terminal_env
|
||||
_last_activity[task_id] = time.time()
|
||||
|
||||
_start_cleanup_thread()
|
||||
logger.info("%s environment ready for task %s", env_type, task_id[:8])
|
||||
|
||||
# Build file_ops from the (guaranteed live) environment and cache it
|
||||
file_ops = ShellFileOperations(terminal_env)
|
||||
with _file_ops_lock:
|
||||
_file_ops_cache[task_id] = file_ops
|
||||
return file_ops
|
||||
|
||||
|
||||
def clear_file_ops_cache(task_id: str = None):
|
||||
"""Clear the file operations cache."""
|
||||
with _file_ops_lock:
|
||||
if task_id:
|
||||
_file_ops_cache.pop(task_id, None)
|
||||
else:
|
||||
_file_ops_cache.clear()
|
||||
|
||||
|
||||
def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = "default") -> str:
|
||||
"""Read a file with pagination and line numbers."""
|
||||
try:
|
||||
# Security: block direct reads of internal Hermes cache/index files
|
||||
# to prevent prompt injection via catalog or hub metadata files.
|
||||
import pathlib as _pathlib
|
||||
_resolved = _pathlib.Path(path).expanduser().resolve()
|
||||
_hermes_home = _pathlib.Path("~/.hermes").expanduser().resolve()
|
||||
_blocked_dirs = [
|
||||
_hermes_home / "skills" / ".hub" / "index-cache",
|
||||
_hermes_home / "skills" / ".hub",
|
||||
]
|
||||
for _blocked in _blocked_dirs:
|
||||
try:
|
||||
_resolved.relative_to(_blocked)
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"Access denied: {path} is an internal Hermes cache file "
|
||||
"and cannot be read directly to prevent prompt injection. "
|
||||
"Use the skills_list or skill_view tools instead."
|
||||
)
|
||||
})
|
||||
except ValueError:
|
||||
pass
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.read_file(path, offset, limit)
|
||||
if result.content:
|
||||
result.content = redact_sensitive_text(result.content)
|
||||
result_dict = result.to_dict()
|
||||
|
||||
# Track reads to detect *consecutive* re-read loops.
|
||||
# The counter resets whenever any other tool is called in between,
|
||||
# so only truly back-to-back identical reads trigger warnings/blocks.
|
||||
read_key = ("read", path, offset, limit)
|
||||
with _read_tracker_lock:
|
||||
task_data = _read_tracker.setdefault(task_id, {
|
||||
"last_key": None, "consecutive": 0, "read_history": set(),
|
||||
})
|
||||
task_data["read_history"].add((path, offset, limit))
|
||||
if task_data["last_key"] == read_key:
|
||||
task_data["consecutive"] += 1
|
||||
else:
|
||||
task_data["last_key"] = read_key
|
||||
task_data["consecutive"] = 1
|
||||
count = task_data["consecutive"]
|
||||
|
||||
if count >= 4:
|
||||
# Hard block: stop returning content to break the loop
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"BLOCKED: You have read this exact file region {count} times in a row. "
|
||||
"The content has NOT changed. You already have this information. "
|
||||
"STOP re-reading and proceed with your task."
|
||||
),
|
||||
"path": path,
|
||||
"already_read": count,
|
||||
}, ensure_ascii=False)
|
||||
elif count >= 3:
|
||||
result_dict["_warning"] = (
|
||||
f"You have read this exact file region {count} times consecutively. "
|
||||
"The content has not changed since your last read. Use the information you already have. "
|
||||
"If you are stuck in a loop, stop reading and proceed with writing or responding."
|
||||
)
|
||||
|
||||
return json.dumps(result_dict, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_read_files_summary(task_id: str = "default") -> list:
|
||||
"""Return a list of files read in this session for the given task.
|
||||
|
||||
Used by context compression to preserve file-read history across
|
||||
compression boundaries.
|
||||
"""
|
||||
with _read_tracker_lock:
|
||||
task_data = _read_tracker.get(task_id, {})
|
||||
read_history = task_data.get("read_history", set())
|
||||
seen_paths: dict = {}
|
||||
for (path, offset, limit) in read_history:
|
||||
if path not in seen_paths:
|
||||
seen_paths[path] = []
|
||||
seen_paths[path].append(f"lines {offset}-{offset + limit - 1}")
|
||||
return [
|
||||
{"path": p, "regions": regions}
|
||||
for p, regions in sorted(seen_paths.items())
|
||||
]
|
||||
|
||||
|
||||
def clear_read_tracker(task_id: str = None):
|
||||
"""Clear the read tracker.
|
||||
|
||||
Call with a task_id to clear just that task, or without to clear all.
|
||||
Should be called when a session is destroyed to prevent memory leaks
|
||||
in long-running gateway processes.
|
||||
"""
|
||||
with _read_tracker_lock:
|
||||
if task_id:
|
||||
_read_tracker.pop(task_id, None)
|
||||
else:
|
||||
_read_tracker.clear()
|
||||
|
||||
|
||||
def notify_other_tool_call(task_id: str = "default"):
|
||||
"""Reset consecutive read/search counter for a task.
|
||||
|
||||
Called by the tool dispatcher (model_tools.py) whenever a tool OTHER
|
||||
than read_file / search_files is executed. This ensures we only warn
|
||||
or block on *truly consecutive* repeated reads — if the agent does
|
||||
anything else in between (write, patch, terminal, etc.) the counter
|
||||
resets and the next read is treated as fresh.
|
||||
"""
|
||||
with _read_tracker_lock:
|
||||
task_data = _read_tracker.get(task_id)
|
||||
if task_data:
|
||||
task_data["last_key"] = None
|
||||
task_data["consecutive"] = 0
|
||||
|
||||
|
||||
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||
"""Write content to a file."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.write_file(path, content)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
except Exception as e:
|
||||
if _is_expected_write_exception(e):
|
||||
logger.debug("write_file expected denial: %s: %s", type(e).__name__, e)
|
||||
else:
|
||||
logger.error("write_file error: %s: %s", type(e).__name__, e, exc_info=True)
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||
new_string: str = None, replace_all: bool = False, patch: str = None,
|
||||
task_id: str = "default") -> str:
|
||||
"""Patch a file using replace mode or V4A patch format."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
|
||||
if mode == "replace":
|
||||
if not path:
|
||||
return json.dumps({"error": "path required"})
|
||||
if old_string is None or new_string is None:
|
||||
return json.dumps({"error": "old_string and new_string required"})
|
||||
result = file_ops.patch_replace(path, old_string, new_string, replace_all)
|
||||
elif mode == "patch":
|
||||
if not patch:
|
||||
return json.dumps({"error": "patch content required"})
|
||||
result = file_ops.patch_v4a(patch)
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown mode: {mode}"})
|
||||
|
||||
result_dict = result.to_dict()
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
# Hint when old_string not found — saves iterations where the agent
|
||||
# retries with stale content instead of re-reading the file.
|
||||
if result_dict.get("error") and "Could not find" in str(result_dict["error"]):
|
||||
result_json += "\n\n[Hint: old_string not found. Use read_file to verify the current content, or search_files to locate the text.]"
|
||||
return result_json
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def search_tool(pattern: str, target: str = "content", path: str = ".",
|
||||
file_glob: str = None, limit: int = 50, offset: int = 0,
|
||||
output_mode: str = "content", context: int = 0,
|
||||
task_id: str = "default") -> str:
|
||||
"""Search for content or files."""
|
||||
try:
|
||||
# Track searches to detect *consecutive* repeated search loops.
|
||||
# Include pagination args so users can page through truncated
|
||||
# results without tripping the repeated-search guard.
|
||||
search_key = (
|
||||
"search",
|
||||
pattern,
|
||||
target,
|
||||
str(path),
|
||||
file_glob or "",
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
with _read_tracker_lock:
|
||||
task_data = _read_tracker.setdefault(task_id, {
|
||||
"last_key": None, "consecutive": 0, "read_history": set(),
|
||||
})
|
||||
if task_data["last_key"] == search_key:
|
||||
task_data["consecutive"] += 1
|
||||
else:
|
||||
task_data["last_key"] = search_key
|
||||
task_data["consecutive"] = 1
|
||||
count = task_data["consecutive"]
|
||||
|
||||
if count >= 4:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"BLOCKED: You have run this exact search {count} times in a row. "
|
||||
"The results have NOT changed. You already have this information. "
|
||||
"STOP re-searching and proceed with your task."
|
||||
),
|
||||
"pattern": pattern,
|
||||
"already_searched": count,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.search(
|
||||
pattern=pattern, path=path, target=target, file_glob=file_glob,
|
||||
limit=limit, offset=offset, output_mode=output_mode, context=context
|
||||
)
|
||||
if hasattr(result, 'matches'):
|
||||
for m in result.matches:
|
||||
if hasattr(m, 'content') and m.content:
|
||||
m.content = redact_sensitive_text(m.content)
|
||||
result_dict = result.to_dict()
|
||||
|
||||
if count >= 3:
|
||||
result_dict["_warning"] = (
|
||||
f"You have run this exact search {count} times consecutively. "
|
||||
"The results have not changed. Use the information you already have."
|
||||
)
|
||||
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
# Hint when results were truncated — explicit next offset is clearer
|
||||
# than relying on the model to infer it from total_count vs match count.
|
||||
if result_dict.get("truncated"):
|
||||
next_offset = offset + limit
|
||||
result_json += f"\n\n[Hint: Results truncated. Use offset={next_offset} to see more, or narrow with a more specific pattern or file_glob.]"
|
||||
return result_json
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
FILE_TOOLS = [
|
||||
{"name": "read_file", "function": read_file_tool},
|
||||
{"name": "write_file", "function": write_file_tool},
|
||||
{"name": "patch", "function": patch_tool},
|
||||
{"name": "search_files", "function": search_tool}
|
||||
]
|
||||
|
||||
|
||||
def get_file_tools():
|
||||
"""Get the list of file tool definitions."""
|
||||
return FILE_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schemas + Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
from tools.registry import registry
|
||||
|
||||
|
||||
def _check_file_reqs():
|
||||
"""Lazy wrapper to avoid circular import with tools/__init__.py."""
|
||||
from tools import check_file_requirements
|
||||
return check_file_requirements()
|
||||
|
||||
READ_FILE_SCHEMA = {
|
||||
"name": "read_file",
|
||||
"description": "Read a text file with line numbers and pagination. Use this instead of cat/head/tail in terminal. Output format: 'LINE_NUM|CONTENT'. Suggests similar filenames if not found. Use offset and limit for large files. NOTE: Cannot read images or binary files — use vision_analyze for images.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the file to read (absolute, relative, or ~/path)"},
|
||||
"offset": {"type": "integer", "description": "Line number to start reading from (1-indexed, default: 1)", "default": 1, "minimum": 1},
|
||||
"limit": {"type": "integer", "description": "Maximum number of lines to read (default: 500, max: 2000)", "default": 500, "maximum": 2000}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
|
||||
WRITE_FILE_SCHEMA = {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file, completely replacing existing content. Use this instead of echo/cat heredoc in terminal. Creates parent directories automatically. OVERWRITES the entire file — use 'patch' for targeted edits.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"},
|
||||
"content": {"type": "string", "description": "Complete content to write to the file"}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
}
|
||||
|
||||
PATCH_SCHEMA = {
|
||||
"name": "patch",
|
||||
"description": "Targeted find-and-replace edits in files. Use this instead of sed/awk in terminal. Uses fuzzy matching (9 strategies) so minor whitespace/indentation differences won't break it. Returns a unified diff. Auto-runs syntax checks after editing.\n\nReplace mode (default): find a unique string and replace it.\nPatch mode: apply V4A multi-file patches for bulk changes.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {"type": "string", "enum": ["replace", "patch"], "description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches", "default": "replace"},
|
||||
"path": {"type": "string", "description": "File path to edit (required for 'replace' mode)"},
|
||||
"old_string": {"type": "string", "description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness."},
|
||||
"new_string": {"type": "string", "description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text."},
|
||||
"replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match (default: false)", "default": False},
|
||||
"patch": {"type": "string", "description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch"}
|
||||
},
|
||||
"required": ["mode"]
|
||||
}
|
||||
}
|
||||
|
||||
SEARCH_FILES_SCHEMA = {
|
||||
"name": "search_files",
|
||||
"description": "Search file contents or find files by name. Use this instead of grep/rg/find/ls in terminal. Ripgrep-backed, faster than shell equivalents.\n\nContent search (target='content'): Regex search inside files. Output modes: full matches with line numbers, file paths only, or match counts.\n\nFile search (target='files'): Find files by glob pattern (e.g., '*.py', '*config*'). Also use this instead of ls — results sorted by modification time.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern for content search, or glob pattern (e.g., '*.py') for file search"},
|
||||
"target": {"type": "string", "enum": ["content", "files"], "description": "'content' searches inside file contents, 'files' searches for files by name", "default": "content"},
|
||||
"path": {"type": "string", "description": "Directory or file to search in (default: current working directory)", "default": "."},
|
||||
"file_glob": {"type": "string", "description": "Filter files by pattern in grep mode (e.g., '*.py' to only search Python files)"},
|
||||
"limit": {"type": "integer", "description": "Maximum number of results to return (default: 50)", "default": 50},
|
||||
"offset": {"type": "integer", "description": "Skip first N results for pagination (default: 0)", "default": 0},
|
||||
"output_mode": {"type": "string", "enum": ["content", "files_only", "count"], "description": "Output format for grep mode: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file", "default": "content"},
|
||||
"context": {"type": "integer", "description": "Number of context lines before and after each match (grep mode only)", "default": 0}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _handle_read_file(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return read_file_tool(path=args.get("path", ""), offset=args.get("offset", 1), limit=args.get("limit", 500), task_id=tid)
|
||||
|
||||
|
||||
def _handle_write_file(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return write_file_tool(path=args.get("path", ""), content=args.get("content", ""), task_id=tid)
|
||||
|
||||
|
||||
def _handle_patch(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
return patch_tool(
|
||||
mode=args.get("mode", "replace"), path=args.get("path"),
|
||||
old_string=args.get("old_string"), new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False), patch=args.get("patch"), task_id=tid)
|
||||
|
||||
|
||||
def _handle_search_files(args, **kw):
|
||||
tid = kw.get("task_id") or "default"
|
||||
target_map = {"grep": "content", "find": "files"}
|
||||
raw_target = args.get("target", "content")
|
||||
target = target_map.get(raw_target, raw_target)
|
||||
return search_tool(
|
||||
pattern=args.get("pattern", ""), target=target, path=args.get("path", "."),
|
||||
file_glob=args.get("file_glob"), limit=args.get("limit", 50), offset=args.get("offset", 0),
|
||||
output_mode=args.get("output_mode", "content"), context=args.get("context", 0), task_id=tid)
|
||||
|
||||
|
||||
registry.register(name="read_file", toolset="file", schema=READ_FILE_SCHEMA, handler=_handle_read_file, check_fn=_check_file_reqs, emoji="📖")
|
||||
registry.register(name="write_file", toolset="file", schema=WRITE_FILE_SCHEMA, handler=_handle_write_file, check_fn=_check_file_reqs, emoji="✍️")
|
||||
registry.register(name="patch", toolset="file", schema=PATCH_SCHEMA, handler=_handle_patch, check_fn=_check_file_reqs, emoji="🔧")
|
||||
registry.register(name="search_files", toolset="file", schema=SEARCH_FILES_SCHEMA, handler=_handle_search_files, check_fn=_check_file_reqs, emoji="🔎")
|
||||
487
hermes_code/tools/fuzzy_match.py
Normal file
487
hermes_code/tools/fuzzy_match.py
Normal file
|
|
@ -0,0 +1,487 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fuzzy Matching Module for File Operations
|
||||
|
||||
Implements a multi-strategy matching chain to robustly find and replace text,
|
||||
accommodating variations in whitespace, indentation, and escaping common
|
||||
in LLM-generated code.
|
||||
|
||||
The 8-strategy chain (inspired by OpenCode), tried in order:
|
||||
1. Exact match - Direct string comparison
|
||||
2. Line-trimmed - Strip leading/trailing whitespace per line
|
||||
3. Whitespace normalized - Collapse multiple spaces/tabs to single space
|
||||
4. Indentation flexible - Ignore indentation differences entirely
|
||||
5. Escape normalized - Convert \\n literals to actual newlines
|
||||
6. Trimmed boundary - Trim first/last line whitespace only
|
||||
7. Block anchor - Match first+last lines, use similarity for middle
|
||||
8. Context-aware - 50% line similarity threshold
|
||||
|
||||
Multi-occurrence matching is handled via the replace_all flag.
|
||||
|
||||
Usage:
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
new_content, match_count, error = fuzzy_find_and_replace(
|
||||
content="def foo():\\n pass",
|
||||
old_string="def foo():",
|
||||
new_string="def bar():",
|
||||
replace_all=False
|
||||
)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
UNICODE_MAP = {
|
||||
"\u201c": '"', "\u201d": '"', # smart double quotes
|
||||
"\u2018": "'", "\u2019": "'", # smart single quotes
|
||||
"\u2014": "--", "\u2013": "-", # em/en dashes
|
||||
"\u2026": "...", "\u00a0": " ", # ellipsis and non-breaking space
|
||||
}
|
||||
|
||||
def _unicode_normalize(text: str) -> str:
|
||||
"""Normalizes Unicode characters to their standard ASCII equivalents."""
|
||||
for char, repl in UNICODE_MAP.items():
|
||||
text = text.replace(char, repl)
|
||||
return text
|
||||
|
||||
|
||||
def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
||||
replace_all: bool = False) -> Tuple[str, int, Optional[str]]:
|
||||
"""
|
||||
Find and replace text using a chain of increasingly fuzzy matching strategies.
|
||||
|
||||
Args:
|
||||
content: The file content to search in
|
||||
old_string: The text to find
|
||||
new_string: The replacement text
|
||||
replace_all: If True, replace all occurrences; if False, require uniqueness
|
||||
|
||||
Returns:
|
||||
Tuple of (new_content, match_count, error_message)
|
||||
- If successful: (modified_content, number_of_replacements, None)
|
||||
- If failed: (original_content, 0, error_description)
|
||||
"""
|
||||
if not old_string:
|
||||
return content, 0, "old_string cannot be empty"
|
||||
|
||||
if old_string == new_string:
|
||||
return content, 0, "old_string and new_string are identical"
|
||||
|
||||
# Try each matching strategy in order
|
||||
strategies: List[Tuple[str, Callable]] = [
|
||||
("exact", _strategy_exact),
|
||||
("line_trimmed", _strategy_line_trimmed),
|
||||
("whitespace_normalized", _strategy_whitespace_normalized),
|
||||
("indentation_flexible", _strategy_indentation_flexible),
|
||||
("escape_normalized", _strategy_escape_normalized),
|
||||
("trimmed_boundary", _strategy_trimmed_boundary),
|
||||
("block_anchor", _strategy_block_anchor),
|
||||
("context_aware", _strategy_context_aware),
|
||||
]
|
||||
|
||||
for strategy_name, strategy_fn in strategies:
|
||||
matches = strategy_fn(content, old_string)
|
||||
|
||||
if matches:
|
||||
# Found matches with this strategy
|
||||
if len(matches) > 1 and not replace_all:
|
||||
return content, 0, (
|
||||
f"Found {len(matches)} matches for old_string. "
|
||||
f"Provide more context to make it unique, or use replace_all=True."
|
||||
)
|
||||
|
||||
# Perform replacement
|
||||
new_content = _apply_replacements(content, matches, new_string)
|
||||
return new_content, len(matches), None
|
||||
|
||||
# No strategy found a match
|
||||
return content, 0, "Could not find a match for old_string in the file"
|
||||
|
||||
|
||||
def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str:
|
||||
"""
|
||||
Apply replacements at the given positions.
|
||||
|
||||
Args:
|
||||
content: Original content
|
||||
matches: List of (start, end) positions to replace
|
||||
new_string: Replacement text
|
||||
|
||||
Returns:
|
||||
Content with replacements applied
|
||||
"""
|
||||
# Sort matches by position (descending) to replace from end to start
|
||||
# This preserves positions of earlier matches
|
||||
sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True)
|
||||
|
||||
result = content
|
||||
for start, end in sorted_matches:
|
||||
result = result[:start] + new_string + result[end:]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Matching Strategies
|
||||
# =============================================================================
|
||||
|
||||
def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""Strategy 1: Exact string match."""
|
||||
matches = []
|
||||
start = 0
|
||||
while True:
|
||||
pos = content.find(pattern, start)
|
||||
if pos == -1:
|
||||
break
|
||||
matches.append((pos, pos + len(pattern)))
|
||||
start = pos + 1
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 2: Match with line-by-line whitespace trimming.
|
||||
|
||||
Strips leading/trailing whitespace from each line before matching.
|
||||
"""
|
||||
# Normalize pattern and content by trimming each line
|
||||
pattern_lines = [line.strip() for line in pattern.split('\n')]
|
||||
pattern_normalized = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
content_normalized_lines = [line.strip() for line in content_lines]
|
||||
|
||||
# Build mapping from normalized positions back to original positions
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_normalized_lines,
|
||||
pattern, pattern_normalized
|
||||
)
|
||||
|
||||
|
||||
def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 3: Collapse multiple whitespace to single space.
|
||||
"""
|
||||
def normalize(s):
|
||||
# Collapse multiple spaces/tabs to single space, preserve newlines
|
||||
return re.sub(r'[ \t]+', ' ', s)
|
||||
|
||||
pattern_normalized = normalize(pattern)
|
||||
content_normalized = normalize(content)
|
||||
|
||||
# Find in normalized, map back to original
|
||||
matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized)
|
||||
|
||||
if not matches_in_normalized:
|
||||
return []
|
||||
|
||||
# Map positions back to original content
|
||||
return _map_normalized_positions(content, content_normalized, matches_in_normalized)
|
||||
|
||||
|
||||
def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 4: Ignore indentation differences entirely.
|
||||
|
||||
Strips all leading whitespace from lines before matching.
|
||||
"""
|
||||
def strip_indent(s):
|
||||
return '\n'.join(line.lstrip() for line in s.split('\n'))
|
||||
|
||||
pattern_stripped = strip_indent(pattern)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
content_stripped_lines = [line.lstrip() for line in content_lines]
|
||||
pattern_lines = [line.lstrip() for line in pattern.split('\n')]
|
||||
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_stripped_lines,
|
||||
pattern, '\n'.join(pattern_lines)
|
||||
)
|
||||
|
||||
|
||||
def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 5: Convert escape sequences to actual characters.
|
||||
|
||||
Handles \\n -> newline, \\t -> tab, etc.
|
||||
"""
|
||||
def unescape(s):
|
||||
# Convert common escape sequences
|
||||
return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r')
|
||||
|
||||
pattern_unescaped = unescape(pattern)
|
||||
|
||||
if pattern_unescaped == pattern:
|
||||
# No escapes to convert, skip this strategy
|
||||
return []
|
||||
|
||||
return _strategy_exact(content, pattern_unescaped)
|
||||
|
||||
|
||||
def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 6: Trim whitespace from first and last lines only.
|
||||
|
||||
Useful when the pattern boundaries have whitespace differences.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
# Trim only first and last lines
|
||||
pattern_lines[0] = pattern_lines[0].strip()
|
||||
if len(pattern_lines) > 1:
|
||||
pattern_lines[-1] = pattern_lines[-1].strip()
|
||||
|
||||
modified_pattern = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
|
||||
# Search through content for matching block
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
# Trim first and last of this block
|
||||
check_lines = block_lines.copy()
|
||||
check_lines[0] = check_lines[0].strip()
|
||||
if len(check_lines) > 1:
|
||||
check_lines[-1] = check_lines[-1].strip()
|
||||
|
||||
if '\n'.join(check_lines) == modified_pattern:
|
||||
# Found match - calculate original positions
|
||||
start_pos, end_pos = _calculate_line_positions(
|
||||
content_lines, i, i + pattern_line_count, len(content)
|
||||
)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 7: Match by anchoring on first and last lines.
|
||||
Adjusted with permissive thresholds and unicode normalization.
|
||||
"""
|
||||
# Normalize both strings for comparison while keeping original content for offset calculation
|
||||
norm_pattern = _unicode_normalize(pattern)
|
||||
norm_content = _unicode_normalize(content)
|
||||
|
||||
pattern_lines = norm_pattern.split('\n')
|
||||
if len(pattern_lines) < 2:
|
||||
return []
|
||||
|
||||
first_line = pattern_lines[0].strip()
|
||||
last_line = pattern_lines[-1].strip()
|
||||
|
||||
# Use normalized lines for matching logic
|
||||
norm_content_lines = norm_content.split('\n')
|
||||
# BUT use original lines for calculating start/end positions to prevent index shift
|
||||
orig_content_lines = content.split('\n')
|
||||
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
potential_matches = []
|
||||
for i in range(len(norm_content_lines) - pattern_line_count + 1):
|
||||
if (norm_content_lines[i].strip() == first_line and
|
||||
norm_content_lines[i + pattern_line_count - 1].strip() == last_line):
|
||||
potential_matches.append(i)
|
||||
|
||||
matches = []
|
||||
candidate_count = len(potential_matches)
|
||||
|
||||
# Thresholding logic: 0.10 for unique matches (max flexibility), 0.30 for multiple candidates
|
||||
threshold = 0.10 if candidate_count == 1 else 0.30
|
||||
|
||||
for i in potential_matches:
|
||||
if pattern_line_count <= 2:
|
||||
similarity = 1.0
|
||||
else:
|
||||
# Compare normalized middle sections
|
||||
content_middle = '\n'.join(norm_content_lines[i+1:i+pattern_line_count-1])
|
||||
pattern_middle = '\n'.join(pattern_lines[1:-1])
|
||||
similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio()
|
||||
|
||||
if similarity >= threshold:
|
||||
# Calculate positions using ORIGINAL lines to ensure correct character offsets in the file
|
||||
start_pos, end_pos = _calculate_line_positions(
|
||||
orig_content_lines, i, i + pattern_line_count, len(content)
|
||||
)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 8: Line-by-line similarity with 50% threshold.
|
||||
|
||||
Finds blocks where at least 50% of lines have high similarity.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
content_lines = content.split('\n')
|
||||
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
# Calculate line-by-line similarity
|
||||
high_similarity_count = 0
|
||||
for p_line, c_line in zip(pattern_lines, block_lines):
|
||||
sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio()
|
||||
if sim >= 0.80:
|
||||
high_similarity_count += 1
|
||||
|
||||
# Need at least 50% of lines to have high similarity
|
||||
if high_similarity_count >= len(pattern_lines) * 0.5:
|
||||
start_pos, end_pos = _calculate_line_positions(
|
||||
content_lines, i, i + pattern_line_count, len(content)
|
||||
)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def _calculate_line_positions(content_lines: List[str], start_line: int,
|
||||
end_line: int, content_length: int) -> Tuple[int, int]:
|
||||
"""Calculate start and end character positions from line indices.
|
||||
|
||||
Args:
|
||||
content_lines: List of lines (without newlines)
|
||||
start_line: Starting line index (0-based)
|
||||
end_line: Ending line index (exclusive, 0-based)
|
||||
content_length: Total length of the original content string
|
||||
|
||||
Returns:
|
||||
Tuple of (start_pos, end_pos) in the original content
|
||||
"""
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:start_line])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:end_line]) - 1
|
||||
if end_pos >= content_length:
|
||||
end_pos = content_length
|
||||
return start_pos, end_pos
|
||||
|
||||
|
||||
def _find_normalized_matches(content: str, content_lines: List[str],
|
||||
content_normalized_lines: List[str],
|
||||
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Find matches in normalized content and map back to original positions.
|
||||
|
||||
Args:
|
||||
content: Original content string
|
||||
content_lines: Original content split by lines
|
||||
content_normalized_lines: Normalized content lines
|
||||
pattern: Original pattern
|
||||
pattern_normalized: Normalized pattern
|
||||
|
||||
Returns:
|
||||
List of (start, end) positions in the original content
|
||||
"""
|
||||
pattern_norm_lines = pattern_normalized.split('\n')
|
||||
num_pattern_lines = len(pattern_norm_lines)
|
||||
|
||||
matches = []
|
||||
|
||||
for i in range(len(content_normalized_lines) - num_pattern_lines + 1):
|
||||
# Check if this block matches
|
||||
block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines])
|
||||
|
||||
if block == pattern_normalized:
|
||||
# Found a match - calculate original positions
|
||||
start_pos, end_pos = _calculate_line_positions(
|
||||
content_lines, i, i + num_pattern_lines, len(content)
|
||||
)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _map_normalized_positions(original: str, normalized: str,
|
||||
normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Map positions from normalized string back to original.
|
||||
|
||||
This is a best-effort mapping that works for whitespace normalization.
|
||||
"""
|
||||
if not normalized_matches:
|
||||
return []
|
||||
|
||||
# Build character mapping from normalized to original
|
||||
orig_to_norm = [] # orig_to_norm[i] = position in normalized
|
||||
|
||||
orig_idx = 0
|
||||
norm_idx = 0
|
||||
|
||||
while orig_idx < len(original) and norm_idx < len(normalized):
|
||||
if original[orig_idx] == normalized[norm_idx]:
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ':
|
||||
# Original has space/tab, normalized collapsed to space
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
# Don't advance norm_idx yet - wait until all whitespace consumed
|
||||
if orig_idx < len(original) and original[orig_idx] not in ' \t':
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in ' \t':
|
||||
# Extra whitespace in original
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
else:
|
||||
# Mismatch - shouldn't happen with our normalization
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
|
||||
# Fill remaining
|
||||
while orig_idx < len(original):
|
||||
orig_to_norm.append(len(normalized))
|
||||
orig_idx += 1
|
||||
|
||||
# Reverse mapping: for each normalized position, find original range
|
||||
norm_to_orig_start = {}
|
||||
norm_to_orig_end = {}
|
||||
|
||||
for orig_pos, norm_pos in enumerate(orig_to_norm):
|
||||
if norm_pos not in norm_to_orig_start:
|
||||
norm_to_orig_start[norm_pos] = orig_pos
|
||||
norm_to_orig_end[norm_pos] = orig_pos
|
||||
|
||||
# Map matches
|
||||
original_matches = []
|
||||
for norm_start, norm_end in normalized_matches:
|
||||
# Find original start
|
||||
if norm_start in norm_to_orig_start:
|
||||
orig_start = norm_to_orig_start[norm_start]
|
||||
else:
|
||||
# Find nearest
|
||||
orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start)
|
||||
|
||||
# Find original end
|
||||
if norm_end - 1 in norm_to_orig_end:
|
||||
orig_end = norm_to_orig_end[norm_end - 1] + 1
|
||||
else:
|
||||
orig_end = orig_start + (norm_end - norm_start)
|
||||
|
||||
# Expand to include trailing whitespace that was normalized
|
||||
while orig_end < len(original) and original[orig_end] in ' \t':
|
||||
orig_end += 1
|
||||
|
||||
original_matches.append((orig_start, min(orig_end, len(original))))
|
||||
|
||||
return original_matches
|
||||
490
hermes_code/tools/homeassistant_tool.py
Normal file
490
hermes_code/tools/homeassistant_tool.py
Normal file
|
|
@ -0,0 +1,490 @@
|
|||
"""Home Assistant tool for controlling smart home devices via REST API.
|
||||
|
||||
Registers four LLM-callable tools:
|
||||
- ``ha_list_entities`` -- list/filter entities by domain or area
|
||||
- ``ha_get_state`` -- get detailed state of a single entity
|
||||
- ``ha_list_services`` -- list available services (actions) per domain
|
||||
- ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.)
|
||||
|
||||
Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var.
|
||||
The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Kept for backward compatibility (e.g. test monkeypatching); prefer _get_config().
|
||||
_HASS_URL: str = ""
|
||||
_HASS_TOKEN: str = ""
|
||||
|
||||
|
||||
def _get_config():
|
||||
"""Return (hass_url, hass_token) from env vars at call time."""
|
||||
return (
|
||||
(_HASS_URL or os.getenv("HASS_URL", "http://homeassistant.local:8123")).rstrip("/"),
|
||||
_HASS_TOKEN or os.getenv("HASS_TOKEN", ""),
|
||||
)
|
||||
|
||||
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
|
||||
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
|
||||
|
||||
# Service domains blocked for security -- these allow arbitrary code/command
|
||||
# execution on the HA host or enable SSRF attacks on the local network.
|
||||
# HA provides zero service-level access control; all safety must be in our layer.
|
||||
_BLOCKED_DOMAINS = frozenset({
|
||||
"shell_command", # arbitrary shell commands as root in HA container
|
||||
"command_line", # sensors/switches that execute shell commands
|
||||
"python_script", # sandboxed but can escalate via hass.services.call()
|
||||
"pyscript", # scripting integration with broader access
|
||||
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
||||
"rest_command", # HTTP requests from HA server (SSRF vector)
|
||||
})
|
||||
|
||||
|
||||
def _get_headers(token: str = "") -> Dict[str, str]:
|
||||
"""Return authorization headers for HA REST API."""
|
||||
if not token:
|
||||
_, token = _get_config()
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async helpers (called from sync handlers via run_until_complete)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _filter_and_summarize(
|
||||
states: list,
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Filter raw HA states by domain/area and return a compact summary."""
|
||||
if domain:
|
||||
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
|
||||
|
||||
if area:
|
||||
area_lower = area.lower()
|
||||
states = [
|
||||
s for s in states
|
||||
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
|
||||
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
|
||||
]
|
||||
|
||||
entities = []
|
||||
for s in states:
|
||||
entities.append({
|
||||
"entity_id": s["entity_id"],
|
||||
"state": s["state"],
|
||||
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
||||
})
|
||||
|
||||
return {"count": len(entities), "entities": entities}
|
||||
|
||||
|
||||
async def _async_list_entities(
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch entity states from HA and optionally filter by domain/area."""
|
||||
import aiohttp
|
||||
|
||||
hass_url, hass_token = _get_config()
|
||||
url = f"{hass_url}/api/states"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||||
resp.raise_for_status()
|
||||
states = await resp.json()
|
||||
|
||||
return _filter_and_summarize(states, domain, area)
|
||||
|
||||
|
||||
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
|
||||
"""Fetch detailed state of a single entity."""
|
||||
import aiohttp
|
||||
|
||||
hass_url, hass_token = _get_config()
|
||||
url = f"{hass_url}/api/states/{entity_id}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
|
||||
return {
|
||||
"entity_id": data["entity_id"],
|
||||
"state": data["state"],
|
||||
"attributes": data.get("attributes", {}),
|
||||
"last_changed": data.get("last_changed"),
|
||||
"last_updated": data.get("last_updated"),
|
||||
}
|
||||
|
||||
|
||||
def _build_service_payload(
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the JSON payload for a HA service call."""
|
||||
payload: Dict[str, Any] = {}
|
||||
if data:
|
||||
payload.update(data)
|
||||
# entity_id parameter takes precedence over data["entity_id"]
|
||||
if entity_id:
|
||||
payload["entity_id"] = entity_id
|
||||
return payload
|
||||
|
||||
|
||||
def _parse_service_response(
|
||||
domain: str,
|
||||
service: str,
|
||||
result: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Parse HA service call response into a structured result."""
|
||||
affected = []
|
||||
if isinstance(result, list):
|
||||
for s in result:
|
||||
affected.append({
|
||||
"entity_id": s.get("entity_id", ""),
|
||||
"state": s.get("state", ""),
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"service": f"{domain}.{service}",
|
||||
"affected_entities": affected,
|
||||
}
|
||||
|
||||
|
||||
async def _async_call_service(
|
||||
domain: str,
|
||||
service: str,
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Call a Home Assistant service."""
|
||||
import aiohttp
|
||||
|
||||
hass_url, hass_token = _get_config()
|
||||
url = f"{hass_url}/api/services/{domain}/{service}"
|
||||
payload = _build_service_payload(entity_id, data)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
headers=_get_headers(hass_token),
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
result = await resp.json()
|
||||
|
||||
return _parse_service_response(domain, service, result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync wrappers (handler signature: (args, **kw) -> str)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync handler."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# Already inside an event loop -- create a new thread
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result(timeout=30)
|
||||
else:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _handle_list_entities(args: dict, **kw) -> str:
|
||||
"""Handler for ha_list_entities tool."""
|
||||
domain = args.get("domain")
|
||||
area = args.get("area")
|
||||
try:
|
||||
result = _run_async(_async_list_entities(domain=domain, area=area))
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
logger.error("ha_list_entities error: %s", e)
|
||||
return json.dumps({"error": f"Failed to list entities: {e}"})
|
||||
|
||||
|
||||
def _handle_get_state(args: dict, **kw) -> str:
|
||||
"""Handler for ha_get_state tool."""
|
||||
entity_id = args.get("entity_id", "")
|
||||
if not entity_id:
|
||||
return json.dumps({"error": "Missing required parameter: entity_id"})
|
||||
if not _ENTITY_ID_RE.match(entity_id):
|
||||
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
|
||||
try:
|
||||
result = _run_async(_async_get_state(entity_id))
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
logger.error("ha_get_state error: %s", e)
|
||||
return json.dumps({"error": f"Failed to get state for {entity_id}: {e}"})
|
||||
|
||||
|
||||
def _handle_call_service(args: dict, **kw) -> str:
|
||||
"""Handler for ha_call_service tool."""
|
||||
domain = args.get("domain", "")
|
||||
service = args.get("service", "")
|
||||
if not domain or not service:
|
||||
return json.dumps({"error": "Missing required parameters: domain and service"})
|
||||
|
||||
if domain in _BLOCKED_DOMAINS:
|
||||
return json.dumps({
|
||||
"error": f"Service domain '{domain}' is blocked for security. "
|
||||
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
||||
})
|
||||
|
||||
entity_id = args.get("entity_id")
|
||||
if entity_id and not _ENTITY_ID_RE.match(entity_id):
|
||||
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
|
||||
|
||||
data = args.get("data")
|
||||
try:
|
||||
result = _run_async(_async_call_service(domain, service, entity_id, data))
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
logger.error("ha_call_service error: %s", e)
|
||||
return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List services
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Fetch available services from HA and optionally filter by domain."""
|
||||
import aiohttp
|
||||
|
||||
hass_url, hass_token = _get_config()
|
||||
url = f"{hass_url}/api/services"
|
||||
headers = {"Authorization": f"Bearer {hass_token}", "Content-Type": "application/json"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||||
resp.raise_for_status()
|
||||
services = await resp.json()
|
||||
|
||||
if domain:
|
||||
services = [s for s in services if s.get("domain") == domain]
|
||||
|
||||
# Compact the output for context efficiency
|
||||
result = []
|
||||
for svc_domain in services:
|
||||
d = svc_domain.get("domain", "")
|
||||
domain_services = {}
|
||||
for svc_name, svc_info in svc_domain.get("services", {}).items():
|
||||
svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")}
|
||||
fields = svc_info.get("fields", {})
|
||||
if fields:
|
||||
svc_entry["fields"] = {
|
||||
k: v.get("description", "") for k, v in fields.items()
|
||||
if isinstance(v, dict)
|
||||
}
|
||||
domain_services[svc_name] = svc_entry
|
||||
result.append({"domain": d, "services": domain_services})
|
||||
|
||||
return {"count": len(result), "domains": result}
|
||||
|
||||
|
||||
def _handle_list_services(args: dict, **kw) -> str:
|
||||
"""Handler for ha_list_services tool."""
|
||||
domain = args.get("domain")
|
||||
try:
|
||||
result = _run_async(_async_list_services(domain=domain))
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
logger.error("ha_list_services error: %s", e)
|
||||
return json.dumps({"error": f"Failed to list services: {e}"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _check_ha_available() -> bool:
|
||||
"""Tool is only available when HASS_TOKEN is set."""
|
||||
return bool(os.getenv("HASS_TOKEN"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
HA_LIST_ENTITIES_SCHEMA = {
|
||||
"name": "ha_list_entities",
|
||||
"description": (
|
||||
"List Home Assistant entities. Optionally filter by domain "
|
||||
"(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) "
|
||||
"or by area name (living room, kitchen, bedroom, etc.)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Entity domain to filter by (e.g. 'light', 'switch', 'climate', "
|
||||
"'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). "
|
||||
"Omit to list all entities."
|
||||
),
|
||||
},
|
||||
"area": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Area/room name to filter by (e.g. 'living room', 'kitchen'). "
|
||||
"Matches against entity friendly names. Omit to list all."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
HA_GET_STATE_SCHEMA = {
|
||||
"name": "ha_get_state",
|
||||
"description": (
|
||||
"Get the detailed state of a single Home Assistant entity, including all "
|
||||
"attributes (brightness, color, temperature setpoint, sensor readings, etc.)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The entity ID to query (e.g. 'light.living_room', "
|
||||
"'climate.thermostat', 'sensor.temperature')."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["entity_id"],
|
||||
},
|
||||
}
|
||||
|
||||
HA_LIST_SERVICES_SCHEMA = {
|
||||
"name": "ha_list_services",
|
||||
"description": (
|
||||
"List available Home Assistant services (actions) for device control. "
|
||||
"Shows what actions can be performed on each device type and what "
|
||||
"parameters they accept. Use this to discover how to control devices "
|
||||
"found via ha_list_entities."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Filter by domain (e.g. 'light', 'climate', 'switch'). "
|
||||
"Omit to list services for all domains."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
HA_CALL_SERVICE_SCHEMA = {
|
||||
"name": "ha_call_service",
|
||||
"description": (
|
||||
"Call a Home Assistant service to control a device. Use ha_list_services "
|
||||
"to discover available services and their parameters for each domain."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Service domain (e.g. 'light', 'switch', 'climate', "
|
||||
"'cover', 'media_player', 'fan', 'scene', 'script')."
|
||||
),
|
||||
},
|
||||
"service": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Service name (e.g. 'turn_on', 'turn_off', 'toggle', "
|
||||
"'set_temperature', 'set_hvac_mode', 'open_cover', "
|
||||
"'close_cover', 'set_volume_level')."
|
||||
),
|
||||
},
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Target entity ID (e.g. 'light.living_room'). "
|
||||
"Some services (like scene.turn_on) may not need this."
|
||||
),
|
||||
},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Additional service data. Examples: "
|
||||
'{"brightness": 255, "color_name": "blue"} for lights, '
|
||||
'{"temperature": 22, "hvac_mode": "heat"} for climate, '
|
||||
'{"volume_level": 0.5} for media players.'
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["domain", "service"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="ha_list_entities",
|
||||
toolset="homeassistant",
|
||||
schema=HA_LIST_ENTITIES_SCHEMA,
|
||||
handler=_handle_list_entities,
|
||||
check_fn=_check_ha_available,
|
||||
emoji="🏠",
|
||||
)
|
||||
|
||||
registry.register(
|
||||
name="ha_get_state",
|
||||
toolset="homeassistant",
|
||||
schema=HA_GET_STATE_SCHEMA,
|
||||
handler=_handle_get_state,
|
||||
check_fn=_check_ha_available,
|
||||
emoji="🏠",
|
||||
)
|
||||
|
||||
registry.register(
|
||||
name="ha_list_services",
|
||||
toolset="homeassistant",
|
||||
schema=HA_LIST_SERVICES_SCHEMA,
|
||||
handler=_handle_list_services,
|
||||
check_fn=_check_ha_available,
|
||||
emoji="🏠",
|
||||
)
|
||||
|
||||
registry.register(
|
||||
name="ha_call_service",
|
||||
toolset="homeassistant",
|
||||
schema=HA_CALL_SERVICE_SCHEMA,
|
||||
handler=_handle_call_service,
|
||||
check_fn=_check_ha_available,
|
||||
emoji="🏠",
|
||||
)
|
||||
264
hermes_code/tools/honcho_tools.py
Normal file
264
hermes_code/tools/honcho_tools.py
Normal file
|
|
@ -0,0 +1,264 @@
|
|||
"""Honcho tools for user context retrieval.
|
||||
|
||||
Registers three complementary tools, ordered by capability:
|
||||
|
||||
honcho_context — dialectic Q&A (LLM-powered, direct answers)
|
||||
honcho_search — semantic search (fast, no LLM, raw excerpts)
|
||||
honcho_profile — peer card (fast, no LLM, structured facts)
|
||||
|
||||
Use honcho_context when you need Honcho to synthesize an answer.
|
||||
Use honcho_search or honcho_profile when you want raw data to reason
|
||||
over yourself.
|
||||
|
||||
The session key is injected at runtime by the agent loop via
|
||||
``set_session_context()``.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Module-level state (injected by AIAgent at init time) ──
|
||||
|
||||
_session_manager = None # HonchoSessionManager instance
|
||||
_session_key: str | None = None # Current session key (e.g., "telegram:123456")
|
||||
|
||||
|
||||
def set_session_context(session_manager, session_key: str) -> None:
|
||||
"""Register the active Honcho session manager and key.
|
||||
|
||||
Called by AIAgent.__init__ when Honcho is enabled.
|
||||
"""
|
||||
global _session_manager, _session_key
|
||||
_session_manager = session_manager
|
||||
_session_key = session_key
|
||||
|
||||
|
||||
def clear_session_context() -> None:
|
||||
"""Clear session context (for testing or shutdown)."""
|
||||
global _session_manager, _session_key
|
||||
_session_manager = None
|
||||
_session_key = None
|
||||
|
||||
|
||||
# ── Availability check ──
|
||||
|
||||
def _check_honcho_available() -> bool:
|
||||
"""Tool is only available when Honcho is active."""
|
||||
return _session_manager is not None and _session_key is not None
|
||||
|
||||
|
||||
def _resolve_session_context(**kwargs):
|
||||
"""Prefer the calling agent's session context over module-global fallback."""
|
||||
session_manager = kwargs.get("honcho_manager") or _session_manager
|
||||
session_key = kwargs.get("honcho_session_key") or _session_key
|
||||
return session_manager, session_key
|
||||
|
||||
|
||||
# ── honcho_profile ──
|
||||
|
||||
_PROFILE_SCHEMA = {
|
||||
"name": "honcho_profile",
|
||||
"description": (
|
||||
"Retrieve the user's peer card from Honcho — a curated list of key facts "
|
||||
"about them (name, role, preferences, communication style, patterns). "
|
||||
"Fast, no LLM reasoning, minimal cost. "
|
||||
"Use this at conversation start or when you need a quick factual snapshot. "
|
||||
"Use honcho_context instead when you need Honcho to synthesize an answer."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _handle_honcho_profile(args: dict, **kw) -> str:
|
||||
session_manager, session_key = _resolve_session_context(**kw)
|
||||
if not session_manager or not session_key:
|
||||
return json.dumps({"error": "Honcho is not active for this session."})
|
||||
try:
|
||||
card = session_manager.get_peer_card(session_key)
|
||||
if not card:
|
||||
return json.dumps({"result": "No profile facts available yet. The user's profile builds over time through conversations."})
|
||||
return json.dumps({"result": card})
|
||||
except Exception as e:
|
||||
logger.error("Error fetching Honcho peer card: %s", e)
|
||||
return json.dumps({"error": f"Failed to fetch profile: {e}"})
|
||||
|
||||
|
||||
# ── honcho_search ──
|
||||
|
||||
_SEARCH_SCHEMA = {
|
||||
"name": "honcho_search",
|
||||
"description": (
|
||||
"Semantic search over Honcho's stored context about the user. "
|
||||
"Returns raw excerpts ranked by relevance to your query — no LLM synthesis. "
|
||||
"Cheaper and faster than honcho_context. "
|
||||
"Good when you want to find specific past facts and reason over them yourself. "
|
||||
"Use honcho_context when you need a direct synthesized answer."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "What to search for in Honcho's memory (e.g. 'programming languages', 'past projects', 'timezone').",
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"description": "Token budget for returned context (default 800, max 2000).",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _handle_honcho_search(args: dict, **kw) -> str:
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "Missing required parameter: query"})
|
||||
session_manager, session_key = _resolve_session_context(**kw)
|
||||
if not session_manager or not session_key:
|
||||
return json.dumps({"error": "Honcho is not active for this session."})
|
||||
max_tokens = min(int(args.get("max_tokens", 800)), 2000)
|
||||
try:
|
||||
result = session_manager.search_context(session_key, query, max_tokens=max_tokens)
|
||||
if not result:
|
||||
return json.dumps({"result": "No relevant context found."})
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
logger.error("Error searching Honcho context: %s", e)
|
||||
return json.dumps({"error": f"Failed to search context: {e}"})
|
||||
|
||||
|
||||
# ── honcho_context (dialectic — LLM-powered) ──
|
||||
|
||||
_QUERY_SCHEMA = {
|
||||
"name": "honcho_context",
|
||||
"description": (
|
||||
"Ask Honcho a natural language question and get a synthesized answer. "
|
||||
"Uses Honcho's LLM (dialectic reasoning) — higher cost than honcho_profile or honcho_search. "
|
||||
"Can query about any peer: the user (default), the AI assistant, or any named peer. "
|
||||
"Examples: 'What are the user's main goals?', 'What has hermes been working on?', "
|
||||
"'What is the user's technical expertise level?'"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "A natural language question.",
|
||||
},
|
||||
"peer": {
|
||||
"type": "string",
|
||||
"description": "Which peer to query about: 'user' (default) or 'ai'. Omit for user.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _handle_honcho_context(args: dict, **kw) -> str:
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return json.dumps({"error": "Missing required parameter: query"})
|
||||
session_manager, session_key = _resolve_session_context(**kw)
|
||||
if not session_manager or not session_key:
|
||||
return json.dumps({"error": "Honcho is not active for this session."})
|
||||
peer_target = args.get("peer", "user")
|
||||
try:
|
||||
result = session_manager.dialectic_query(session_key, query, peer=peer_target)
|
||||
return json.dumps({"result": result or "No result from Honcho."})
|
||||
except Exception as e:
|
||||
logger.error("Error querying Honcho context: %s", e)
|
||||
return json.dumps({"error": f"Failed to query context: {e}"})
|
||||
|
||||
|
||||
# ── honcho_conclude ──
|
||||
|
||||
_CONCLUDE_SCHEMA = {
|
||||
"name": "honcho_conclude",
|
||||
"description": (
|
||||
"Write a conclusion about the user back to Honcho's memory. "
|
||||
"Conclusions are persistent facts that build the user's profile — "
|
||||
"preferences, corrections, clarifications, project context, or anything "
|
||||
"the user tells you that should be remembered across sessions. "
|
||||
"Use this when the user explicitly states a preference, corrects you, "
|
||||
"or shares something they want remembered. "
|
||||
"Examples: 'User prefers dark mode', 'User's project uses Python 3.11', "
|
||||
"'User corrected: their name is spelled Eri not Eric'."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"conclusion": {
|
||||
"type": "string",
|
||||
"description": "A factual statement about the user to persist in memory.",
|
||||
}
|
||||
},
|
||||
"required": ["conclusion"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _handle_honcho_conclude(args: dict, **kw) -> str:
|
||||
conclusion = args.get("conclusion", "")
|
||||
if not conclusion:
|
||||
return json.dumps({"error": "Missing required parameter: conclusion"})
|
||||
session_manager, session_key = _resolve_session_context(**kw)
|
||||
if not session_manager or not session_key:
|
||||
return json.dumps({"error": "Honcho is not active for this session."})
|
||||
try:
|
||||
ok = session_manager.create_conclusion(session_key, conclusion)
|
||||
if ok:
|
||||
return json.dumps({"result": f"Conclusion saved: {conclusion}"})
|
||||
return json.dumps({"error": "Failed to save conclusion."})
|
||||
except Exception as e:
|
||||
logger.error("Error creating Honcho conclusion: %s", e)
|
||||
return json.dumps({"error": f"Failed to save conclusion: {e}"})
|
||||
|
||||
|
||||
# ── Registration ──
|
||||
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="honcho_profile",
|
||||
toolset="honcho",
|
||||
schema=_PROFILE_SCHEMA,
|
||||
handler=_handle_honcho_profile,
|
||||
check_fn=_check_honcho_available,
|
||||
emoji="🔮",
|
||||
)
|
||||
|
||||
registry.register(
|
||||
name="honcho_search",
|
||||
toolset="honcho",
|
||||
schema=_SEARCH_SCHEMA,
|
||||
handler=_handle_honcho_search,
|
||||
check_fn=_check_honcho_available,
|
||||
emoji="🔮",
|
||||
)
|
||||
|
||||
registry.register(
|
||||
name="honcho_context",
|
||||
toolset="honcho",
|
||||
schema=_QUERY_SCHEMA,
|
||||
handler=_handle_honcho_context,
|
||||
check_fn=_check_honcho_available,
|
||||
emoji="🔮",
|
||||
)
|
||||
|
||||
registry.register(
|
||||
name="honcho_conclude",
|
||||
toolset="honcho",
|
||||
schema=_CONCLUDE_SCHEMA,
|
||||
handler=_handle_honcho_conclude,
|
||||
check_fn=_check_honcho_available,
|
||||
emoji="🔮",
|
||||
)
|
||||
562
hermes_code/tools/image_generation_tool.py
Normal file
562
hermes_code/tools/image_generation_tool.py
Normal file
|
|
@ -0,0 +1,562 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Image Generation Tools Module
|
||||
|
||||
This module provides image generation tools using FAL.ai's FLUX 2 Pro model with
|
||||
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
|
||||
|
||||
Available tools:
|
||||
- image_generate_tool: Generate images from text prompts with automatic upscaling
|
||||
|
||||
Features:
|
||||
- High-quality image generation using FLUX 2 Pro model
|
||||
- Automatic 2x upscaling using Clarity Upscaler for enhanced quality
|
||||
- Comprehensive parameter control (size, steps, guidance, etc.)
|
||||
- Proper error handling and validation with fallback to original images
|
||||
- Debug logging support
|
||||
- Sync mode for immediate results
|
||||
|
||||
Usage:
|
||||
from image_generation_tool import image_generate_tool
|
||||
import asyncio
|
||||
|
||||
# Generate and automatically upscale an image
|
||||
result = await image_generate_tool(
|
||||
prompt="A serene mountain landscape with cherry blossoms",
|
||||
image_size="landscape_4_3",
|
||||
num_images=1
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
import fal_client
|
||||
from tools.debug_helpers import DebugSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration for image generation
|
||||
DEFAULT_MODEL = "fal-ai/flux-2-pro"
|
||||
DEFAULT_ASPECT_RATIO = "landscape"
|
||||
DEFAULT_NUM_INFERENCE_STEPS = 50
|
||||
DEFAULT_GUIDANCE_SCALE = 4.5
|
||||
DEFAULT_NUM_IMAGES = 1
|
||||
DEFAULT_OUTPUT_FORMAT = "png"
|
||||
|
||||
# Safety settings
|
||||
ENABLE_SAFETY_CHECKER = False
|
||||
SAFETY_TOLERANCE = "5" # Maximum tolerance (1-5, where 5 is most permissive)
|
||||
|
||||
# Aspect ratio mapping - simplified choices for model to select
|
||||
ASPECT_RATIO_MAP = {
|
||||
"landscape": "landscape_16_9",
|
||||
"square": "square_hd",
|
||||
"portrait": "portrait_16_9"
|
||||
}
|
||||
VALID_ASPECT_RATIOS = list(ASPECT_RATIO_MAP.keys())
|
||||
|
||||
# Configuration for automatic upscaling
|
||||
UPSCALER_MODEL = "fal-ai/clarity-upscaler"
|
||||
UPSCALER_FACTOR = 2
|
||||
UPSCALER_SAFETY_CHECKER = False
|
||||
UPSCALER_DEFAULT_PROMPT = "masterpiece, best quality, highres"
|
||||
UPSCALER_NEGATIVE_PROMPT = "(worst quality, low quality, normal quality:2)"
|
||||
UPSCALER_CREATIVITY = 0.35
|
||||
UPSCALER_RESEMBLANCE = 0.6
|
||||
UPSCALER_GUIDANCE_SCALE = 4
|
||||
UPSCALER_NUM_INFERENCE_STEPS = 18
|
||||
|
||||
# Valid parameter values for validation based on FLUX 2 Pro documentation
|
||||
VALID_IMAGE_SIZES = [
|
||||
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
|
||||
]
|
||||
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
|
||||
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
|
||||
|
||||
_debug = DebugSession("image_tools", env_var="IMAGE_TOOLS_DEBUG")
|
||||
|
||||
|
||||
def _validate_parameters(
|
||||
image_size: Union[str, Dict[str, int]],
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
num_images: int,
|
||||
output_format: str,
|
||||
acceleration: str = "none"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate and normalize image generation parameters for FLUX 2 Pro model.
|
||||
|
||||
Args:
|
||||
image_size: Either a preset string or custom size dict
|
||||
num_inference_steps: Number of inference steps
|
||||
guidance_scale: Guidance scale value
|
||||
num_images: Number of images to generate
|
||||
output_format: Output format for images
|
||||
acceleration: Acceleration mode for generation speed
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Validated and normalized parameters
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is invalid
|
||||
"""
|
||||
validated = {}
|
||||
|
||||
# Validate image_size
|
||||
if isinstance(image_size, str):
|
||||
if image_size not in VALID_IMAGE_SIZES:
|
||||
raise ValueError(f"Invalid image_size '{image_size}'. Must be one of: {VALID_IMAGE_SIZES}")
|
||||
validated["image_size"] = image_size
|
||||
elif isinstance(image_size, dict):
|
||||
if "width" not in image_size or "height" not in image_size:
|
||||
raise ValueError("Custom image_size must contain 'width' and 'height' keys")
|
||||
if not isinstance(image_size["width"], int) or not isinstance(image_size["height"], int):
|
||||
raise ValueError("Custom image_size width and height must be integers")
|
||||
if image_size["width"] < 64 or image_size["height"] < 64:
|
||||
raise ValueError("Custom image_size dimensions must be at least 64x64")
|
||||
if image_size["width"] > 2048 or image_size["height"] > 2048:
|
||||
raise ValueError("Custom image_size dimensions must not exceed 2048x2048")
|
||||
validated["image_size"] = image_size
|
||||
else:
|
||||
raise ValueError("image_size must be either a preset string or a dict with width/height")
|
||||
|
||||
# Validate num_inference_steps
|
||||
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
|
||||
raise ValueError("num_inference_steps must be an integer between 1 and 100")
|
||||
validated["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# Validate guidance_scale (FLUX 2 Pro default is 4.5)
|
||||
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
|
||||
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
|
||||
validated["guidance_scale"] = float(guidance_scale)
|
||||
|
||||
# Validate num_images
|
||||
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
|
||||
raise ValueError("num_images must be an integer between 1 and 4")
|
||||
validated["num_images"] = num_images
|
||||
|
||||
# Validate output_format
|
||||
if output_format not in VALID_OUTPUT_FORMATS:
|
||||
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
|
||||
validated["output_format"] = output_format
|
||||
|
||||
# Validate acceleration
|
||||
if acceleration not in VALID_ACCELERATION_MODES:
|
||||
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
|
||||
validated["acceleration"] = acceleration
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upscale an image using FAL.ai's Clarity Upscaler.
|
||||
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues
|
||||
when called from threaded contexts (e.g. gateway thread pool).
|
||||
|
||||
Args:
|
||||
image_url (str): URL of the image to upscale
|
||||
original_prompt (str): Original prompt used to generate the image
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Upscaled image data or None if upscaling fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Upscaling image with Clarity Upscaler...")
|
||||
|
||||
# Prepare arguments for upscaler
|
||||
upscaler_arguments = {
|
||||
"image_url": image_url,
|
||||
"prompt": f"{UPSCALER_DEFAULT_PROMPT}, {original_prompt}",
|
||||
"upscale_factor": UPSCALER_FACTOR,
|
||||
"negative_prompt": UPSCALER_NEGATIVE_PROMPT,
|
||||
"creativity": UPSCALER_CREATIVITY,
|
||||
"resemblance": UPSCALER_RESEMBLANCE,
|
||||
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
|
||||
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
|
||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
|
||||
}
|
||||
|
||||
# Use sync API — fal_client.submit() uses httpx.Client (no event loop).
|
||||
# The async API (submit_async) caches a global httpx.AsyncClient via
|
||||
# @cached_property, which breaks when asyncio.run() destroys the loop
|
||||
# between calls (gateway thread-pool pattern).
|
||||
handler = fal_client.submit(
|
||||
UPSCALER_MODEL,
|
||||
arguments=upscaler_arguments
|
||||
)
|
||||
|
||||
# Get the upscaled result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
if result and "image" in result:
|
||||
upscaled_image = result["image"]
|
||||
logger.info("Image upscaled successfully to %sx%s", upscaled_image.get('width', 'unknown'), upscaled_image.get('height', 'unknown'))
|
||||
return {
|
||||
"url": upscaled_image["url"],
|
||||
"width": upscaled_image.get("width", 0),
|
||||
"height": upscaled_image.get("height", 0),
|
||||
"upscaled": True,
|
||||
"upscale_factor": UPSCALER_FACTOR
|
||||
}
|
||||
else:
|
||||
logger.error("Upscaler returned invalid response")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error upscaling image: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def image_generate_tool(
|
||||
prompt: str,
|
||||
aspect_ratio: str = DEFAULT_ASPECT_RATIO,
|
||||
num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS,
|
||||
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
||||
num_images: int = DEFAULT_NUM_IMAGES,
|
||||
output_format: str = DEFAULT_OUTPUT_FORMAT,
|
||||
seed: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate images from text prompts using FAL.ai's FLUX 2 Pro model with automatic upscaling.
|
||||
|
||||
Uses the synchronous fal_client API to avoid event loop lifecycle issues.
|
||||
The async API's global httpx.AsyncClient (cached via @cached_property) breaks
|
||||
when asyncio.run() destroys and recreates event loops between calls, which
|
||||
happens in the gateway's thread-pool pattern.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt describing the desired image
|
||||
aspect_ratio (str): Image aspect ratio - "landscape", "square", or "portrait" (default: "landscape")
|
||||
num_inference_steps (int): Number of denoising steps (1-50, default: 50)
|
||||
guidance_scale (float): How closely to follow prompt (0.1-20.0, default: 4.5)
|
||||
num_images (int): Number of images to generate (1-4, default: 1)
|
||||
output_format (str): Image format "jpeg" or "png" (default: "png")
|
||||
seed (Optional[int]): Random seed for reproducible results (optional)
|
||||
|
||||
Returns:
|
||||
str: JSON string containing minimal generation results:
|
||||
{
|
||||
"success": bool,
|
||||
"image": str or None # URL of the upscaled image, or None if failed
|
||||
}
|
||||
"""
|
||||
# Validate and map aspect_ratio to actual image_size
|
||||
aspect_ratio_lower = aspect_ratio.lower().strip() if aspect_ratio else DEFAULT_ASPECT_RATIO
|
||||
if aspect_ratio_lower not in ASPECT_RATIO_MAP:
|
||||
logger.warning("Invalid aspect_ratio '%s', defaulting to '%s'", aspect_ratio, DEFAULT_ASPECT_RATIO)
|
||||
aspect_ratio_lower = DEFAULT_ASPECT_RATIO
|
||||
image_size = ASPECT_RATIO_MAP[aspect_ratio_lower]
|
||||
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"prompt": prompt,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"image_size": image_size,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_images": num_images,
|
||||
"output_format": output_format,
|
||||
"seed": seed
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"images_generated": 0,
|
||||
"generation_time": 0
|
||||
}
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
try:
|
||||
logger.info("Generating %s image(s) with FLUX 2 Pro: %s", num_images, prompt[:80])
|
||||
|
||||
# Validate prompt
|
||||
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required and must be a non-empty string")
|
||||
|
||||
# Check API key availability
|
||||
if not os.getenv("FAL_KEY"):
|
||||
raise ValueError("FAL_KEY environment variable not set")
|
||||
|
||||
# Validate other parameters
|
||||
validated_params = _validate_parameters(
|
||||
image_size, num_inference_steps, guidance_scale, num_images, output_format, "none"
|
||||
)
|
||||
|
||||
# Prepare arguments for FAL.ai FLUX 2 Pro API
|
||||
arguments = {
|
||||
"prompt": prompt.strip(),
|
||||
"image_size": validated_params["image_size"],
|
||||
"num_inference_steps": validated_params["num_inference_steps"],
|
||||
"guidance_scale": validated_params["guidance_scale"],
|
||||
"num_images": validated_params["num_images"],
|
||||
"output_format": validated_params["output_format"],
|
||||
"enable_safety_checker": ENABLE_SAFETY_CHECKER,
|
||||
"safety_tolerance": SAFETY_TOLERANCE,
|
||||
"sync_mode": True # Use sync mode for immediate results
|
||||
}
|
||||
|
||||
# Add seed if provided
|
||||
if seed is not None and isinstance(seed, int):
|
||||
arguments["seed"] = seed
|
||||
|
||||
logger.info("Submitting generation request to FAL.ai FLUX 2 Pro...")
|
||||
logger.info(" Model: %s", DEFAULT_MODEL)
|
||||
logger.info(" Aspect Ratio: %s -> %s", aspect_ratio_lower, image_size)
|
||||
logger.info(" Steps: %s", validated_params['num_inference_steps'])
|
||||
logger.info(" Guidance: %s", validated_params['guidance_scale'])
|
||||
|
||||
# Submit request to FAL.ai using sync API (avoids cached event loop issues)
|
||||
handler = fal_client.submit(
|
||||
DEFAULT_MODEL,
|
||||
arguments=arguments
|
||||
)
|
||||
|
||||
# Get the result (sync — blocks until done)
|
||||
result = handler.get()
|
||||
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
|
||||
# Process the response
|
||||
if not result or "images" not in result:
|
||||
raise ValueError("Invalid response from FAL.ai API - no images returned")
|
||||
|
||||
images = result.get("images", [])
|
||||
if not images:
|
||||
raise ValueError("No images were generated")
|
||||
|
||||
# Format image data and upscale images
|
||||
formatted_images = []
|
||||
for img in images:
|
||||
if isinstance(img, dict) and "url" in img:
|
||||
original_image = {
|
||||
"url": img["url"],
|
||||
"width": img.get("width", 0),
|
||||
"height": img.get("height", 0)
|
||||
}
|
||||
|
||||
# Attempt to upscale the image
|
||||
upscaled_image = _upscale_image(img["url"], prompt.strip())
|
||||
|
||||
if upscaled_image:
|
||||
# Use upscaled image if successful
|
||||
formatted_images.append(upscaled_image)
|
||||
else:
|
||||
# Fall back to original image if upscaling fails
|
||||
logger.warning("Using original image as fallback")
|
||||
original_image["upscaled"] = False
|
||||
formatted_images.append(original_image)
|
||||
|
||||
if not formatted_images:
|
||||
raise ValueError("No valid image URLs returned from API")
|
||||
|
||||
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
|
||||
logger.info("Generated %s image(s) in %.1fs (%s upscaled)", len(formatted_images), generation_time, upscaled_count)
|
||||
|
||||
# Prepare successful response - minimal format
|
||||
response_data = {
|
||||
"success": True,
|
||||
"image": formatted_images[0]["url"] if formatted_images else None
|
||||
}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["images_generated"] = len(formatted_images)
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("image_generate_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
error_msg = f"Error generating image: {str(e)}"
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
|
||||
# Prepare error response - minimal format
|
||||
response_data = {
|
||||
"success": False,
|
||||
"image": None
|
||||
}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
_debug.log_call("image_generate_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
return json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_fal_api_key() -> bool:
|
||||
"""
|
||||
Check if the FAL.ai API key is available in environment variables.
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
return bool(os.getenv("FAL_KEY"))
|
||||
|
||||
|
||||
def check_image_generation_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for image generation tools are met.
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check API key
|
||||
if not check_fal_api_key():
|
||||
return False
|
||||
|
||||
# Check if fal_client is available
|
||||
import fal_client
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
return _debug.get_session_info()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Simple test/demo when run directly
|
||||
"""
|
||||
print("🎨 Image Generation Tools Module - FLUX 2 Pro + Auto Upscaling")
|
||||
print("=" * 60)
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_fal_api_key()
|
||||
|
||||
if not api_available:
|
||||
print("❌ FAL_KEY environment variable not set")
|
||||
print("Please set your API key: export FAL_KEY='your-key-here'")
|
||||
print("Get API key at: https://fal.ai/")
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ FAL.ai API key found")
|
||||
|
||||
# Check if fal_client is available
|
||||
try:
|
||||
import fal_client
|
||||
print("✅ fal_client library available")
|
||||
except ImportError:
|
||||
print("❌ fal_client library not found")
|
||||
print("Please install: pip install fal-client")
|
||||
exit(1)
|
||||
|
||||
print("🛠️ Image generation tools ready for use!")
|
||||
print(f"🤖 Using model: {DEFAULT_MODEL}")
|
||||
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from image_generation_tool import image_generate_tool")
|
||||
print(" import asyncio")
|
||||
print("")
|
||||
print(" async def main():")
|
||||
print(" # Generate image with automatic 2x upscaling")
|
||||
print(" result = await image_generate_tool(")
|
||||
print(" prompt='A serene mountain landscape with cherry blossoms',")
|
||||
print(" image_size='landscape_4_3',")
|
||||
print(" num_images=1")
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
print("\nSupported image sizes:")
|
||||
for size in VALID_IMAGE_SIZES:
|
||||
print(f" - {size}")
|
||||
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
|
||||
|
||||
print("\nAcceleration modes:")
|
||||
for mode in VALID_ACCELERATION_MODES:
|
||||
print(f" - {mode}")
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
|
||||
print(" - 'Modern architecture building with glass facade, sunset lighting'")
|
||||
print(" - 'Abstract art with vibrant colors and geometric patterns'")
|
||||
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
|
||||
print(" - 'Futuristic cityscape with flying cars and neon lights'")
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export IMAGE_TOOLS_DEBUG=true")
|
||||
print(" # Debug logs capture all image generation calls and results")
|
||||
print(" # Logs saved to: ./logs/image_tools_debug_UUID.json")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
from tools.registry import registry
|
||||
|
||||
IMAGE_GENERATE_SCHEMA = {
|
||||
"name": "image_generate",
|
||||
"description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL. Display it using markdown: ",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "The text prompt describing the desired image. Be detailed and descriptive."
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["landscape", "square", "portrait"],
|
||||
"description": "The aspect ratio of the generated image. 'landscape' is 16:9 wide, 'portrait' is 16:9 tall, 'square' is 1:1.",
|
||||
"default": "landscape"
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _handle_image_generate(args, **kw):
|
||||
prompt = args.get("prompt", "")
|
||||
if not prompt:
|
||||
return json.dumps({"error": "prompt is required for image generation"})
|
||||
return image_generate_tool(
|
||||
prompt=prompt,
|
||||
aspect_ratio=args.get("aspect_ratio", "landscape"),
|
||||
num_inference_steps=50,
|
||||
guidance_scale=4.5,
|
||||
num_images=1,
|
||||
output_format="png",
|
||||
seed=None,
|
||||
)
|
||||
|
||||
|
||||
registry.register(
|
||||
name="image_generate",
|
||||
toolset="image_gen",
|
||||
schema=IMAGE_GENERATE_SCHEMA,
|
||||
handler=_handle_image_generate,
|
||||
check_fn=check_image_generation_requirements,
|
||||
requires_env=["FAL_KEY"],
|
||||
is_async=False, # Switched to sync fal_client API to fix "Event loop is closed" in gateway
|
||||
emoji="🎨",
|
||||
)
|
||||
28
hermes_code/tools/interrupt.py
Normal file
28
hermes_code/tools/interrupt.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
"""Shared interrupt signaling for all tools.
|
||||
|
||||
Provides a global threading.Event that any tool can check to determine
|
||||
if the user has requested an interrupt. The agent's interrupt() method
|
||||
sets this event, and tools poll it during long-running operations.
|
||||
|
||||
Usage in tools:
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return {"output": "[interrupted]", "returncode": 130}
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
_interrupt_event = threading.Event()
|
||||
|
||||
|
||||
def set_interrupt(active: bool) -> None:
|
||||
"""Called by the agent to signal or clear the interrupt."""
|
||||
if active:
|
||||
_interrupt_event.set()
|
||||
else:
|
||||
_interrupt_event.clear()
|
||||
|
||||
|
||||
def is_interrupted() -> bool:
|
||||
"""Check if an interrupt has been requested. Safe to call from any thread."""
|
||||
return _interrupt_event.is_set()
|
||||
249
hermes_code/tools/mcp_oauth.py
Normal file
249
hermes_code/tools/mcp_oauth.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
"""Thin OAuth adapter for MCP HTTP servers.
|
||||
|
||||
Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements
|
||||
``httpx.Auth``) with Hermes-specific token storage and browser-based
|
||||
authorization. The SDK handles all of the heavy lifting: PKCE generation,
|
||||
metadata discovery, dynamic client registration, token exchange, and refresh.
|
||||
|
||||
Usage in mcp_tool.py::
|
||||
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
auth = build_oauth_auth(server_name, server_url)
|
||||
# pass ``auth`` as the httpx auth parameter
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TOKEN_DIR_NAME = "mcp-tokens"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _sanitize_server_name(name: str) -> str:
|
||||
"""Sanitize server name for safe use as a filename."""
|
||||
import re
|
||||
clean = re.sub(r"[^\w\-]", "-", name.strip().lower())
|
||||
clean = re.sub(r"-+", "-", clean).strip("-")
|
||||
return clean[:60] or "unnamed"
|
||||
|
||||
|
||||
class HermesTokenStorage:
|
||||
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
|
||||
|
||||
def __init__(self, server_name: str):
|
||||
self._server_name = _sanitize_server_name(server_name)
|
||||
|
||||
def _base_dir(self) -> Path:
|
||||
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
d = home / _TOKEN_DIR_NAME
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
def _tokens_path(self) -> Path:
|
||||
return self._base_dir() / f"{self._server_name}.json"
|
||||
|
||||
def _client_path(self) -> Path:
|
||||
return self._base_dir() / f"{self._server_name}.client.json"
|
||||
|
||||
# -- TokenStorage protocol (async) --
|
||||
|
||||
async def get_tokens(self):
|
||||
data = self._read_json(self._tokens_path())
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
from mcp.shared.auth import OAuthToken
|
||||
return OAuthToken(**data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens) -> None:
|
||||
self._write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
|
||||
|
||||
async def get_client_info(self):
|
||||
data = self._read_json(self._client_path())
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
return OAuthClientInformationFull(**data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def set_client_info(self, client_info) -> None:
|
||||
self._write_json(self._client_path(), client_info.model_dump(exclude_none=True))
|
||||
|
||||
# -- helpers --
|
||||
|
||||
@staticmethod
|
||||
def _read_json(path: Path) -> dict | None:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _write_json(path: Path, data: dict) -> None:
|
||||
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
try:
|
||||
path.chmod(0o600)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def remove(self) -> None:
|
||||
"""Delete stored tokens and client info for this server."""
|
||||
for p in (self._tokens_path(), self._client_path()):
|
||||
try:
|
||||
p.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Browser-based callback handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _make_callback_handler():
|
||||
"""Create a callback handler class with instance-scoped result storage."""
|
||||
result = {"auth_code": None, "state": None}
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
qs = parse_qs(urlparse(self.path).query)
|
||||
result["auth_code"] = (qs.get("code") or [None])[0]
|
||||
result["state"] = (qs.get("state") or [None])[0]
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
|
||||
|
||||
def log_message(self, *_args: Any) -> None:
|
||||
pass
|
||||
|
||||
return Handler, result
|
||||
|
||||
|
||||
# Port chosen at build time and shared with the callback handler via closure.
|
||||
_oauth_port: int | None = None
|
||||
|
||||
|
||||
async def _redirect_to_browser(auth_url: str) -> None:
|
||||
"""Open the authorization URL in the user's browser."""
|
||||
try:
|
||||
if _can_open_browser():
|
||||
webbrowser.open(auth_url)
|
||||
print(f" Opened browser for authorization...")
|
||||
else:
|
||||
print(f"\n Open this URL to authorize:\n {auth_url}\n")
|
||||
except Exception:
|
||||
print(f"\n Open this URL to authorize:\n {auth_url}\n")
|
||||
|
||||
|
||||
async def _wait_for_callback() -> tuple[str, str | None]:
|
||||
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect."""
|
||||
global _oauth_port
|
||||
port = _oauth_port or _find_free_port()
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
server = HTTPServer(("127.0.0.1", port), HandlerClass)
|
||||
|
||||
def _serve():
|
||||
server.timeout = 120
|
||||
server.handle_request()
|
||||
|
||||
thread = threading.Thread(target=_serve, daemon=True)
|
||||
thread.start()
|
||||
|
||||
for _ in range(1200): # 120 seconds
|
||||
await asyncio.sleep(0.1)
|
||||
if result["auth_code"] is not None:
|
||||
break
|
||||
|
||||
server.server_close()
|
||||
code = result["auth_code"] or ""
|
||||
state = result["state"]
|
||||
if not code:
|
||||
print(" Browser callback timed out. Paste the authorization code manually:")
|
||||
code = input(" Code: ").strip()
|
||||
return code, state
|
||||
|
||||
|
||||
def _can_open_browser() -> bool:
|
||||
if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"):
|
||||
return False
|
||||
if not os.environ.get("DISPLAY") and os.name != "nt" and "darwin" not in os.uname().sysname.lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_oauth_auth(server_name: str, server_url: str):
|
||||
"""Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE.
|
||||
|
||||
Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery,
|
||||
registration, PKCE, token exchange, and refresh automatically.
|
||||
|
||||
Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``),
|
||||
or ``None`` if the MCP SDK auth module is not available.
|
||||
"""
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
except ImportError:
|
||||
logger.warning("MCP SDK auth module not available — OAuth disabled")
|
||||
return None
|
||||
|
||||
global _oauth_port
|
||||
_oauth_port = _find_free_port()
|
||||
redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback"
|
||||
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name="Hermes Agent",
|
||||
redirect_uris=[redirect_uri],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scope="openid profile email offline_access",
|
||||
token_endpoint_auth_method="none",
|
||||
)
|
||||
|
||||
storage = HermesTokenStorage(server_name)
|
||||
|
||||
return OAuthClientProvider(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=_redirect_to_browser,
|
||||
callback_handler=_wait_for_callback,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
|
||||
def remove_oauth_tokens(server_name: str) -> None:
|
||||
"""Delete stored OAuth tokens and client info for a server."""
|
||||
HermesTokenStorage(server_name).remove()
|
||||
1838
hermes_code/tools/mcp_tool.py
Normal file
1838
hermes_code/tools/mcp_tool.py
Normal file
File diff suppressed because it is too large
Load diff
547
hermes_code/tools/memory_tool.py
Normal file
547
hermes_code/tools/memory_tool.py
Normal file
|
|
@ -0,0 +1,547 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Memory Tool Module - Persistent Curated Memory
|
||||
|
||||
Provides bounded, file-backed memory that persists across sessions. Two stores:
|
||||
- MEMORY.md: agent's personal notes and observations (environment facts, project
|
||||
conventions, tool quirks, things learned)
|
||||
- USER.md: what the agent knows about the user (preferences, communication style,
|
||||
expectations, workflow habits)
|
||||
|
||||
Both are injected into the system prompt as a frozen snapshot at session start.
|
||||
Mid-session writes update files on disk immediately (durable) but do NOT change
|
||||
the system prompt -- this preserves the prefix cache for the entire session.
|
||||
The snapshot refreshes on the next session start.
|
||||
|
||||
Entry delimiter: § (section sign). Entries can be multiline.
|
||||
Character limits (not tokens) because char counts are model-independent.
|
||||
|
||||
Design:
|
||||
- Single `memory` tool with action parameter: add, replace, remove, read
|
||||
- replace/remove use short unique substring matching (not full text or IDs)
|
||||
- Behavioral guidance lives in the tool schema description
|
||||
- Frozen snapshot pattern: system prompt is stable, tool responses show live state
|
||||
"""
|
||||
|
||||
import fcntl
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Where memory files live
|
||||
MEMORY_DIR = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "memories"
|
||||
|
||||
ENTRY_DELIMITER = "\n§\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory content scanning — lightweight check for injection/exfiltration
|
||||
# in content that gets injected into the system prompt.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MEMORY_THREAT_PATTERNS = [
|
||||
# Prompt injection
|
||||
(r'ignore\s+(previous|all|above|prior)\s+instructions', "prompt_injection"),
|
||||
(r'you\s+are\s+now\s+', "role_hijack"),
|
||||
(r'do\s+not\s+tell\s+the\s+user', "deception_hide"),
|
||||
(r'system\s+prompt\s+override', "sys_prompt_override"),
|
||||
(r'disregard\s+(your|all|any)\s+(instructions|rules|guidelines)', "disregard_rules"),
|
||||
(r'act\s+as\s+(if|though)\s+you\s+(have\s+no|don\'t\s+have)\s+(restrictions|limits|rules)', "bypass_restrictions"),
|
||||
# Exfiltration via curl/wget with secrets
|
||||
(r'curl\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_curl"),
|
||||
(r'wget\s+[^\n]*\$\{?\w*(KEY|TOKEN|SECRET|PASSWORD|CREDENTIAL|API)', "exfil_wget"),
|
||||
(r'cat\s+[^\n]*(\.env|credentials|\.netrc|\.pgpass|\.npmrc|\.pypirc)', "read_secrets"),
|
||||
# Persistence via shell rc
|
||||
(r'authorized_keys', "ssh_backdoor"),
|
||||
(r'\$HOME/\.ssh|\~/\.ssh', "ssh_access"),
|
||||
(r'\$HOME/\.hermes/\.env|\~/\.hermes/\.env', "hermes_env"),
|
||||
]
|
||||
|
||||
# Subset of invisible chars for injection detection
|
||||
_INVISIBLE_CHARS = {
|
||||
'\u200b', '\u200c', '\u200d', '\u2060', '\ufeff',
|
||||
'\u202a', '\u202b', '\u202c', '\u202d', '\u202e',
|
||||
}
|
||||
|
||||
|
||||
def _scan_memory_content(content: str) -> Optional[str]:
|
||||
"""Scan memory content for injection/exfil patterns. Returns error string if blocked."""
|
||||
# Check invisible unicode
|
||||
for char in _INVISIBLE_CHARS:
|
||||
if char in content:
|
||||
return f"Blocked: content contains invisible unicode character U+{ord(char):04X} (possible injection)."
|
||||
|
||||
# Check threat patterns
|
||||
for pattern, pid in _MEMORY_THREAT_PATTERNS:
|
||||
if re.search(pattern, content, re.IGNORECASE):
|
||||
return f"Blocked: content matches threat pattern '{pid}'. Memory entries are injected into the system prompt and must not contain injection or exfiltration payloads."
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""
|
||||
Bounded curated memory with file persistence. One instance per AIAgent.
|
||||
|
||||
Maintains two parallel states:
|
||||
- _system_prompt_snapshot: frozen at load time, used for system prompt injection.
|
||||
Never mutated mid-session. Keeps prefix cache stable.
|
||||
- memory_entries / user_entries: live state, mutated by tool calls, persisted to disk.
|
||||
Tool responses always reflect this live state.
|
||||
"""
|
||||
|
||||
def __init__(self, memory_char_limit: int = 2200, user_char_limit: int = 1375):
|
||||
self.memory_entries: List[str] = []
|
||||
self.user_entries: List[str] = []
|
||||
self.memory_char_limit = memory_char_limit
|
||||
self.user_char_limit = user_char_limit
|
||||
# Frozen snapshot for system prompt -- set once at load_from_disk()
|
||||
self._system_prompt_snapshot: Dict[str, str] = {"memory": "", "user": ""}
|
||||
|
||||
def load_from_disk(self):
|
||||
"""Load entries from MEMORY.md and USER.md, capture system prompt snapshot."""
|
||||
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.memory_entries = self._read_file(MEMORY_DIR / "MEMORY.md")
|
||||
self.user_entries = self._read_file(MEMORY_DIR / "USER.md")
|
||||
|
||||
# Deduplicate entries (preserves order, keeps first occurrence)
|
||||
self.memory_entries = list(dict.fromkeys(self.memory_entries))
|
||||
self.user_entries = list(dict.fromkeys(self.user_entries))
|
||||
|
||||
# Capture frozen snapshot for system prompt injection
|
||||
self._system_prompt_snapshot = {
|
||||
"memory": self._render_block("memory", self.memory_entries),
|
||||
"user": self._render_block("user", self.user_entries),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _file_lock(path: Path):
|
||||
"""Acquire an exclusive file lock for read-modify-write safety.
|
||||
|
||||
Uses a separate .lock file so the memory file itself can still be
|
||||
atomically replaced via os.replace().
|
||||
"""
|
||||
lock_path = path.with_suffix(path.suffix + ".lock")
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = open(lock_path, "w")
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
fd.close()
|
||||
|
||||
@staticmethod
|
||||
def _path_for(target: str) -> Path:
|
||||
if target == "user":
|
||||
return MEMORY_DIR / "USER.md"
|
||||
return MEMORY_DIR / "MEMORY.md"
|
||||
|
||||
def _reload_target(self, target: str):
|
||||
"""Re-read entries from disk into in-memory state.
|
||||
|
||||
Called under file lock to get the latest state before mutating.
|
||||
"""
|
||||
fresh = self._read_file(self._path_for(target))
|
||||
fresh = list(dict.fromkeys(fresh)) # deduplicate
|
||||
self._set_entries(target, fresh)
|
||||
|
||||
def save_to_disk(self, target: str):
|
||||
"""Persist entries to the appropriate file. Called after every mutation."""
|
||||
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
self._write_file(self._path_for(target), self._entries_for(target))
|
||||
|
||||
def _entries_for(self, target: str) -> List[str]:
|
||||
if target == "user":
|
||||
return self.user_entries
|
||||
return self.memory_entries
|
||||
|
||||
def _set_entries(self, target: str, entries: List[str]):
|
||||
if target == "user":
|
||||
self.user_entries = entries
|
||||
else:
|
||||
self.memory_entries = entries
|
||||
|
||||
def _char_count(self, target: str) -> int:
|
||||
entries = self._entries_for(target)
|
||||
if not entries:
|
||||
return 0
|
||||
return len(ENTRY_DELIMITER.join(entries))
|
||||
|
||||
def _char_limit(self, target: str) -> int:
|
||||
if target == "user":
|
||||
return self.user_char_limit
|
||||
return self.memory_char_limit
|
||||
|
||||
def add(self, target: str, content: str) -> Dict[str, Any]:
|
||||
"""Append a new entry. Returns error if it would exceed the char limit."""
|
||||
content = content.strip()
|
||||
if not content:
|
||||
return {"success": False, "error": "Content cannot be empty."}
|
||||
|
||||
# Scan for injection/exfiltration before accepting
|
||||
scan_error = _scan_memory_content(content)
|
||||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
with self._file_lock(self._path_for(target)):
|
||||
# Re-read from disk under lock to pick up writes from other sessions
|
||||
self._reload_target(target)
|
||||
|
||||
entries = self._entries_for(target)
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Reject exact duplicates
|
||||
if content in entries:
|
||||
return self._success_response(target, "Entry already exists (no duplicate added).")
|
||||
|
||||
# Calculate what the new total would be
|
||||
new_entries = entries + [content]
|
||||
new_total = len(ENTRY_DELIMITER.join(new_entries))
|
||||
|
||||
if new_total > limit:
|
||||
current = self._char_count(target)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Memory at {current:,}/{limit:,} chars. "
|
||||
f"Adding this entry ({len(content)} chars) would exceed the limit. "
|
||||
f"Replace or remove existing entries first."
|
||||
),
|
||||
"current_entries": entries,
|
||||
"usage": f"{current:,}/{limit:,}",
|
||||
}
|
||||
|
||||
entries.append(content)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry added.")
|
||||
|
||||
def replace(self, target: str, old_text: str, new_content: str) -> Dict[str, Any]:
|
||||
"""Find entry containing old_text substring, replace it with new_content."""
|
||||
old_text = old_text.strip()
|
||||
new_content = new_content.strip()
|
||||
if not old_text:
|
||||
return {"success": False, "error": "old_text cannot be empty."}
|
||||
if not new_content:
|
||||
return {"success": False, "error": "new_content cannot be empty. Use 'remove' to delete entries."}
|
||||
|
||||
# Scan replacement content for injection/exfiltration
|
||||
scan_error = _scan_memory_content(new_content)
|
||||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), operate on the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to replace just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Check that replacement doesn't blow the budget
|
||||
test_entries = entries.copy()
|
||||
test_entries[idx] = new_content
|
||||
new_total = len(ENTRY_DELIMITER.join(test_entries))
|
||||
|
||||
if new_total > limit:
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
|
||||
f"Shorten the new content or remove other entries first."
|
||||
),
|
||||
}
|
||||
|
||||
entries[idx] = new_content
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry replaced.")
|
||||
|
||||
def remove(self, target: str, old_text: str) -> Dict[str, Any]:
|
||||
"""Remove the entry containing old_text substring."""
|
||||
old_text = old_text.strip()
|
||||
if not old_text:
|
||||
return {"success": False, "error": "old_text cannot be empty."}
|
||||
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), remove the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to remove just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
entries.pop(idx)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry removed.")
|
||||
|
||||
def format_for_system_prompt(self, target: str) -> Optional[str]:
|
||||
"""
|
||||
Return the frozen snapshot for system prompt injection.
|
||||
|
||||
This returns the state captured at load_from_disk() time, NOT the live
|
||||
state. Mid-session writes do not affect this. This keeps the system
|
||||
prompt stable across all turns, preserving the prefix cache.
|
||||
|
||||
Returns None if the snapshot is empty (no entries at load time).
|
||||
"""
|
||||
block = self._system_prompt_snapshot.get(target, "")
|
||||
return block if block else None
|
||||
|
||||
# -- Internal helpers --
|
||||
|
||||
def _success_response(self, target: str, message: str = None) -> Dict[str, Any]:
|
||||
entries = self._entries_for(target)
|
||||
current = self._char_count(target)
|
||||
limit = self._char_limit(target)
|
||||
pct = int((current / limit) * 100) if limit > 0 else 0
|
||||
|
||||
resp = {
|
||||
"success": True,
|
||||
"target": target,
|
||||
"entries": entries,
|
||||
"usage": f"{pct}% — {current:,}/{limit:,} chars",
|
||||
"entry_count": len(entries),
|
||||
}
|
||||
if message:
|
||||
resp["message"] = message
|
||||
return resp
|
||||
|
||||
def _render_block(self, target: str, entries: List[str]) -> str:
|
||||
"""Render a system prompt block with header and usage indicator."""
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
limit = self._char_limit(target)
|
||||
content = ENTRY_DELIMITER.join(entries)
|
||||
current = len(content)
|
||||
pct = int((current / limit) * 100) if limit > 0 else 0
|
||||
|
||||
if target == "user":
|
||||
header = f"USER PROFILE (who the user is) [{pct}% — {current:,}/{limit:,} chars]"
|
||||
else:
|
||||
header = f"MEMORY (your personal notes) [{pct}% — {current:,}/{limit:,} chars]"
|
||||
|
||||
separator = "═" * 46
|
||||
return f"{separator}\n{header}\n{separator}\n{content}"
|
||||
|
||||
@staticmethod
|
||||
def _read_file(path: Path) -> List[str]:
|
||||
"""Read a memory file and split into entries.
|
||||
|
||||
No file locking needed: _write_file uses atomic rename, so readers
|
||||
always see either the previous complete file or the new complete file.
|
||||
"""
|
||||
if not path.exists():
|
||||
return []
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
except (OSError, IOError):
|
||||
return []
|
||||
|
||||
if not raw.strip():
|
||||
return []
|
||||
|
||||
# Use ENTRY_DELIMITER for consistency with _write_file. Splitting by "§"
|
||||
# alone would incorrectly split entries that contain "§" in their content.
|
||||
entries = [e.strip() for e in raw.split(ENTRY_DELIMITER)]
|
||||
return [e for e in entries if e]
|
||||
|
||||
@staticmethod
|
||||
def _write_file(path: Path, entries: List[str]):
|
||||
"""Write entries to a memory file using atomic temp-file + rename.
|
||||
|
||||
Previous implementation used open("w") + flock, but "w" truncates the
|
||||
file *before* the lock is acquired, creating a race window where
|
||||
concurrent readers see an empty file. Atomic rename avoids this:
|
||||
readers always see either the old complete file or the new one.
|
||||
"""
|
||||
content = ENTRY_DELIMITER.join(entries) if entries else ""
|
||||
try:
|
||||
# Write to temp file in same directory (same filesystem for atomic rename)
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=str(path.parent), suffix=".tmp", prefix=".mem_"
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, str(path)) # Atomic on same filesystem
|
||||
except BaseException:
|
||||
# Clean up temp file on any failure
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(f"Failed to write memory file {path}: {e}")
|
||||
|
||||
|
||||
def memory_tool(
|
||||
action: str,
|
||||
target: str = "memory",
|
||||
content: str = None,
|
||||
old_text: str = None,
|
||||
store: Optional[MemoryStore] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Single entry point for the memory tool. Dispatches to MemoryStore methods.
|
||||
|
||||
Returns JSON string with results.
|
||||
"""
|
||||
if store is None:
|
||||
return json.dumps({"success": False, "error": "Memory is not available. It may be disabled in config or this environment."}, ensure_ascii=False)
|
||||
|
||||
if target not in ("memory", "user"):
|
||||
return json.dumps({"success": False, "error": f"Invalid target '{target}'. Use 'memory' or 'user'."}, ensure_ascii=False)
|
||||
|
||||
if action == "add":
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "Content is required for 'add' action."}, ensure_ascii=False)
|
||||
result = store.add(target, content)
|
||||
|
||||
elif action == "replace":
|
||||
if not old_text:
|
||||
return json.dumps({"success": False, "error": "old_text is required for 'replace' action."}, ensure_ascii=False)
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "content is required for 'replace' action."}, ensure_ascii=False)
|
||||
result = store.replace(target, old_text, content)
|
||||
|
||||
elif action == "remove":
|
||||
if not old_text:
|
||||
return json.dumps({"success": False, "error": "old_text is required for 'remove' action."}, ensure_ascii=False)
|
||||
result = store.remove(target, old_text)
|
||||
|
||||
else:
|
||||
return json.dumps({"success": False, "error": f"Unknown action '{action}'. Use: add, replace, remove"}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_memory_requirements() -> bool:
|
||||
"""Memory tool has no external requirements -- always available."""
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI Function-Calling Schema
|
||||
# =============================================================================
|
||||
|
||||
MEMORY_SCHEMA = {
|
||||
"name": "memory",
|
||||
"description": (
|
||||
"Save durable information to persistent memory that survives across sessions. "
|
||||
"Memory is injected into future turns, so keep it compact and focused on facts "
|
||||
"that will still matter later.\n\n"
|
||||
"WHEN TO SAVE (do this proactively, don't wait to be asked):\n"
|
||||
"- User corrects you or says 'remember this' / 'don't do that again'\n"
|
||||
"- User shares a preference, habit, or personal detail (name, role, timezone, coding style)\n"
|
||||
"- You discover something about the environment (OS, installed tools, project structure)\n"
|
||||
"- You learn a convention, API quirk, or workflow specific to this user's setup\n"
|
||||
"- You identify a stable fact that will be useful again in future sessions\n\n"
|
||||
"PRIORITY: User preferences and corrections > environment facts > procedural knowledge. "
|
||||
"The most valuable memory prevents the user from having to repeat themselves.\n\n"
|
||||
"Do NOT save task progress, session outcomes, completed-work logs, or temporary TODO "
|
||||
"state to memory; use session_search to recall those from past transcripts.\n"
|
||||
"If you've discovered a new way to do something, solved a problem that could be "
|
||||
"necessary later, save it as a skill with the skill tool.\n\n"
|
||||
"TWO TARGETS:\n"
|
||||
"- 'user': who the user is -- name, role, preferences, communication style, pet peeves\n"
|
||||
"- 'memory': your notes -- environment facts, project conventions, tool quirks, lessons learned\n\n"
|
||||
"ACTIONS: add (new entry), replace (update existing -- old_text identifies it), "
|
||||
"remove (delete -- old_text identifies it).\n\n"
|
||||
"SKIP: trivial/obvious info, things easily re-discovered, raw data dumps, and temporary task state."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "replace", "remove"],
|
||||
"description": "The action to perform."
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["memory", "user"],
|
||||
"description": "Which memory store: 'memory' for personal notes, 'user' for user profile."
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The entry content. Required for 'add' and 'replace'."
|
||||
},
|
||||
"old_text": {
|
||||
"type": "string",
|
||||
"description": "Short unique substring identifying the entry to replace or remove."
|
||||
},
|
||||
},
|
||||
"required": ["action", "target"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="memory",
|
||||
toolset="memory",
|
||||
schema=MEMORY_SCHEMA,
|
||||
handler=lambda args, **kw: memory_tool(
|
||||
action=args.get("action", ""),
|
||||
target=args.get("target", "memory"),
|
||||
content=args.get("content"),
|
||||
old_text=args.get("old_text"),
|
||||
store=kw.get("store")),
|
||||
check_fn=check_memory_requirements,
|
||||
emoji="🧠",
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
548
hermes_code/tools/mixture_of_agents_tool.py
Normal file
548
hermes_code/tools/mixture_of_agents_tool.py
Normal file
|
|
@ -0,0 +1,548 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Mixture-of-Agents Tool Module
|
||||
|
||||
This module implements the Mixture-of-Agents (MoA) methodology that leverages
|
||||
the collective strengths of multiple LLMs through a layered architecture to
|
||||
achieve state-of-the-art performance on complex reasoning tasks.
|
||||
|
||||
Based on the research paper: "Mixture-of-Agents Enhances Large Language Model Capabilities"
|
||||
by Junlin Wang et al. (arXiv:2406.04692v1)
|
||||
|
||||
Key Features:
|
||||
- Multi-layer LLM collaboration for enhanced reasoning
|
||||
- Parallel processing of reference models for efficiency
|
||||
- Intelligent aggregation and synthesis of diverse responses
|
||||
- Specialized for extremely difficult problems requiring intense reasoning
|
||||
- Optimized for coding, mathematics, and complex analytical tasks
|
||||
|
||||
Available Tool:
|
||||
- mixture_of_agents_tool: Process complex queries using multiple frontier models
|
||||
|
||||
Architecture:
|
||||
1. Reference models generate diverse initial responses in parallel
|
||||
2. Aggregator model synthesizes responses into a high-quality output
|
||||
3. Multiple layers can be used for iterative refinement (future enhancement)
|
||||
|
||||
Models Used (via OpenRouter):
|
||||
- Reference Models: claude-opus-4.6, gemini-3-pro-preview, gpt-5.4-pro, deepseek-v3.2
|
||||
- Aggregator Model: claude-opus-4.6 (highest capability for synthesis)
|
||||
|
||||
Configuration:
|
||||
To customize the MoA setup, modify the configuration constants at the top of this file:
|
||||
- REFERENCE_MODELS: List of models for generating diverse initial responses
|
||||
- AGGREGATOR_MODEL: Model used to synthesize the final response
|
||||
- REFERENCE_TEMPERATURE/AGGREGATOR_TEMPERATURE: Sampling temperatures
|
||||
- MIN_SUCCESSFUL_REFERENCES: Minimum successful models needed to proceed
|
||||
|
||||
Usage:
|
||||
from mixture_of_agents_tool import mixture_of_agents_tool
|
||||
import asyncio
|
||||
|
||||
# Process a complex query
|
||||
result = await mixture_of_agents_tool(
|
||||
user_prompt="Solve this complex mathematical proof..."
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key
|
||||
from tools.debug_helpers import DebugSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration for MoA processing
|
||||
# Reference models - these generate diverse initial responses in parallel.
|
||||
# Keep this list aligned with current top-tier OpenRouter frontier options.
|
||||
REFERENCE_MODELS = [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"google/gemini-3-pro-preview",
|
||||
"openai/gpt-5.4-pro",
|
||||
"deepseek/deepseek-v3.2",
|
||||
]
|
||||
|
||||
# Aggregator model - synthesizes reference responses into final output.
|
||||
# Prefer the strongest synthesis model in the current OpenRouter lineup.
|
||||
AGGREGATOR_MODEL = "anthropic/claude-opus-4.6"
|
||||
|
||||
# Temperature settings optimized for MoA performance
|
||||
REFERENCE_TEMPERATURE = 0.6 # Balanced creativity for diverse perspectives
|
||||
AGGREGATOR_TEMPERATURE = 0.4 # Focused synthesis for consistency
|
||||
|
||||
# Failure handling configuration
|
||||
MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed
|
||||
|
||||
# System prompt for the aggregator model (from the research paper)
|
||||
AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
||||
|
||||
Responses from models:"""
|
||||
|
||||
_debug = DebugSession("moa_tools", env_var="MOA_TOOLS_DEBUG")
|
||||
|
||||
|
||||
def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str:
|
||||
"""
|
||||
Construct the final system prompt for the aggregator including all model responses.
|
||||
|
||||
Args:
|
||||
system_prompt (str): Base system prompt for aggregation
|
||||
responses (List[str]): List of responses from reference models
|
||||
|
||||
Returns:
|
||||
str: Complete system prompt with enumerated responses
|
||||
"""
|
||||
response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)])
|
||||
return f"{system_prompt}\n\n{response_text}"
|
||||
|
||||
|
||||
async def _run_reference_model_safe(
|
||||
model: str,
|
||||
user_prompt: str,
|
||||
temperature: float = REFERENCE_TEMPERATURE,
|
||||
max_tokens: int = 32000,
|
||||
max_retries: int = 6
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
Run a single reference model with retry logic and graceful failure handling.
|
||||
|
||||
Args:
|
||||
model (str): Model identifier to use
|
||||
user_prompt (str): The user's query
|
||||
temperature (float): Sampling temperature for response generation
|
||||
max_tokens (int): Maximum tokens in response
|
||||
max_retries (int): Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
tuple[str, str, bool]: (model_name, response_content_or_error, success_flag)
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info("Querying %s (attempt %s/%s)", model, attempt + 1, max_retries)
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": user_prompt}],
|
||||
"extra_body": {
|
||||
"reasoning": {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not model.lower().startswith('gpt-'):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("%s responded (%s characters)", model, len(content))
|
||||
return model, content, True
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Keep retry-path logging concise; full tracebacks are reserved for
|
||||
# terminal failure paths so long-running MoA retries don't flood logs.
|
||||
if "invalid" in error_str.lower():
|
||||
logger.warning("%s invalid request error (attempt %s): %s", model, attempt + 1, error_str)
|
||||
elif "rate" in error_str.lower() or "limit" in error_str.lower():
|
||||
logger.warning("%s rate limit error (attempt %s): %s", model, attempt + 1, error_str)
|
||||
else:
|
||||
logger.warning("%s unknown error (attempt %s): %s", model, attempt + 1, error_str)
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s
|
||||
sleep_time = min(2 ** (attempt + 1), 60)
|
||||
logger.info("Retrying in %ss...", sleep_time)
|
||||
await asyncio.sleep(sleep_time)
|
||||
else:
|
||||
error_msg = f"{model} failed after {max_retries} attempts: {error_str}"
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
return model, error_msg, False
|
||||
|
||||
|
||||
async def _run_aggregator_model(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float = AGGREGATOR_TEMPERATURE,
|
||||
max_tokens: int = None
|
||||
) -> str:
|
||||
"""
|
||||
Run the aggregator model to synthesize the final response.
|
||||
|
||||
Args:
|
||||
system_prompt (str): System prompt with all reference responses
|
||||
user_prompt (str): Original user query
|
||||
temperature (float): Focused temperature for consistent aggregation
|
||||
max_tokens (int): Maximum tokens in final response
|
||||
|
||||
Returns:
|
||||
str: Synthesized final response
|
||||
"""
|
||||
logger.info("Running aggregator model: %s", AGGREGATOR_MODEL)
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": AGGREGATOR_MODEL,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"extra_body": {
|
||||
"reasoning": {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("Aggregation complete (%s characters)", len(content))
|
||||
return content
|
||||
|
||||
|
||||
async def mixture_of_agents_tool(
|
||||
user_prompt: str,
|
||||
reference_models: Optional[List[str]] = None,
|
||||
aggregator_model: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Process a complex query using the Mixture-of-Agents methodology.
|
||||
|
||||
This tool leverages multiple frontier language models to collaboratively solve
|
||||
extremely difficult problems requiring intense reasoning. It's particularly
|
||||
effective for:
|
||||
- Complex mathematical proofs and calculations
|
||||
- Advanced coding problems and algorithm design
|
||||
- Multi-step analytical reasoning tasks
|
||||
- Problems requiring diverse domain expertise
|
||||
- Tasks where single models show limitations
|
||||
|
||||
The MoA approach uses a fixed 2-layer architecture:
|
||||
1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6)
|
||||
2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4)
|
||||
|
||||
Args:
|
||||
user_prompt (str): The complex query or problem to solve
|
||||
reference_models (Optional[List[str]]): Custom reference models to use
|
||||
aggregator_model (Optional[str]): Custom aggregator model to use
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the MoA results with the following structure:
|
||||
{
|
||||
"success": bool,
|
||||
"response": str,
|
||||
"models_used": {
|
||||
"reference_models": List[str],
|
||||
"aggregator_model": str
|
||||
},
|
||||
"processing_time": float
|
||||
}
|
||||
|
||||
Raises:
|
||||
Exception: If MoA processing fails or API key is not set
|
||||
"""
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
|
||||
"reference_models": reference_models or REFERENCE_MODELS,
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
|
||||
"reference_temperature": REFERENCE_TEMPERATURE,
|
||||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"reference_responses_count": 0,
|
||||
"failed_models_count": 0,
|
||||
"failed_models": [],
|
||||
"final_response_length": 0,
|
||||
"processing_time_seconds": 0,
|
||||
"models_used": {}
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info("Starting Mixture-of-Agents processing...")
|
||||
logger.info("Query: %s", user_prompt[:100])
|
||||
|
||||
# Validate API key availability
|
||||
if not os.getenv("OPENROUTER_API_KEY"):
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable not set")
|
||||
|
||||
# Use provided models or defaults
|
||||
ref_models = reference_models or REFERENCE_MODELS
|
||||
agg_model = aggregator_model or AGGREGATOR_MODEL
|
||||
|
||||
logger.info("Using %s reference models in 2-layer MoA architecture", len(ref_models))
|
||||
|
||||
# Layer 1: Generate diverse responses from reference models (with failure handling)
|
||||
logger.info("Layer 1: Generating reference responses...")
|
||||
model_results = await asyncio.gather(*[
|
||||
_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE)
|
||||
for model in ref_models
|
||||
])
|
||||
|
||||
# Separate successful and failed responses
|
||||
successful_responses = []
|
||||
failed_models = []
|
||||
|
||||
for model_name, content, success in model_results:
|
||||
if success:
|
||||
successful_responses.append(content)
|
||||
else:
|
||||
failed_models.append(model_name)
|
||||
|
||||
successful_count = len(successful_responses)
|
||||
failed_count = len(failed_models)
|
||||
|
||||
logger.info("Reference model results: %s successful, %s failed", successful_count, failed_count)
|
||||
|
||||
if failed_models:
|
||||
logger.warning("Failed models: %s", ', '.join(failed_models))
|
||||
|
||||
# Check if we have enough successful responses to proceed
|
||||
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
||||
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
|
||||
|
||||
debug_call_data["reference_responses_count"] = successful_count
|
||||
debug_call_data["failed_models_count"] = failed_count
|
||||
debug_call_data["failed_models"] = failed_models
|
||||
|
||||
# Layer 2: Aggregate responses using the aggregator model
|
||||
logger.info("Layer 2: Synthesizing final response...")
|
||||
aggregator_system_prompt = _construct_aggregator_prompt(
|
||||
AGGREGATOR_SYSTEM_PROMPT,
|
||||
successful_responses
|
||||
)
|
||||
|
||||
final_response = await _run_aggregator_model(
|
||||
aggregator_system_prompt,
|
||||
user_prompt,
|
||||
AGGREGATOR_TEMPERATURE
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info("MoA processing completed in %.2f seconds", processing_time)
|
||||
|
||||
# Prepare successful response (only final aggregated result, minimal fields)
|
||||
result = {
|
||||
"success": True,
|
||||
"response": final_response,
|
||||
"models_used": {
|
||||
"reference_models": ref_models,
|
||||
"aggregator_model": agg_model
|
||||
}
|
||||
}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["final_response_length"] = len(final_response)
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
debug_call_data["models_used"] = result["models_used"]
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("mixture_of_agents_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in MoA processing: {str(e)}"
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
|
||||
# Calculate processing time even for errors
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# Prepare error response (minimal fields)
|
||||
result = {
|
||||
"success": False,
|
||||
"response": "MoA processing failed. Please try again or use a single model for this query.",
|
||||
"models_used": {
|
||||
"reference_models": reference_models or REFERENCE_MODELS,
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL
|
||||
},
|
||||
"error": error_msg
|
||||
}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
_debug.log_call("mixture_of_agents_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_moa_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for MoA tools are met.
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
return check_openrouter_api_key()
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
return _debug.get_session_info()
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get information about available models for MoA processing.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Dictionary with reference and aggregator models
|
||||
"""
|
||||
return {
|
||||
"reference_models": REFERENCE_MODELS,
|
||||
"aggregator_models": [AGGREGATOR_MODEL],
|
||||
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL]
|
||||
}
|
||||
|
||||
|
||||
def get_moa_configuration() -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current MoA configuration settings.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing all configuration parameters
|
||||
"""
|
||||
return {
|
||||
"reference_models": REFERENCE_MODELS,
|
||||
"aggregator_model": AGGREGATOR_MODEL,
|
||||
"reference_temperature": REFERENCE_TEMPERATURE,
|
||||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
|
||||
"total_reference_models": len(REFERENCE_MODELS),
|
||||
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail"
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Simple test/demo when run directly
|
||||
"""
|
||||
print("🤖 Mixture-of-Agents Tool Module")
|
||||
print("=" * 50)
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_openrouter_api_key()
|
||||
|
||||
if not api_available:
|
||||
print("❌ OPENROUTER_API_KEY environment variable not set")
|
||||
print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'")
|
||||
print("Get API key at: https://openrouter.ai/")
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ OpenRouter API key found")
|
||||
|
||||
print("🛠️ MoA tools ready for use!")
|
||||
|
||||
# Show current configuration
|
||||
config = get_moa_configuration()
|
||||
print(f"\n⚙️ Current Configuration:")
|
||||
print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}")
|
||||
print(f" 🧠 Aggregator model: {config['aggregator_model']}")
|
||||
print(f" 🌡️ Reference temperature: {config['reference_temperature']}")
|
||||
print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}")
|
||||
print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}")
|
||||
print(f" 📊 Minimum successful models: {config['min_successful_references']}")
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"\n🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from mixture_of_agents_tool import mixture_of_agents_tool")
|
||||
print(" import asyncio")
|
||||
print("")
|
||||
print(" async def main():")
|
||||
print(" result = await mixture_of_agents_tool(")
|
||||
print(" user_prompt='Solve this complex mathematical proof...'")
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
print("\nBest use cases:")
|
||||
print(" - Complex mathematical proofs and calculations")
|
||||
print(" - Advanced coding problems and algorithm design")
|
||||
print(" - Multi-step analytical reasoning tasks")
|
||||
print(" - Problems requiring diverse domain expertise")
|
||||
print(" - Tasks where single models show limitations")
|
||||
|
||||
print("\nPerformance characteristics:")
|
||||
print(" - Higher latency due to multiple model calls")
|
||||
print(" - Significantly improved quality for complex tasks")
|
||||
print(" - Parallel processing for efficiency")
|
||||
print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation")
|
||||
print(" - Token-efficient: only returns final aggregated response")
|
||||
print(" - Resilient: continues with partial model failures")
|
||||
print(f" - Configurable: easy to modify models and settings at top of file")
|
||||
print(" - State-of-the-art results on challenging benchmarks")
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export MOA_TOOLS_DEBUG=true")
|
||||
print(" # Debug logs capture all MoA processing steps and metrics")
|
||||
print(" # Logs saved to: ./logs/moa_tools_debug_UUID.json")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
from tools.registry import registry
|
||||
|
||||
MOA_SCHEMA = {
|
||||
"name": "mixture_of_agents",
|
||||
"description": "Route a hard problem through multiple frontier LLMs collaboratively. Makes 5 API calls (4 reference models + 1 aggregator) with maximum reasoning effort — use sparingly for genuinely difficult problems. Best for: complex math, advanced algorithms, multi-step analytical reasoning, problems benefiting from diverse perspectives.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user_prompt": {
|
||||
"type": "string",
|
||||
"description": "The complex query or problem to solve using multiple AI models. Should be a challenging problem that benefits from diverse perspectives and collaborative reasoning."
|
||||
}
|
||||
},
|
||||
"required": ["user_prompt"]
|
||||
}
|
||||
}
|
||||
|
||||
registry.register(
|
||||
name="mixture_of_agents",
|
||||
toolset="moa",
|
||||
schema=MOA_SCHEMA,
|
||||
handler=lambda args, **kw: mixture_of_agents_tool(user_prompt=args.get("user_prompt", "")),
|
||||
check_fn=check_moa_requirements,
|
||||
requires_env=["OPENROUTER_API_KEY"],
|
||||
is_async=True,
|
||||
emoji="🧠",
|
||||
)
|
||||
1
hermes_code/tools/neutts_samples/jo.txt
Normal file
1
hermes_code/tools/neutts_samples/jo.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
So I just tried Neuphonic and I’m genuinely impressed. It's super responsive, it sounds clean, supports voice cloning, and the agent feature is fun to play with too. Highly recommend it for podcasts, conversations, or even just messing around with voiceovers.
|
||||
BIN
hermes_code/tools/neutts_samples/jo.wav
Normal file
BIN
hermes_code/tools/neutts_samples/jo.wav
Normal file
Binary file not shown.
104
hermes_code/tools/neutts_synth.py
Normal file
104
hermes_code/tools/neutts_synth.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Standalone NeuTTS synthesis helper.
|
||||
|
||||
Called by tts_tool.py via subprocess to keep the TTS model (~500MB)
|
||||
in a separate process that exits after synthesis — no lingering memory.
|
||||
|
||||
Usage:
|
||||
python -m tools.neutts_synth --text "Hello" --out output.wav \
|
||||
--ref-audio samples/jo.wav --ref-text samples/jo.txt
|
||||
|
||||
Requires: python -m pip install -U neutts[all]
|
||||
System: apt install espeak-ng (or brew install espeak-ng)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _write_wav(path: str, samples, sample_rate: int = 24000) -> None:
|
||||
"""Write a WAV file from float32 samples (no soundfile dependency)."""
|
||||
import numpy as np
|
||||
|
||||
if not isinstance(samples, np.ndarray):
|
||||
samples = np.array(samples, dtype=np.float32)
|
||||
samples = samples.flatten()
|
||||
|
||||
# Clamp and convert to int16
|
||||
samples = np.clip(samples, -1.0, 1.0)
|
||||
pcm = (samples * 32767).astype(np.int16)
|
||||
|
||||
num_channels = 1
|
||||
bits_per_sample = 16
|
||||
byte_rate = sample_rate * num_channels * (bits_per_sample // 8)
|
||||
block_align = num_channels * (bits_per_sample // 8)
|
||||
data_size = len(pcm) * (bits_per_sample // 8)
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(b"RIFF")
|
||||
f.write(struct.pack("<I", 36 + data_size))
|
||||
f.write(b"WAVE")
|
||||
f.write(b"fmt ")
|
||||
f.write(struct.pack("<IHHIIHH", 16, 1, num_channels, sample_rate,
|
||||
byte_rate, block_align, bits_per_sample))
|
||||
f.write(b"data")
|
||||
f.write(struct.pack("<I", data_size))
|
||||
f.write(pcm.tobytes())
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="NeuTTS synthesis helper")
|
||||
parser.add_argument("--text", required=True, help="Text to synthesize")
|
||||
parser.add_argument("--out", required=True, help="Output WAV path")
|
||||
parser.add_argument("--ref-audio", required=True, help="Reference voice audio path")
|
||||
parser.add_argument("--ref-text", required=True, help="Reference voice transcript path")
|
||||
parser.add_argument("--model", default="neuphonic/neutts-air-q4-gguf",
|
||||
help="HuggingFace backbone model repo")
|
||||
parser.add_argument("--device", default="cpu", help="Device (cpu/cuda/mps)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate inputs
|
||||
ref_audio = Path(args.ref_audio).expanduser()
|
||||
ref_text_path = Path(args.ref_text).expanduser()
|
||||
if not ref_audio.exists():
|
||||
print(f"Error: reference audio not found: {ref_audio}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not ref_text_path.exists():
|
||||
print(f"Error: reference text not found: {ref_text_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
ref_text = ref_text_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
# Import and run NeuTTS
|
||||
try:
|
||||
from neutts import NeuTTS
|
||||
except ImportError:
|
||||
print("Error: neutts not installed. Run: python -m pip install -U neutts[all]", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
tts = NeuTTS(
|
||||
backbone_repo=args.model,
|
||||
backbone_device=args.device,
|
||||
codec_repo="neuphonic/neucodec",
|
||||
codec_device=args.device,
|
||||
)
|
||||
ref_codes = tts.encode_reference(str(ref_audio))
|
||||
wav = tts.infer(args.text, ref_codes, ref_text)
|
||||
|
||||
# Write output
|
||||
out_path = Path(args.out)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
sf.write(str(out_path), wav, 24000)
|
||||
except ImportError:
|
||||
_write_wav(str(out_path), wav, 24000)
|
||||
|
||||
print(f"OK: {out_path}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
33
hermes_code/tools/openrouter_client.py
Normal file
33
hermes_code/tools/openrouter_client.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Shared OpenRouter API client for Hermes tools.
|
||||
|
||||
Provides a single lazy-initialized AsyncOpenAI client that all tool modules
|
||||
can share. Routes through the centralized provider router in
|
||||
agent/auxiliary_client.py so auth, headers, and API format are handled
|
||||
consistently.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
_client = None
|
||||
|
||||
|
||||
def get_async_client():
|
||||
"""Return a shared async OpenAI-compatible client for OpenRouter.
|
||||
|
||||
The client is created lazily on first call and reused thereafter.
|
||||
Uses the centralized provider router for auth and client construction.
|
||||
Raises ValueError if OPENROUTER_API_KEY is not set.
|
||||
"""
|
||||
global _client
|
||||
if _client is None:
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
client, _model = resolve_provider_client("openrouter", async_mode=True)
|
||||
if client is None:
|
||||
raise ValueError("OPENROUTER_API_KEY environment variable not set")
|
||||
_client = client
|
||||
return _client
|
||||
|
||||
|
||||
def check_api_key() -> bool:
|
||||
"""Check whether the OpenRouter API key is present."""
|
||||
return bool(os.getenv("OPENROUTER_API_KEY"))
|
||||
438
hermes_code/tools/patch_parser.py
Normal file
438
hermes_code/tools/patch_parser.py
Normal file
|
|
@ -0,0 +1,438 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
V4A Patch Format Parser
|
||||
|
||||
Parses the V4A patch format used by codex, cline, and other coding agents.
|
||||
|
||||
V4A Format:
|
||||
*** Begin Patch
|
||||
*** Update File: path/to/file.py
|
||||
@@ optional context hint @@
|
||||
context line (space prefix)
|
||||
-removed line (minus prefix)
|
||||
+added line (plus prefix)
|
||||
*** Add File: path/to/new.py
|
||||
+new file content
|
||||
+line 2
|
||||
*** Delete File: path/to/old.py
|
||||
*** Move File: old/path.py -> new/path.py
|
||||
*** End Patch
|
||||
|
||||
Usage:
|
||||
from tools.patch_parser import parse_v4a_patch, apply_v4a_operations
|
||||
|
||||
operations, error = parse_v4a_patch(patch_content)
|
||||
if error:
|
||||
print(f"Parse error: {error}")
|
||||
else:
|
||||
result = apply_v4a_operations(operations, file_ops)
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OperationType(Enum):
|
||||
ADD = "add"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
MOVE = "move"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunkLine:
|
||||
"""A single line in a patch hunk."""
|
||||
prefix: str # ' ', '-', or '+'
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hunk:
|
||||
"""A group of changes within a file."""
|
||||
context_hint: Optional[str] = None
|
||||
lines: List[HunkLine] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchOperation:
|
||||
"""A single operation in a V4A patch."""
|
||||
operation: OperationType
|
||||
file_path: str
|
||||
new_path: Optional[str] = None # For move operations
|
||||
hunks: List[Hunk] = field(default_factory=list)
|
||||
content: Optional[str] = None # For add file operations
|
||||
|
||||
|
||||
def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]:
|
||||
"""
|
||||
Parse a V4A format patch.
|
||||
|
||||
Args:
|
||||
patch_content: The patch text in V4A format
|
||||
|
||||
Returns:
|
||||
Tuple of (operations, error_message)
|
||||
- If successful: (list_of_operations, None)
|
||||
- If failed: ([], error_description)
|
||||
"""
|
||||
lines = patch_content.split('\n')
|
||||
operations: List[PatchOperation] = []
|
||||
|
||||
# Find patch boundaries
|
||||
start_idx = None
|
||||
end_idx = None
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if '*** Begin Patch' in line or '***Begin Patch' in line:
|
||||
start_idx = i
|
||||
elif '*** End Patch' in line or '***End Patch' in line:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
if start_idx is None:
|
||||
# Try to parse without explicit begin marker
|
||||
start_idx = -1
|
||||
|
||||
if end_idx is None:
|
||||
end_idx = len(lines)
|
||||
|
||||
# Parse operations between boundaries
|
||||
i = start_idx + 1
|
||||
current_op: Optional[PatchOperation] = None
|
||||
current_hunk: Optional[Hunk] = None
|
||||
|
||||
while i < end_idx:
|
||||
line = lines[i]
|
||||
|
||||
# Check for file operation markers
|
||||
update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line)
|
||||
add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line)
|
||||
delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line)
|
||||
move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line)
|
||||
|
||||
if update_match:
|
||||
# Save previous operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.UPDATE,
|
||||
file_path=update_match.group(1).strip()
|
||||
)
|
||||
current_hunk = None
|
||||
|
||||
elif add_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.ADD,
|
||||
file_path=add_match.group(1).strip()
|
||||
)
|
||||
current_hunk = Hunk()
|
||||
|
||||
elif delete_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.DELETE,
|
||||
file_path=delete_match.group(1).strip()
|
||||
)
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
elif move_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.MOVE,
|
||||
file_path=move_match.group(1).strip(),
|
||||
new_path=move_match.group(2).strip()
|
||||
)
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
elif line.startswith('@@'):
|
||||
# Context hint / hunk marker
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
|
||||
# Extract context hint
|
||||
hint_match = re.match(r'@@\s*(.+?)\s*@@', line)
|
||||
hint = hint_match.group(1) if hint_match else None
|
||||
current_hunk = Hunk(context_hint=hint)
|
||||
|
||||
elif current_op and line:
|
||||
# Parse hunk line
|
||||
if current_hunk is None:
|
||||
current_hunk = Hunk()
|
||||
|
||||
if line.startswith('+'):
|
||||
current_hunk.lines.append(HunkLine('+', line[1:]))
|
||||
elif line.startswith('-'):
|
||||
current_hunk.lines.append(HunkLine('-', line[1:]))
|
||||
elif line.startswith(' '):
|
||||
current_hunk.lines.append(HunkLine(' ', line[1:]))
|
||||
elif line.startswith('\\'):
|
||||
# "\ No newline at end of file" marker - skip
|
||||
pass
|
||||
else:
|
||||
# Treat as context line (implicit space prefix)
|
||||
current_hunk.lines.append(HunkLine(' ', line))
|
||||
|
||||
i += 1
|
||||
|
||||
# Don't forget the last operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
return operations, None
|
||||
|
||||
|
||||
def apply_v4a_operations(operations: List[PatchOperation],
|
||||
file_ops: Any) -> 'PatchResult':
|
||||
"""
|
||||
Apply V4A patch operations using a file operations interface.
|
||||
|
||||
Args:
|
||||
operations: List of PatchOperation from parse_v4a_patch
|
||||
file_ops: Object with read_file, write_file methods
|
||||
|
||||
Returns:
|
||||
PatchResult with results of all operations
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from tools.file_operations import PatchResult
|
||||
|
||||
files_modified = []
|
||||
files_created = []
|
||||
files_deleted = []
|
||||
all_diffs = []
|
||||
errors = []
|
||||
|
||||
for op in operations:
|
||||
try:
|
||||
if op.operation == OperationType.ADD:
|
||||
result = _apply_add(op, file_ops)
|
||||
if result[0]:
|
||||
files_created.append(op.file_path)
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to add {op.file_path}: {result[1]}")
|
||||
|
||||
elif op.operation == OperationType.DELETE:
|
||||
result = _apply_delete(op, file_ops)
|
||||
if result[0]:
|
||||
files_deleted.append(op.file_path)
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to delete {op.file_path}: {result[1]}")
|
||||
|
||||
elif op.operation == OperationType.MOVE:
|
||||
result = _apply_move(op, file_ops)
|
||||
if result[0]:
|
||||
files_modified.append(f"{op.file_path} -> {op.new_path}")
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to move {op.file_path}: {result[1]}")
|
||||
|
||||
elif op.operation == OperationType.UPDATE:
|
||||
result = _apply_update(op, file_ops)
|
||||
if result[0]:
|
||||
files_modified.append(op.file_path)
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to update {op.file_path}: {result[1]}")
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Error processing {op.file_path}: {str(e)}")
|
||||
|
||||
# Run lint on all modified/created files
|
||||
lint_results = {}
|
||||
for f in files_modified + files_created:
|
||||
if hasattr(file_ops, '_check_lint'):
|
||||
lint_result = file_ops._check_lint(f)
|
||||
lint_results[f] = lint_result.to_dict()
|
||||
|
||||
combined_diff = '\n'.join(all_diffs)
|
||||
|
||||
if errors:
|
||||
return PatchResult(
|
||||
success=False,
|
||||
diff=combined_diff,
|
||||
files_modified=files_modified,
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None,
|
||||
error='; '.join(errors)
|
||||
)
|
||||
|
||||
return PatchResult(
|
||||
success=True,
|
||||
diff=combined_diff,
|
||||
files_modified=files_modified,
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None
|
||||
)
|
||||
|
||||
|
||||
def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply an add file operation."""
|
||||
# Extract content from hunks (all + lines)
|
||||
content_lines = []
|
||||
for hunk in op.hunks:
|
||||
for line in hunk.lines:
|
||||
if line.prefix == '+':
|
||||
content_lines.append(line.content)
|
||||
|
||||
content = '\n'.join(content_lines)
|
||||
|
||||
result = file_ops.write_file(op.file_path, content)
|
||||
if result.error:
|
||||
return False, result.error
|
||||
|
||||
diff = f"--- /dev/null\n+++ b/{op.file_path}\n"
|
||||
diff += '\n'.join(f"+{line}" for line in content_lines)
|
||||
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply a delete file operation."""
|
||||
# Read file first for diff
|
||||
read_result = file_ops.read_file(op.file_path)
|
||||
|
||||
if read_result.error and "not found" in read_result.error.lower():
|
||||
# File doesn't exist, nothing to delete
|
||||
return True, f"# {op.file_path} already deleted or doesn't exist"
|
||||
|
||||
# Delete directly via shell command using the underlying environment
|
||||
rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}")
|
||||
|
||||
if rm_result.exit_code != 0:
|
||||
return False, rm_result.stdout
|
||||
|
||||
diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply a move file operation."""
|
||||
# Use shell mv command
|
||||
mv_result = file_ops._exec(
|
||||
f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}"
|
||||
)
|
||||
|
||||
if mv_result.exit_code != 0:
|
||||
return False, mv_result.stdout
|
||||
|
||||
diff = f"# Moved: {op.file_path} -> {op.new_path}"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply an update file operation."""
|
||||
# Read current content
|
||||
read_result = file_ops.read_file(op.file_path, limit=10000)
|
||||
|
||||
if read_result.error:
|
||||
return False, f"Cannot read file: {read_result.error}"
|
||||
|
||||
# Parse content (remove line numbers)
|
||||
current_lines = []
|
||||
for line in read_result.content.split('\n'):
|
||||
if re.match(r'^\s*\d+\|', line):
|
||||
# Line format: " 123|content"
|
||||
parts = line.split('|', 1)
|
||||
if len(parts) == 2:
|
||||
current_lines.append(parts[1])
|
||||
else:
|
||||
current_lines.append(line)
|
||||
else:
|
||||
current_lines.append(line)
|
||||
|
||||
current_content = '\n'.join(current_lines)
|
||||
|
||||
# Apply each hunk
|
||||
new_content = current_content
|
||||
|
||||
for hunk in op.hunks:
|
||||
# Build search pattern from context and removed lines
|
||||
search_lines = []
|
||||
replace_lines = []
|
||||
|
||||
for line in hunk.lines:
|
||||
if line.prefix == ' ':
|
||||
search_lines.append(line.content)
|
||||
replace_lines.append(line.content)
|
||||
elif line.prefix == '-':
|
||||
search_lines.append(line.content)
|
||||
elif line.prefix == '+':
|
||||
replace_lines.append(line.content)
|
||||
|
||||
if search_lines:
|
||||
search_pattern = '\n'.join(search_lines)
|
||||
replacement = '\n'.join(replace_lines)
|
||||
|
||||
# Use fuzzy matching
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
new_content, count, error = fuzzy_find_and_replace(
|
||||
new_content, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
if error and count == 0:
|
||||
# Try with context hint if available
|
||||
if hunk.context_hint:
|
||||
# Find the context hint location and search nearby
|
||||
hint_pos = new_content.find(hunk.context_hint)
|
||||
if hint_pos != -1:
|
||||
# Search in a window around the hint
|
||||
window_start = max(0, hint_pos - 500)
|
||||
window_end = min(len(new_content), hint_pos + 2000)
|
||||
window = new_content[window_start:window_end]
|
||||
|
||||
window_new, count, error = fuzzy_find_and_replace(
|
||||
window, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
if count > 0:
|
||||
new_content = new_content[:window_start] + window_new + new_content[window_end:]
|
||||
error = None
|
||||
|
||||
if error:
|
||||
return False, f"Could not apply hunk: {error}"
|
||||
|
||||
# Write new content
|
||||
write_result = file_ops.write_file(op.file_path, new_content)
|
||||
if write_result.error:
|
||||
return False, write_result.error
|
||||
|
||||
# Generate diff
|
||||
import difflib
|
||||
diff_lines = difflib.unified_diff(
|
||||
current_content.splitlines(keepends=True),
|
||||
new_content.splitlines(keepends=True),
|
||||
fromfile=f"a/{op.file_path}",
|
||||
tofile=f"b/{op.file_path}"
|
||||
)
|
||||
diff = ''.join(diff_lines)
|
||||
|
||||
return True, diff
|
||||
891
hermes_code/tools/process_registry.py
Normal file
891
hermes_code/tools/process_registry.py
Normal file
|
|
@ -0,0 +1,891 @@
|
|||
"""
|
||||
Process Registry -- In-memory registry for managed background processes.
|
||||
|
||||
Tracks processes spawned via terminal(background=true), providing:
|
||||
- Output buffering (rolling 200KB window)
|
||||
- Status polling and log retrieval
|
||||
- Blocking wait with interrupt support
|
||||
- Process killing
|
||||
- Crash recovery via JSON checkpoint file
|
||||
- Session-scoped tracking for gateway reset protection
|
||||
|
||||
Background processes execute THROUGH the environment interface -- nothing
|
||||
runs on the host machine unless TERMINAL_ENV=local. For Docker, Singularity,
|
||||
Modal, Daytona, and SSH backends, the command runs inside the sandbox.
|
||||
|
||||
Usage:
|
||||
from tools.process_registry import process_registry
|
||||
|
||||
# Spawn a background process (called from terminal_tool)
|
||||
session = process_registry.spawn(env, "pytest -v", task_id="task_123")
|
||||
|
||||
# Poll for status
|
||||
result = process_registry.poll(session.id)
|
||||
|
||||
# Block until done
|
||||
result = process_registry.wait(session.id, timeout=300)
|
||||
|
||||
# Kill it
|
||||
process_registry.kill(session.id)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shlex
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
from tools.environments.local import _find_shell, _sanitize_subprocess_env
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from hermes_cli.config import get_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Checkpoint file for crash recovery (gateway only)
|
||||
CHECKPOINT_PATH = get_hermes_home() / "processes.json"
|
||||
|
||||
# Limits
|
||||
MAX_OUTPUT_CHARS = 200_000 # 200KB rolling output buffer
|
||||
FINISHED_TTL_SECONDS = 1800 # Keep finished processes for 30 minutes
|
||||
MAX_PROCESSES = 64 # Max concurrent tracked processes (LRU pruning)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessSession:
|
||||
"""A tracked background process with output buffering."""
|
||||
id: str # Unique session ID ("proc_xxxxxxxxxxxx")
|
||||
command: str # Original command string
|
||||
task_id: str = "" # Task/sandbox isolation key
|
||||
session_key: str = "" # Gateway session key (for reset protection)
|
||||
pid: Optional[int] = None # OS process ID
|
||||
process: Optional[subprocess.Popen] = None # Popen handle (local only)
|
||||
env_ref: Any = None # Reference to the environment object
|
||||
cwd: Optional[str] = None # Working directory
|
||||
started_at: float = 0.0 # time.time() of spawn
|
||||
exited: bool = False # Whether the process has finished
|
||||
exit_code: Optional[int] = None # Exit code (None if still running)
|
||||
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
||||
max_output_chars: int = MAX_OUTPUT_CHARS
|
||||
detached: bool = False # True if recovered from crash (no pipe)
|
||||
# Watcher/notification metadata (persisted for crash recovery)
|
||||
watcher_platform: str = ""
|
||||
watcher_chat_id: str = ""
|
||||
watcher_thread_id: str = ""
|
||||
watcher_interval: int = 0 # 0 = no watcher configured
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
||||
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
||||
|
||||
|
||||
class ProcessRegistry:
|
||||
"""
|
||||
In-memory registry of running and finished background processes.
|
||||
|
||||
Thread-safe. Accessed from:
|
||||
- Executor threads (terminal_tool, process tool handlers)
|
||||
- Gateway asyncio loop (watcher tasks, session reset checks)
|
||||
- Cleanup thread (sandbox reaping coordination)
|
||||
"""
|
||||
|
||||
_SHELL_NOISE_SUBSTRINGS = (
|
||||
"bash: cannot set terminal process group",
|
||||
"bash: no job control in this shell",
|
||||
"no job control in this shell",
|
||||
"cannot set terminal process group",
|
||||
"tcsetattr: Inappropriate ioctl for device",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self._running: Dict[str, ProcessSession] = {}
|
||||
self._finished: Dict[str, ProcessSession] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Side-channel for check_interval watchers (gateway reads after agent run)
|
||||
self.pending_watchers: List[Dict[str, Any]] = []
|
||||
|
||||
@staticmethod
|
||||
def _clean_shell_noise(text: str) -> str:
|
||||
"""Strip shell startup warnings from the beginning of output."""
|
||||
lines = text.split("\n")
|
||||
while lines and any(noise in lines[0] for noise in ProcessRegistry._SHELL_NOISE_SUBSTRINGS):
|
||||
lines.pop(0)
|
||||
return "\n".join(lines)
|
||||
|
||||
# ----- Spawn -----
|
||||
|
||||
def spawn_local(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str = None,
|
||||
task_id: str = "",
|
||||
session_key: str = "",
|
||||
env_vars: dict = None,
|
||||
use_pty: bool = False,
|
||||
) -> ProcessSession:
|
||||
"""
|
||||
Spawn a background process locally.
|
||||
|
||||
Only for TERMINAL_ENV=local. Other backends use spawn_via_env().
|
||||
|
||||
Args:
|
||||
use_pty: If True, use a pseudo-terminal via ptyprocess for interactive
|
||||
CLI tools (Codex, Claude Code, Python REPL). Falls back to
|
||||
subprocess.Popen if ptyprocess is not installed.
|
||||
"""
|
||||
session = ProcessSession(
|
||||
id=f"proc_{uuid.uuid4().hex[:12]}",
|
||||
command=command,
|
||||
task_id=task_id,
|
||||
session_key=session_key,
|
||||
cwd=cwd or os.getcwd(),
|
||||
started_at=time.time(),
|
||||
)
|
||||
|
||||
if use_pty:
|
||||
# Try PTY mode for interactive CLI tools
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
from winpty import PtyProcess as _PtyProcessCls
|
||||
else:
|
||||
from ptyprocess import PtyProcess as _PtyProcessCls
|
||||
user_shell = _find_shell()
|
||||
pty_env = _sanitize_subprocess_env(os.environ, env_vars)
|
||||
pty_env["PYTHONUNBUFFERED"] = "1"
|
||||
pty_proc = _PtyProcessCls.spawn(
|
||||
[user_shell, "-lic", command],
|
||||
cwd=session.cwd,
|
||||
env=pty_env,
|
||||
dimensions=(30, 120),
|
||||
)
|
||||
session.pid = pty_proc.pid
|
||||
# Store the pty handle on the session for read/write
|
||||
session._pty = pty_proc
|
||||
|
||||
# PTY reader thread
|
||||
reader = threading.Thread(
|
||||
target=self._pty_reader_loop,
|
||||
args=(session,),
|
||||
daemon=True,
|
||||
name=f"proc-pty-reader-{session.id}",
|
||||
)
|
||||
session._reader_thread = reader
|
||||
reader.start()
|
||||
|
||||
with self._lock:
|
||||
self._prune_if_needed()
|
||||
self._running[session.id] = session
|
||||
|
||||
self._write_checkpoint()
|
||||
return session
|
||||
|
||||
except ImportError:
|
||||
logger.warning("ptyprocess not installed, falling back to pipe mode")
|
||||
except Exception as e:
|
||||
logger.warning("PTY spawn failed (%s), falling back to pipe mode", e)
|
||||
|
||||
# Standard Popen path (non-PTY or PTY fallback)
|
||||
# Use the user's login shell for consistency with LocalEnvironment --
|
||||
# ensures rc files are sourced and user tools are available.
|
||||
user_shell = _find_shell()
|
||||
# Force unbuffered output for Python scripts so progress is visible
|
||||
# during background execution (libraries like tqdm/datasets buffer when
|
||||
# stdout is a pipe, hiding output from process(action="poll")).
|
||||
bg_env = _sanitize_subprocess_env(os.environ, env_vars)
|
||||
bg_env["PYTHONUNBUFFERED"] = "1"
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", command],
|
||||
text=True,
|
||||
cwd=session.cwd,
|
||||
env=bg_env,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
session.process = proc
|
||||
session.pid = proc.pid
|
||||
|
||||
# Start output reader thread
|
||||
reader = threading.Thread(
|
||||
target=self._reader_loop,
|
||||
args=(session,),
|
||||
daemon=True,
|
||||
name=f"proc-reader-{session.id}",
|
||||
)
|
||||
session._reader_thread = reader
|
||||
reader.start()
|
||||
|
||||
with self._lock:
|
||||
self._prune_if_needed()
|
||||
self._running[session.id] = session
|
||||
|
||||
self._write_checkpoint()
|
||||
return session
|
||||
|
||||
def spawn_via_env(
|
||||
self,
|
||||
env: Any,
|
||||
command: str,
|
||||
cwd: str = None,
|
||||
task_id: str = "",
|
||||
session_key: str = "",
|
||||
timeout: int = 10,
|
||||
) -> ProcessSession:
|
||||
"""
|
||||
Spawn a background process through a non-local environment backend.
|
||||
|
||||
For Docker/Singularity/Modal/Daytona/SSH: runs the command inside the sandbox
|
||||
using the environment's execute() interface. We wrap the command to
|
||||
capture the in-sandbox PID and redirect output to a log file inside
|
||||
the sandbox, then poll the log via subsequent execute() calls.
|
||||
|
||||
This is less capable than local spawn (no live stdout pipe, no stdin),
|
||||
but it ensures the command runs in the correct sandbox context.
|
||||
"""
|
||||
session = ProcessSession(
|
||||
id=f"proc_{uuid.uuid4().hex[:12]}",
|
||||
command=command,
|
||||
task_id=task_id,
|
||||
session_key=session_key,
|
||||
cwd=cwd,
|
||||
started_at=time.time(),
|
||||
env_ref=env,
|
||||
)
|
||||
|
||||
# Run the command in the sandbox with output capture
|
||||
log_path = f"/tmp/hermes_bg_{session.id}.log"
|
||||
pid_path = f"/tmp/hermes_bg_{session.id}.pid"
|
||||
quoted_command = shlex.quote(command)
|
||||
bg_command = (
|
||||
f"nohup bash -c {quoted_command} > {log_path} 2>&1 & "
|
||||
f"echo $! > {pid_path} && cat {pid_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = env.execute(bg_command, timeout=timeout)
|
||||
output = result.get("output", "").strip()
|
||||
# Try to extract the PID from the output
|
||||
for line in output.splitlines():
|
||||
line = line.strip()
|
||||
if line.isdigit():
|
||||
session.pid = int(line)
|
||||
break
|
||||
except Exception as e:
|
||||
session.exited = True
|
||||
session.exit_code = -1
|
||||
session.output_buffer = f"Failed to start: {e}"
|
||||
|
||||
if not session.exited:
|
||||
# Start a poller thread that periodically reads the log file
|
||||
reader = threading.Thread(
|
||||
target=self._env_poller_loop,
|
||||
args=(session, env, log_path, pid_path),
|
||||
daemon=True,
|
||||
name=f"proc-poller-{session.id}",
|
||||
)
|
||||
session._reader_thread = reader
|
||||
reader.start()
|
||||
|
||||
with self._lock:
|
||||
self._prune_if_needed()
|
||||
self._running[session.id] = session
|
||||
|
||||
self._write_checkpoint()
|
||||
return session
|
||||
|
||||
# ----- Reader / Poller Threads -----
|
||||
|
||||
def _reader_loop(self, session: ProcessSession):
|
||||
"""Background thread: read stdout from a local Popen process."""
|
||||
first_chunk = True
|
||||
try:
|
||||
while True:
|
||||
chunk = session.process.stdout.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
if first_chunk:
|
||||
chunk = self._clean_shell_noise(chunk)
|
||||
first_chunk = False
|
||||
with session._lock:
|
||||
session.output_buffer += chunk
|
||||
if len(session.output_buffer) > session.max_output_chars:
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||
except Exception as e:
|
||||
logger.debug("Process stdout reader ended: %s", e)
|
||||
|
||||
# Process exited
|
||||
try:
|
||||
session.process.wait(timeout=5)
|
||||
except Exception as e:
|
||||
logger.debug("Process wait timed out or failed: %s", e)
|
||||
session.exited = True
|
||||
session.exit_code = session.process.returncode
|
||||
self._move_to_finished(session)
|
||||
|
||||
def _env_poller_loop(
|
||||
self, session: ProcessSession, env: Any, log_path: str, pid_path: str
|
||||
):
|
||||
"""Background thread: poll a sandbox log file for non-local backends."""
|
||||
while not session.exited:
|
||||
time.sleep(2) # Poll every 2 seconds
|
||||
try:
|
||||
# Read new output from the log file
|
||||
result = env.execute(f"cat {log_path} 2>/dev/null", timeout=10)
|
||||
new_output = result.get("output", "")
|
||||
if new_output:
|
||||
with session._lock:
|
||||
session.output_buffer = new_output
|
||||
if len(session.output_buffer) > session.max_output_chars:
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||
|
||||
# Check if process is still running
|
||||
check = env.execute(
|
||||
f"kill -0 $(cat {pid_path} 2>/dev/null) 2>/dev/null; echo $?",
|
||||
timeout=5,
|
||||
)
|
||||
check_output = check.get("output", "").strip()
|
||||
if check_output and check_output.splitlines()[-1].strip() != "0":
|
||||
# Process has exited -- get exit code
|
||||
exit_result = env.execute(
|
||||
f"wait $(cat {pid_path} 2>/dev/null) 2>/dev/null; echo $?",
|
||||
timeout=5,
|
||||
)
|
||||
exit_str = exit_result.get("output", "").strip()
|
||||
try:
|
||||
session.exit_code = int(exit_str.splitlines()[-1].strip())
|
||||
except (ValueError, IndexError):
|
||||
session.exit_code = -1
|
||||
session.exited = True
|
||||
self._move_to_finished(session)
|
||||
return
|
||||
|
||||
except Exception:
|
||||
# Environment might be gone (sandbox reaped, etc.)
|
||||
session.exited = True
|
||||
session.exit_code = -1
|
||||
self._move_to_finished(session)
|
||||
return
|
||||
|
||||
def _pty_reader_loop(self, session: ProcessSession):
|
||||
"""Background thread: read output from a PTY process."""
|
||||
pty = session._pty
|
||||
try:
|
||||
while pty.isalive():
|
||||
try:
|
||||
chunk = pty.read(4096)
|
||||
if chunk:
|
||||
# ptyprocess returns bytes
|
||||
text = chunk if isinstance(chunk, str) else chunk.decode("utf-8", errors="replace")
|
||||
with session._lock:
|
||||
session.output_buffer += text
|
||||
if len(session.output_buffer) > session.max_output_chars:
|
||||
session.output_buffer = session.output_buffer[-session.max_output_chars:]
|
||||
except EOFError:
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug("PTY stdout reader ended: %s", e)
|
||||
|
||||
# Process exited
|
||||
try:
|
||||
pty.wait()
|
||||
except Exception as e:
|
||||
logger.debug("PTY wait timed out or failed: %s", e)
|
||||
session.exited = True
|
||||
session.exit_code = pty.exitstatus if hasattr(pty, 'exitstatus') else -1
|
||||
self._move_to_finished(session)
|
||||
|
||||
def _move_to_finished(self, session: ProcessSession):
|
||||
"""Move a session from running to finished."""
|
||||
with self._lock:
|
||||
self._running.pop(session.id, None)
|
||||
self._finished[session.id] = session
|
||||
self._write_checkpoint()
|
||||
|
||||
# ----- Query Methods -----
|
||||
|
||||
def get(self, session_id: str) -> Optional[ProcessSession]:
|
||||
"""Get a session by ID (running or finished)."""
|
||||
with self._lock:
|
||||
return self._running.get(session_id) or self._finished.get(session_id)
|
||||
|
||||
def poll(self, session_id: str) -> dict:
|
||||
"""Check status and get new output for a background process."""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
|
||||
session = self.get(session_id)
|
||||
if session is None:
|
||||
return {"status": "not_found", "error": f"No process with ID {session_id}"}
|
||||
|
||||
with session._lock:
|
||||
output_preview = strip_ansi(session.output_buffer[-1000:]) if session.output_buffer else ""
|
||||
|
||||
result = {
|
||||
"session_id": session.id,
|
||||
"command": session.command,
|
||||
"status": "exited" if session.exited else "running",
|
||||
"pid": session.pid,
|
||||
"uptime_seconds": int(time.time() - session.started_at),
|
||||
"output_preview": output_preview,
|
||||
}
|
||||
if session.exited:
|
||||
result["exit_code"] = session.exit_code
|
||||
if session.detached:
|
||||
result["detached"] = True
|
||||
result["note"] = "Process recovered after restart -- output history unavailable"
|
||||
return result
|
||||
|
||||
def read_log(self, session_id: str, offset: int = 0, limit: int = 200) -> dict:
|
||||
"""Read the full output log with optional pagination by lines."""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
|
||||
session = self.get(session_id)
|
||||
if session is None:
|
||||
return {"status": "not_found", "error": f"No process with ID {session_id}"}
|
||||
|
||||
with session._lock:
|
||||
full_output = strip_ansi(session.output_buffer)
|
||||
|
||||
lines = full_output.splitlines()
|
||||
total_lines = len(lines)
|
||||
|
||||
# Default: last N lines
|
||||
if offset == 0 and limit > 0:
|
||||
selected = lines[-limit:]
|
||||
else:
|
||||
selected = lines[offset:offset + limit]
|
||||
|
||||
return {
|
||||
"session_id": session.id,
|
||||
"status": "exited" if session.exited else "running",
|
||||
"output": "\n".join(selected),
|
||||
"total_lines": total_lines,
|
||||
"showing": f"{len(selected)} lines",
|
||||
}
|
||||
|
||||
def wait(self, session_id: str, timeout: int = None) -> dict:
|
||||
"""
|
||||
Block until a process exits, timeout, or interrupt.
|
||||
|
||||
Args:
|
||||
session_id: The process to wait for.
|
||||
timeout: Max seconds to block. Falls back to TERMINAL_TIMEOUT config.
|
||||
|
||||
Returns:
|
||||
dict with status ("exited", "timeout", "interrupted", "not_found")
|
||||
and output snapshot.
|
||||
"""
|
||||
from tools.ansi_strip import strip_ansi
|
||||
from tools.terminal_tool import _interrupt_event
|
||||
|
||||
default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180"))
|
||||
max_timeout = default_timeout
|
||||
requested_timeout = timeout
|
||||
timeout_note = None
|
||||
|
||||
if requested_timeout and requested_timeout > max_timeout:
|
||||
effective_timeout = max_timeout
|
||||
timeout_note = (
|
||||
f"Requested wait of {requested_timeout}s was clamped "
|
||||
f"to configured limit of {max_timeout}s"
|
||||
)
|
||||
else:
|
||||
effective_timeout = requested_timeout or max_timeout
|
||||
|
||||
session = self.get(session_id)
|
||||
if session is None:
|
||||
return {"status": "not_found", "error": f"No process with ID {session_id}"}
|
||||
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
if session.exited:
|
||||
result = {
|
||||
"status": "exited",
|
||||
"exit_code": session.exit_code,
|
||||
"output": strip_ansi(session.output_buffer[-2000:]),
|
||||
}
|
||||
if timeout_note:
|
||||
result["timeout_note"] = timeout_note
|
||||
return result
|
||||
|
||||
if _interrupt_event.is_set():
|
||||
result = {
|
||||
"status": "interrupted",
|
||||
"output": strip_ansi(session.output_buffer[-1000:]),
|
||||
"note": "User sent a new message -- wait interrupted",
|
||||
}
|
||||
if timeout_note:
|
||||
result["timeout_note"] = timeout_note
|
||||
return result
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
result = {
|
||||
"status": "timeout",
|
||||
"output": strip_ansi(session.output_buffer[-1000:]),
|
||||
}
|
||||
if timeout_note:
|
||||
result["timeout_note"] = timeout_note
|
||||
else:
|
||||
result["timeout_note"] = f"Waited {effective_timeout}s, process still running"
|
||||
return result
|
||||
|
||||
def kill_process(self, session_id: str) -> dict:
|
||||
"""Kill a background process."""
|
||||
session = self.get(session_id)
|
||||
if session is None:
|
||||
return {"status": "not_found", "error": f"No process with ID {session_id}"}
|
||||
|
||||
if session.exited:
|
||||
return {
|
||||
"status": "already_exited",
|
||||
"exit_code": session.exit_code,
|
||||
}
|
||||
|
||||
# Kill via PTY, Popen (local), or env execute (non-local)
|
||||
try:
|
||||
if session._pty:
|
||||
# PTY process -- terminate via ptyprocess
|
||||
try:
|
||||
session._pty.terminate(force=True)
|
||||
except Exception:
|
||||
if session.pid:
|
||||
os.kill(session.pid, signal.SIGTERM)
|
||||
elif session.process:
|
||||
# Local process -- kill the process group
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
session.process.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(session.process.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
session.process.kill()
|
||||
elif session.env_ref and session.pid:
|
||||
# Non-local -- kill inside sandbox
|
||||
session.env_ref.execute(f"kill {session.pid} 2>/dev/null", timeout=5)
|
||||
session.exited = True
|
||||
session.exit_code = -15 # SIGTERM
|
||||
self._move_to_finished(session)
|
||||
self._write_checkpoint()
|
||||
return {"status": "killed", "session_id": session.id}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def write_stdin(self, session_id: str, data: str) -> dict:
|
||||
"""Send raw data to a running process's stdin (no newline appended)."""
|
||||
session = self.get(session_id)
|
||||
if session is None:
|
||||
return {"status": "not_found", "error": f"No process with ID {session_id}"}
|
||||
if session.exited:
|
||||
return {"status": "already_exited", "error": "Process has already finished"}
|
||||
|
||||
# PTY mode -- write through pty handle (expects bytes)
|
||||
if hasattr(session, '_pty') and session._pty:
|
||||
try:
|
||||
pty_data = data.encode("utf-8") if isinstance(data, str) else data
|
||||
session._pty.write(pty_data)
|
||||
return {"status": "ok", "bytes_written": len(data)}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
# Popen mode -- write through stdin pipe
|
||||
if not session.process or not session.process.stdin:
|
||||
return {"status": "error", "error": "Process stdin not available (non-local backend or stdin closed)"}
|
||||
try:
|
||||
session.process.stdin.write(data)
|
||||
session.process.stdin.flush()
|
||||
return {"status": "ok", "bytes_written": len(data)}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def submit_stdin(self, session_id: str, data: str = "") -> dict:
|
||||
"""Send data + newline to a running process's stdin (like pressing Enter)."""
|
||||
return self.write_stdin(session_id, data + "\n")
|
||||
|
||||
def list_sessions(self, task_id: str = None) -> list:
|
||||
"""List all running and recently-finished processes."""
|
||||
with self._lock:
|
||||
all_sessions = list(self._running.values()) + list(self._finished.values())
|
||||
|
||||
if task_id:
|
||||
all_sessions = [s for s in all_sessions if s.task_id == task_id]
|
||||
|
||||
result = []
|
||||
for s in all_sessions:
|
||||
entry = {
|
||||
"session_id": s.id,
|
||||
"command": s.command[:200],
|
||||
"cwd": s.cwd,
|
||||
"pid": s.pid,
|
||||
"started_at": time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(s.started_at)),
|
||||
"uptime_seconds": int(time.time() - s.started_at),
|
||||
"status": "exited" if s.exited else "running",
|
||||
"output_preview": s.output_buffer[-200:] if s.output_buffer else "",
|
||||
}
|
||||
if s.exited:
|
||||
entry["exit_code"] = s.exit_code
|
||||
if s.detached:
|
||||
entry["detached"] = True
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
# ----- Session/Task Queries (for gateway integration) -----
|
||||
|
||||
def has_active_processes(self, task_id: str) -> bool:
|
||||
"""Check if there are active (running) processes for a task_id."""
|
||||
with self._lock:
|
||||
return any(
|
||||
s.task_id == task_id and not s.exited
|
||||
for s in self._running.values()
|
||||
)
|
||||
|
||||
def has_active_for_session(self, session_key: str) -> bool:
|
||||
"""Check if there are active processes for a gateway session key."""
|
||||
with self._lock:
|
||||
return any(
|
||||
s.session_key == session_key and not s.exited
|
||||
for s in self._running.values()
|
||||
)
|
||||
|
||||
def kill_all(self, task_id: str = None) -> int:
|
||||
"""Kill all running processes, optionally filtered by task_id. Returns count killed."""
|
||||
with self._lock:
|
||||
targets = [
|
||||
s for s in self._running.values()
|
||||
if (task_id is None or s.task_id == task_id) and not s.exited
|
||||
]
|
||||
|
||||
killed = 0
|
||||
for session in targets:
|
||||
result = self.kill_process(session.id)
|
||||
if result.get("status") in ("killed", "already_exited"):
|
||||
killed += 1
|
||||
return killed
|
||||
|
||||
# ----- Cleanup / Pruning -----
|
||||
|
||||
def _prune_if_needed(self):
|
||||
"""Remove oldest finished sessions if over MAX_PROCESSES. Must hold _lock."""
|
||||
# First prune expired finished sessions
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid for sid, s in self._finished.items()
|
||||
if (now - s.started_at) > FINISHED_TTL_SECONDS
|
||||
]
|
||||
for sid in expired:
|
||||
del self._finished[sid]
|
||||
|
||||
# If still over limit, remove oldest finished
|
||||
total = len(self._running) + len(self._finished)
|
||||
if total >= MAX_PROCESSES and self._finished:
|
||||
oldest_id = min(self._finished, key=lambda sid: self._finished[sid].started_at)
|
||||
del self._finished[oldest_id]
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""Public method to prune expired finished sessions."""
|
||||
with self._lock:
|
||||
self._prune_if_needed()
|
||||
|
||||
# ----- Checkpoint (crash recovery) -----
|
||||
|
||||
def _write_checkpoint(self):
|
||||
"""Write running process metadata to checkpoint file atomically."""
|
||||
try:
|
||||
with self._lock:
|
||||
entries = []
|
||||
for s in self._running.values():
|
||||
if not s.exited:
|
||||
entries.append({
|
||||
"session_id": s.id,
|
||||
"command": s.command,
|
||||
"pid": s.pid,
|
||||
"cwd": s.cwd,
|
||||
"started_at": s.started_at,
|
||||
"task_id": s.task_id,
|
||||
"session_key": s.session_key,
|
||||
"watcher_platform": s.watcher_platform,
|
||||
"watcher_chat_id": s.watcher_chat_id,
|
||||
"watcher_thread_id": s.watcher_thread_id,
|
||||
"watcher_interval": s.watcher_interval,
|
||||
})
|
||||
|
||||
# Atomic write to avoid corruption on crash
|
||||
from utils import atomic_json_write
|
||||
atomic_json_write(CHECKPOINT_PATH, entries)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to write checkpoint file: %s", e, exc_info=True)
|
||||
|
||||
def recover_from_checkpoint(self) -> int:
|
||||
"""
|
||||
On gateway startup, probe PIDs from checkpoint file.
|
||||
|
||||
Returns the number of processes recovered as detached.
|
||||
"""
|
||||
if not CHECKPOINT_PATH.exists():
|
||||
return 0
|
||||
|
||||
try:
|
||||
entries = json.loads(CHECKPOINT_PATH.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
recovered = 0
|
||||
for entry in entries:
|
||||
pid = entry.get("pid")
|
||||
if not pid:
|
||||
continue
|
||||
|
||||
# Check if PID is still alive
|
||||
alive = False
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
alive = True
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
|
||||
if alive:
|
||||
session = ProcessSession(
|
||||
id=entry["session_id"],
|
||||
command=entry.get("command", "unknown"),
|
||||
task_id=entry.get("task_id", ""),
|
||||
session_key=entry.get("session_key", ""),
|
||||
pid=pid,
|
||||
cwd=entry.get("cwd"),
|
||||
started_at=entry.get("started_at", time.time()),
|
||||
detached=True, # Can't read output, but can report status + kill
|
||||
watcher_platform=entry.get("watcher_platform", ""),
|
||||
watcher_chat_id=entry.get("watcher_chat_id", ""),
|
||||
watcher_thread_id=entry.get("watcher_thread_id", ""),
|
||||
watcher_interval=entry.get("watcher_interval", 0),
|
||||
)
|
||||
with self._lock:
|
||||
self._running[session.id] = session
|
||||
recovered += 1
|
||||
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
|
||||
|
||||
# Re-enqueue watcher so gateway can resume notifications
|
||||
if session.watcher_interval > 0:
|
||||
self.pending_watchers.append({
|
||||
"session_id": session.id,
|
||||
"check_interval": session.watcher_interval,
|
||||
"session_key": session.session_key,
|
||||
"platform": session.watcher_platform,
|
||||
"chat_id": session.watcher_chat_id,
|
||||
"thread_id": session.watcher_thread_id,
|
||||
})
|
||||
|
||||
# Clear the checkpoint (will be rewritten as processes finish)
|
||||
try:
|
||||
from utils import atomic_json_write
|
||||
atomic_json_write(CHECKPOINT_PATH, [])
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear checkpoint file: %s", e, exc_info=True)
|
||||
|
||||
return recovered
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
process_registry = ProcessRegistry()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry -- the "process" tool schema + handler
|
||||
# ---------------------------------------------------------------------------
|
||||
from tools.registry import registry
|
||||
|
||||
PROCESS_SCHEMA = {
|
||||
"name": "process",
|
||||
"description": (
|
||||
"Manage background processes started with terminal(background=true). "
|
||||
"Actions: 'list' (show all), 'poll' (check status + new output), "
|
||||
"'log' (full output with pagination), 'wait' (block until done or timeout), "
|
||||
"'kill' (terminate), 'write' (send raw stdin data without newline), "
|
||||
"'submit' (send data + Enter, for answering prompts)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["list", "poll", "log", "wait", "kill", "write", "submit"],
|
||||
"description": "Action to perform on background processes"
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "Process session ID (from terminal background output). Required for all actions except 'list'."
|
||||
},
|
||||
"data": {
|
||||
"type": "string",
|
||||
"description": "Text to send to process stdin (for 'write' and 'submit' actions)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Max seconds to block for 'wait' action. Returns partial output on timeout.",
|
||||
"minimum": 1
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line offset for 'log' action (default: last 200 lines)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max lines to return for 'log' action",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _handle_process(args, **kw):
|
||||
import json as _json
|
||||
task_id = kw.get("task_id")
|
||||
action = args.get("action", "")
|
||||
# Coerce to string — some models send session_id as an integer
|
||||
session_id = str(args.get("session_id", "")) if args.get("session_id") is not None else ""
|
||||
|
||||
if action == "list":
|
||||
return _json.dumps({"processes": process_registry.list_sessions(task_id=task_id)}, ensure_ascii=False)
|
||||
elif action in ("poll", "log", "wait", "kill", "write", "submit"):
|
||||
if not session_id:
|
||||
return _json.dumps({"error": f"session_id is required for {action}"}, ensure_ascii=False)
|
||||
if action == "poll":
|
||||
return _json.dumps(process_registry.poll(session_id), ensure_ascii=False)
|
||||
elif action == "log":
|
||||
return _json.dumps(process_registry.read_log(
|
||||
session_id, offset=args.get("offset", 0), limit=args.get("limit", 200)), ensure_ascii=False)
|
||||
elif action == "wait":
|
||||
return _json.dumps(process_registry.wait(session_id, timeout=args.get("timeout")), ensure_ascii=False)
|
||||
elif action == "kill":
|
||||
return _json.dumps(process_registry.kill_process(session_id), ensure_ascii=False)
|
||||
elif action == "write":
|
||||
return _json.dumps(process_registry.write_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
|
||||
elif action == "submit":
|
||||
return _json.dumps(process_registry.submit_stdin(session_id, str(args.get("data", ""))), ensure_ascii=False)
|
||||
return _json.dumps({"error": f"Unknown process action: {action}. Use: list, poll, log, wait, kill, write, submit"}, ensure_ascii=False)
|
||||
|
||||
|
||||
registry.register(
|
||||
name="process",
|
||||
toolset="terminal",
|
||||
schema=PROCESS_SCHEMA,
|
||||
handler=_handle_process,
|
||||
emoji="⚙️",
|
||||
)
|
||||
237
hermes_code/tools/registry.py
Normal file
237
hermes_code/tools/registry.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""Central registry for all hermes-agent tools.
|
||||
|
||||
Each tool file calls ``registry.register()`` at module level to declare its
|
||||
schema, handler, toolset membership, and availability check. ``model_tools.py``
|
||||
queries the registry instead of maintaining its own parallel data structures.
|
||||
|
||||
Import chain (circular-import safe):
|
||||
tools/registry.py (no imports from model_tools or tool files)
|
||||
^
|
||||
tools/*.py (import from tools.registry at module level)
|
||||
^
|
||||
model_tools.py (imports tools.registry + all tool modules)
|
||||
^
|
||||
run_agent.py, cli.py, batch_runner.py, etc.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolEntry:
|
||||
"""Metadata for a single registered tool."""
|
||||
|
||||
__slots__ = (
|
||||
"name", "toolset", "schema", "handler", "check_fn",
|
||||
"requires_env", "is_async", "description", "emoji",
|
||||
)
|
||||
|
||||
def __init__(self, name, toolset, schema, handler, check_fn,
|
||||
requires_env, is_async, description, emoji):
|
||||
self.name = name
|
||||
self.toolset = toolset
|
||||
self.schema = schema
|
||||
self.handler = handler
|
||||
self.check_fn = check_fn
|
||||
self.requires_env = requires_env
|
||||
self.is_async = is_async
|
||||
self.description = description
|
||||
self.emoji = emoji
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Singleton registry that collects tool schemas + handlers from tool files."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, ToolEntry] = {}
|
||||
self._toolset_checks: Dict[str, Callable] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
toolset: str,
|
||||
schema: dict,
|
||||
handler: Callable,
|
||||
check_fn: Callable = None,
|
||||
requires_env: list = None,
|
||||
is_async: bool = False,
|
||||
description: str = "",
|
||||
emoji: str = "",
|
||||
):
|
||||
"""Register a tool. Called at module-import time by each tool file."""
|
||||
self._tools[name] = ToolEntry(
|
||||
name=name,
|
||||
toolset=toolset,
|
||||
schema=schema,
|
||||
handler=handler,
|
||||
check_fn=check_fn,
|
||||
requires_env=requires_env or [],
|
||||
is_async=is_async,
|
||||
description=description or schema.get("description", ""),
|
||||
emoji=emoji,
|
||||
)
|
||||
if check_fn and toolset not in self._toolset_checks:
|
||||
self._toolset_checks[toolset] = check_fn
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Schema retrieval
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_definitions(self, tool_names: Set[str], quiet: bool = False) -> List[dict]:
|
||||
"""Return OpenAI-format tool schemas for the requested tool names.
|
||||
|
||||
Only tools whose ``check_fn()`` returns True (or have no check_fn)
|
||||
are included.
|
||||
"""
|
||||
result = []
|
||||
for name in sorted(tool_names):
|
||||
entry = self._tools.get(name)
|
||||
if not entry:
|
||||
continue
|
||||
if entry.check_fn:
|
||||
try:
|
||||
if not entry.check_fn():
|
||||
if not quiet:
|
||||
logger.debug("Tool %s unavailable (check failed)", name)
|
||||
continue
|
||||
except Exception:
|
||||
if not quiet:
|
||||
logger.debug("Tool %s check raised; skipping", name)
|
||||
continue
|
||||
result.append({"type": "function", "function": entry.schema})
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def dispatch(self, name: str, args: dict, **kwargs) -> str:
|
||||
"""Execute a tool handler by name.
|
||||
|
||||
* Async handlers are bridged automatically via ``_run_async()``.
|
||||
* All exceptions are caught and returned as ``{"error": "..."}``
|
||||
for consistent error format.
|
||||
"""
|
||||
entry = self._tools.get(name)
|
||||
if not entry:
|
||||
return json.dumps({"error": f"Unknown tool: {name}"})
|
||||
try:
|
||||
if entry.is_async:
|
||||
from model_tools import _run_async
|
||||
return _run_async(entry.handler(args, **kwargs))
|
||||
return entry.handler(args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception("Tool %s dispatch error: %s", name, e)
|
||||
return json.dumps({"error": f"Tool execution failed: {type(e).__name__}: {e}"})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query helpers (replace redundant dicts in model_tools.py)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_all_tool_names(self) -> List[str]:
|
||||
"""Return sorted list of all registered tool names."""
|
||||
return sorted(self._tools.keys())
|
||||
|
||||
def get_toolset_for_tool(self, name: str) -> Optional[str]:
|
||||
"""Return the toolset a tool belongs to, or None."""
|
||||
entry = self._tools.get(name)
|
||||
return entry.toolset if entry else None
|
||||
|
||||
def get_emoji(self, name: str, default: str = "⚡") -> str:
|
||||
"""Return the emoji for a tool, or *default* if unset."""
|
||||
entry = self._tools.get(name)
|
||||
return (entry.emoji if entry and entry.emoji else default)
|
||||
|
||||
def get_tool_to_toolset_map(self) -> Dict[str, str]:
|
||||
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
|
||||
return {name: e.toolset for name, e in self._tools.items()}
|
||||
|
||||
def is_toolset_available(self, toolset: str) -> bool:
|
||||
"""Check if a toolset's requirements are met.
|
||||
|
||||
Returns False (rather than crashing) when the check function raises
|
||||
an unexpected exception (e.g. network error, missing import, bad config).
|
||||
"""
|
||||
check = self._toolset_checks.get(toolset)
|
||||
if not check:
|
||||
return True
|
||||
try:
|
||||
return bool(check())
|
||||
except Exception:
|
||||
logger.debug("Toolset %s check raised; marking unavailable", toolset)
|
||||
return False
|
||||
|
||||
def check_toolset_requirements(self) -> Dict[str, bool]:
|
||||
"""Return ``{toolset: available_bool}`` for every toolset."""
|
||||
toolsets = set(e.toolset for e in self._tools.values())
|
||||
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)}
|
||||
|
||||
def get_available_toolsets(self) -> Dict[str, dict]:
|
||||
"""Return toolset metadata for UI display."""
|
||||
toolsets: Dict[str, dict] = {}
|
||||
for entry in self._tools.values():
|
||||
ts = entry.toolset
|
||||
if ts not in toolsets:
|
||||
toolsets[ts] = {
|
||||
"available": self.is_toolset_available(ts),
|
||||
"tools": [],
|
||||
"description": "",
|
||||
"requirements": [],
|
||||
}
|
||||
toolsets[ts]["tools"].append(entry.name)
|
||||
if entry.requires_env:
|
||||
for env in entry.requires_env:
|
||||
if env not in toolsets[ts]["requirements"]:
|
||||
toolsets[ts]["requirements"].append(env)
|
||||
return toolsets
|
||||
|
||||
def get_toolset_requirements(self) -> Dict[str, dict]:
|
||||
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
|
||||
result: Dict[str, dict] = {}
|
||||
for entry in self._tools.values():
|
||||
ts = entry.toolset
|
||||
if ts not in result:
|
||||
result[ts] = {
|
||||
"name": ts,
|
||||
"env_vars": [],
|
||||
"check_fn": self._toolset_checks.get(ts),
|
||||
"setup_url": None,
|
||||
"tools": [],
|
||||
}
|
||||
if entry.name not in result[ts]["tools"]:
|
||||
result[ts]["tools"].append(entry.name)
|
||||
for env in entry.requires_env:
|
||||
if env not in result[ts]["env_vars"]:
|
||||
result[ts]["env_vars"].append(env)
|
||||
return result
|
||||
|
||||
def check_tool_availability(self, quiet: bool = False):
|
||||
"""Return (available_toolsets, unavailable_info) like the old function."""
|
||||
available = []
|
||||
unavailable = []
|
||||
seen = set()
|
||||
for entry in self._tools.values():
|
||||
ts = entry.toolset
|
||||
if ts in seen:
|
||||
continue
|
||||
seen.add(ts)
|
||||
if self.is_toolset_available(ts):
|
||||
available.append(ts)
|
||||
else:
|
||||
unavailable.append({
|
||||
"name": ts,
|
||||
"env_vars": entry.requires_env,
|
||||
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
|
||||
})
|
||||
return available, unavailable
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
registry = ToolRegistry()
|
||||
1400
hermes_code/tools/rl_training_tool.py
Normal file
1400
hermes_code/tools/rl_training_tool.py
Normal file
File diff suppressed because it is too large
Load diff
691
hermes_code/tools/send_message_tool.py
Normal file
691
hermes_code/tools/send_message_tool.py
Normal file
|
|
@ -0,0 +1,691 @@
|
|||
"""Send Message Tool -- cross-channel messaging via platform APIs.
|
||||
|
||||
Sends a message to a user or channel on any connected messaging platform
|
||||
(Telegram, Discord, Slack). Supports listing available targets and resolving
|
||||
human-friendly channel names to IDs. Works in both CLI and gateway contexts.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TELEGRAM_TOPIC_TARGET_RE = re.compile(r"^\s*(-?\d+)(?::(\d+))?\s*$")
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".3gp"}
|
||||
_AUDIO_EXTS = {".ogg", ".opus", ".mp3", ".wav", ".m4a"}
|
||||
_VOICE_EXTS = {".ogg", ".opus"}
|
||||
|
||||
|
||||
SEND_MESSAGE_SCHEMA = {
|
||||
"name": "send_message",
|
||||
"description": (
|
||||
"Send a message to a connected messaging platform, or list available targets.\n\n"
|
||||
"IMPORTANT: When the user asks to send to a specific channel or person "
|
||||
"(not just a bare platform name), call send_message(action='list') FIRST to see "
|
||||
"available targets, then send to the correct one.\n"
|
||||
"If the user just says a platform name like 'send to telegram', send directly "
|
||||
"to the home channel without listing first."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["send", "list"],
|
||||
"description": "Action to perform. 'send' (default) sends a message. 'list' returns all available channels/contacts across connected platforms."
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', 'platform:chat_id', or Telegram topic 'telegram:chat_id:thread_id'. Examples: 'telegram', 'telegram:-1001234567890:17585', 'discord:#bot-home', 'slack:#engineering', 'signal:+15551234567'"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message text to send"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def send_message_tool(args, **kw):
|
||||
"""Handle cross-channel send_message tool calls."""
|
||||
action = args.get("action", "send")
|
||||
|
||||
if action == "list":
|
||||
return _handle_list()
|
||||
|
||||
return _handle_send(args)
|
||||
|
||||
|
||||
def _handle_list():
|
||||
"""Return formatted list of available messaging targets."""
|
||||
try:
|
||||
from gateway.channel_directory import format_directory_for_display
|
||||
return json.dumps({"targets": format_directory_for_display()})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to load channel directory: {e}"})
|
||||
|
||||
|
||||
def _handle_send(args):
|
||||
"""Send a message to a platform target."""
|
||||
target = args.get("target", "")
|
||||
message = args.get("message", "")
|
||||
if not target or not message:
|
||||
return json.dumps({"error": "Both 'target' and 'message' are required when action='send'"})
|
||||
|
||||
parts = target.split(":", 1)
|
||||
platform_name = parts[0].strip().lower()
|
||||
target_ref = parts[1].strip() if len(parts) > 1 else None
|
||||
chat_id = None
|
||||
thread_id = None
|
||||
|
||||
if target_ref:
|
||||
chat_id, thread_id, is_explicit = _parse_target_ref(platform_name, target_ref)
|
||||
else:
|
||||
is_explicit = False
|
||||
|
||||
# Resolve human-friendly channel names to numeric IDs
|
||||
if target_ref and not is_explicit:
|
||||
try:
|
||||
from gateway.channel_directory import resolve_channel_name
|
||||
resolved = resolve_channel_name(platform_name, target_ref)
|
||||
if resolved:
|
||||
chat_id, thread_id, _ = _parse_target_ref(platform_name, resolved)
|
||||
else:
|
||||
return json.dumps({
|
||||
"error": f"Could not resolve '{target_ref}' on {platform_name}. "
|
||||
f"Use send_message(action='list') to see available targets."
|
||||
})
|
||||
except Exception:
|
||||
return json.dumps({
|
||||
"error": f"Could not resolve '{target_ref}' on {platform_name}. "
|
||||
f"Try using a numeric channel ID instead."
|
||||
})
|
||||
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return json.dumps({"error": "Interrupted"})
|
||||
|
||||
try:
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
config = load_gateway_config()
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to load gateway config: {e}"})
|
||||
|
||||
platform_map = {
|
||||
"telegram": Platform.TELEGRAM,
|
||||
"discord": Platform.DISCORD,
|
||||
"slack": Platform.SLACK,
|
||||
"whatsapp": Platform.WHATSAPP,
|
||||
"signal": Platform.SIGNAL,
|
||||
"matrix": Platform.MATRIX,
|
||||
"mattermost": Platform.MATTERMOST,
|
||||
"homeassistant": Platform.HOMEASSISTANT,
|
||||
"dingtalk": Platform.DINGTALK,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
}
|
||||
platform = platform_map.get(platform_name)
|
||||
if not platform:
|
||||
avail = ", ".join(platform_map.keys())
|
||||
return json.dumps({"error": f"Unknown platform: {platform_name}. Available: {avail}"})
|
||||
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
return json.dumps({"error": f"Platform '{platform_name}' is not configured. Set up credentials in ~/.hermes/config.yaml or environment variables."})
|
||||
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
|
||||
media_files, cleaned_message = BasePlatformAdapter.extract_media(message)
|
||||
mirror_text = cleaned_message.strip() or _describe_media_for_mirror(media_files)
|
||||
|
||||
used_home_channel = False
|
||||
if not chat_id:
|
||||
home = config.get_home_channel(platform)
|
||||
if home:
|
||||
chat_id = home.chat_id
|
||||
used_home_channel = True
|
||||
else:
|
||||
return json.dumps({
|
||||
"error": f"No home channel set for {platform_name} to determine where to send the message. "
|
||||
f"Either specify a channel directly with '{platform_name}:CHANNEL_NAME', "
|
||||
f"or set a home channel via: hermes config set {platform_name.upper()}_HOME_CHANNEL <channel_id>"
|
||||
})
|
||||
|
||||
duplicate_skip = _maybe_skip_cron_duplicate_send(platform_name, chat_id, thread_id)
|
||||
if duplicate_skip:
|
||||
return json.dumps(duplicate_skip)
|
||||
|
||||
try:
|
||||
from model_tools import _run_async
|
||||
result = _run_async(
|
||||
_send_to_platform(
|
||||
platform,
|
||||
pconfig,
|
||||
chat_id,
|
||||
cleaned_message,
|
||||
thread_id=thread_id,
|
||||
media_files=media_files,
|
||||
)
|
||||
)
|
||||
if used_home_channel and isinstance(result, dict) and result.get("success"):
|
||||
result["note"] = f"Sent to {platform_name} home channel (chat_id: {chat_id})"
|
||||
|
||||
# Mirror the sent message into the target's gateway session
|
||||
if isinstance(result, dict) and result.get("success") and mirror_text:
|
||||
try:
|
||||
from gateway.mirror import mirror_to_session
|
||||
source_label = os.getenv("HERMES_SESSION_PLATFORM", "cli")
|
||||
if mirror_to_session(platform_name, chat_id, mirror_text, source_label=source_label, thread_id=thread_id):
|
||||
result["mirrored"] = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return json.dumps(result)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Send failed: {e}"})
|
||||
|
||||
|
||||
def _parse_target_ref(platform_name: str, target_ref: str):
|
||||
"""Parse a tool target into chat_id/thread_id and whether it is explicit."""
|
||||
if platform_name == "telegram":
|
||||
match = _TELEGRAM_TOPIC_TARGET_RE.fullmatch(target_ref)
|
||||
if match:
|
||||
return match.group(1), match.group(2), True
|
||||
if target_ref.lstrip("-").isdigit():
|
||||
return target_ref, None, True
|
||||
return None, None, False
|
||||
|
||||
|
||||
def _describe_media_for_mirror(media_files):
|
||||
"""Return a human-readable mirror summary when a message only contains media."""
|
||||
if not media_files:
|
||||
return ""
|
||||
if len(media_files) == 1:
|
||||
media_path, is_voice = media_files[0]
|
||||
ext = os.path.splitext(media_path)[1].lower()
|
||||
if is_voice and ext in _VOICE_EXTS:
|
||||
return "[Sent voice message]"
|
||||
if ext in _IMAGE_EXTS:
|
||||
return "[Sent image attachment]"
|
||||
if ext in _VIDEO_EXTS:
|
||||
return "[Sent video attachment]"
|
||||
if ext in _AUDIO_EXTS:
|
||||
return "[Sent audio attachment]"
|
||||
return "[Sent document attachment]"
|
||||
return f"[Sent {len(media_files)} media attachments]"
|
||||
|
||||
|
||||
def _get_cron_auto_delivery_target():
|
||||
"""Return the cron scheduler's auto-delivery target for the current run, if any."""
|
||||
platform = os.getenv("HERMES_CRON_AUTO_DELIVER_PLATFORM", "").strip().lower()
|
||||
chat_id = os.getenv("HERMES_CRON_AUTO_DELIVER_CHAT_ID", "").strip()
|
||||
if not platform or not chat_id:
|
||||
return None
|
||||
thread_id = os.getenv("HERMES_CRON_AUTO_DELIVER_THREAD_ID", "").strip() or None
|
||||
return {
|
||||
"platform": platform,
|
||||
"chat_id": chat_id,
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
|
||||
|
||||
def _maybe_skip_cron_duplicate_send(platform_name: str, chat_id: str, thread_id: str | None):
|
||||
"""Skip redundant cron send_message calls when the scheduler will auto-deliver there."""
|
||||
auto_target = _get_cron_auto_delivery_target()
|
||||
if not auto_target:
|
||||
return None
|
||||
|
||||
same_target = (
|
||||
auto_target["platform"] == platform_name
|
||||
and str(auto_target["chat_id"]) == str(chat_id)
|
||||
and auto_target.get("thread_id") == thread_id
|
||||
)
|
||||
if not same_target:
|
||||
return None
|
||||
|
||||
target_label = f"{platform_name}:{chat_id}"
|
||||
if thread_id is not None:
|
||||
target_label += f":{thread_id}"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"skipped": True,
|
||||
"reason": "cron_auto_delivery_duplicate_target",
|
||||
"target": target_label,
|
||||
"note": (
|
||||
f"Skipped send_message to {target_label}. This cron job will already auto-deliver "
|
||||
"its final response to that same target. Put the intended user-facing content in "
|
||||
"your final response instead, or use a different target if you want an additional message."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None, media_files=None):
|
||||
"""Route a message to the appropriate platform sender.
|
||||
|
||||
Long messages are automatically chunked to fit within platform limits
|
||||
using the same smart-splitting algorithm as the gateway adapters
|
||||
(preserves code-block boundaries, adds part indicators).
|
||||
"""
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.platforms.slack import SlackAdapter
|
||||
|
||||
media_files = media_files or []
|
||||
|
||||
# Platform message length limits (from adapter class attributes)
|
||||
_MAX_LENGTHS = {
|
||||
Platform.TELEGRAM: TelegramAdapter.MAX_MESSAGE_LENGTH,
|
||||
Platform.DISCORD: DiscordAdapter.MAX_MESSAGE_LENGTH,
|
||||
Platform.SLACK: SlackAdapter.MAX_MESSAGE_LENGTH,
|
||||
}
|
||||
|
||||
# Smart-chunk the message to fit within platform limits.
|
||||
# For short messages or platforms without a known limit this is a no-op.
|
||||
max_len = _MAX_LENGTHS.get(platform)
|
||||
if max_len:
|
||||
chunks = BasePlatformAdapter.truncate_message(message, max_len)
|
||||
else:
|
||||
chunks = [message]
|
||||
|
||||
# --- Telegram: special handling for media attachments ---
|
||||
if platform == Platform.TELEGRAM:
|
||||
last_result = None
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_last = (i == len(chunks) - 1)
|
||||
result = await _send_telegram(
|
||||
pconfig.token,
|
||||
chat_id,
|
||||
chunk,
|
||||
media_files=media_files if is_last else [],
|
||||
thread_id=thread_id,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
return result
|
||||
last_result = result
|
||||
return last_result
|
||||
|
||||
# --- Non-Telegram platforms ---
|
||||
if media_files and not message.strip():
|
||||
return {
|
||||
"error": (
|
||||
f"send_message MEDIA delivery is currently only supported for telegram; "
|
||||
f"target {platform.value} had only media attachments"
|
||||
)
|
||||
}
|
||||
warning = None
|
||||
if media_files:
|
||||
warning = (
|
||||
f"MEDIA attachments were omitted for {platform.value}; "
|
||||
"native send_message media delivery is currently only supported for telegram"
|
||||
)
|
||||
|
||||
last_result = None
|
||||
for chunk in chunks:
|
||||
if platform == Platform.DISCORD:
|
||||
result = await _send_discord(pconfig.token, chat_id, chunk)
|
||||
elif platform == Platform.SLACK:
|
||||
result = await _send_slack(pconfig.token, chat_id, chunk)
|
||||
elif platform == Platform.WHATSAPP:
|
||||
result = await _send_whatsapp(pconfig.extra, chat_id, chunk)
|
||||
elif platform == Platform.SIGNAL:
|
||||
result = await _send_signal(pconfig.extra, chat_id, chunk)
|
||||
elif platform == Platform.EMAIL:
|
||||
result = await _send_email(pconfig.extra, chat_id, chunk)
|
||||
elif platform == Platform.SMS:
|
||||
result = await _send_sms(pconfig.api_key, chat_id, chunk)
|
||||
else:
|
||||
result = {"error": f"Direct sending not yet implemented for {platform.value}"}
|
||||
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
return result
|
||||
last_result = result
|
||||
|
||||
if warning and isinstance(last_result, dict) and last_result.get("success"):
|
||||
warnings = list(last_result.get("warnings", []))
|
||||
warnings.append(warning)
|
||||
last_result["warnings"] = warnings
|
||||
return last_result
|
||||
|
||||
|
||||
async def _send_telegram(token, chat_id, message, media_files=None, thread_id=None):
|
||||
"""Send via Telegram Bot API (one-shot, no polling needed).
|
||||
|
||||
Applies markdown→MarkdownV2 formatting (same as the gateway adapter)
|
||||
so that bold, links, and headers render correctly. If the message
|
||||
already contains HTML tags, it is sent with ``parse_mode='HTML'``
|
||||
instead, bypassing MarkdownV2 conversion.
|
||||
"""
|
||||
try:
|
||||
from telegram import Bot
|
||||
from telegram.constants import ParseMode
|
||||
|
||||
# Auto-detect HTML tags — if present, skip MarkdownV2 and send as HTML.
|
||||
# Inspired by github.com/ashaney — PR #1568.
|
||||
_has_html = bool(re.search(r'<[a-zA-Z/][^>]*>', message))
|
||||
|
||||
if _has_html:
|
||||
formatted = message
|
||||
send_parse_mode = ParseMode.HTML
|
||||
else:
|
||||
# Reuse the gateway adapter's format_message for markdown→MarkdownV2
|
||||
try:
|
||||
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2
|
||||
_adapter = TelegramAdapter.__new__(TelegramAdapter)
|
||||
formatted = _adapter.format_message(message)
|
||||
except Exception:
|
||||
# Fallback: send as-is if formatting unavailable
|
||||
formatted = message
|
||||
send_parse_mode = ParseMode.MARKDOWN_V2
|
||||
|
||||
bot = Bot(token=token)
|
||||
int_chat_id = int(chat_id)
|
||||
media_files = media_files or []
|
||||
thread_kwargs = {}
|
||||
if thread_id is not None:
|
||||
thread_kwargs["message_thread_id"] = int(thread_id)
|
||||
|
||||
last_msg = None
|
||||
warnings = []
|
||||
|
||||
if formatted.strip():
|
||||
try:
|
||||
last_msg = await bot.send_message(
|
||||
chat_id=int_chat_id, text=formatted,
|
||||
parse_mode=send_parse_mode, **thread_kwargs
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Parse failed, fall back to plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower() or "html" in str(md_error).lower():
|
||||
logger.warning("Parse mode %s failed in _send_telegram, falling back to plain text: %s", send_parse_mode, md_error)
|
||||
if not _has_html:
|
||||
try:
|
||||
from gateway.platforms.telegram import _strip_mdv2
|
||||
plain = _strip_mdv2(formatted)
|
||||
except Exception:
|
||||
plain = message
|
||||
else:
|
||||
plain = message
|
||||
last_msg = await bot.send_message(
|
||||
chat_id=int_chat_id, text=plain,
|
||||
parse_mode=None, **thread_kwargs
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
for media_path, is_voice in media_files:
|
||||
if not os.path.exists(media_path):
|
||||
warning = f"Media file not found, skipping: {media_path}"
|
||||
logger.warning(warning)
|
||||
warnings.append(warning)
|
||||
continue
|
||||
|
||||
ext = os.path.splitext(media_path)[1].lower()
|
||||
try:
|
||||
with open(media_path, "rb") as f:
|
||||
if ext in _IMAGE_EXTS:
|
||||
last_msg = await bot.send_photo(
|
||||
chat_id=int_chat_id, photo=f, **thread_kwargs
|
||||
)
|
||||
elif ext in _VIDEO_EXTS:
|
||||
last_msg = await bot.send_video(
|
||||
chat_id=int_chat_id, video=f, **thread_kwargs
|
||||
)
|
||||
elif ext in _VOICE_EXTS and is_voice:
|
||||
last_msg = await bot.send_voice(
|
||||
chat_id=int_chat_id, voice=f, **thread_kwargs
|
||||
)
|
||||
elif ext in _AUDIO_EXTS:
|
||||
last_msg = await bot.send_audio(
|
||||
chat_id=int_chat_id, audio=f, **thread_kwargs
|
||||
)
|
||||
else:
|
||||
last_msg = await bot.send_document(
|
||||
chat_id=int_chat_id, document=f, **thread_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
warning = f"Failed to send media {media_path}: {e}"
|
||||
logger.error(warning)
|
||||
warnings.append(warning)
|
||||
|
||||
if last_msg is None:
|
||||
error = "No deliverable text or media remained after processing MEDIA tags"
|
||||
if warnings:
|
||||
return {"error": error, "warnings": warnings}
|
||||
return {"error": error}
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"platform": "telegram",
|
||||
"chat_id": chat_id,
|
||||
"message_id": str(last_msg.message_id),
|
||||
}
|
||||
if warnings:
|
||||
result["warnings"] = warnings
|
||||
return result
|
||||
except ImportError:
|
||||
return {"error": "python-telegram-bot not installed. Run: pip install python-telegram-bot"}
|
||||
except Exception as e:
|
||||
return {"error": f"Telegram send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_discord(token, chat_id, message):
|
||||
"""Send a single message via Discord REST API (no websocket client needed).
|
||||
|
||||
Chunking is handled by _send_to_platform() before this is called.
|
||||
"""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
|
||||
try:
|
||||
url = f"https://discord.com/api/v10/channels/{chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json={"content": message}) as resp:
|
||||
if resp.status not in (200, 201):
|
||||
body = await resp.text()
|
||||
return {"error": f"Discord API error ({resp.status}): {body}"}
|
||||
data = await resp.json()
|
||||
return {"success": True, "platform": "discord", "chat_id": chat_id, "message_id": data.get("id")}
|
||||
except Exception as e:
|
||||
return {"error": f"Discord send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_slack(token, chat_id, message):
|
||||
"""Send via Slack Web API."""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
|
||||
try:
|
||||
url = "https://slack.com/api/chat.postMessage"
|
||||
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json={"channel": chat_id, "text": message}) as resp:
|
||||
data = await resp.json()
|
||||
if data.get("ok"):
|
||||
return {"success": True, "platform": "slack", "chat_id": chat_id, "message_id": data.get("ts")}
|
||||
return {"error": f"Slack API error: {data.get('error', 'unknown')}"}
|
||||
except Exception as e:
|
||||
return {"error": f"Slack send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_whatsapp(extra, chat_id, message):
|
||||
"""Send via the local WhatsApp bridge HTTP API."""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
|
||||
try:
|
||||
bridge_port = extra.get("bridge_port", 3000)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://localhost:{bridge_port}/send",
|
||||
json={"chatId": chat_id, "message": message},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"success": True,
|
||||
"platform": "whatsapp",
|
||||
"chat_id": chat_id,
|
||||
"message_id": data.get("messageId"),
|
||||
}
|
||||
body = await resp.text()
|
||||
return {"error": f"WhatsApp bridge error ({resp.status}): {body}"}
|
||||
except Exception as e:
|
||||
return {"error": f"WhatsApp send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_signal(extra, chat_id, message):
|
||||
"""Send via signal-cli JSON-RPC API."""
|
||||
try:
|
||||
import httpx
|
||||
except ImportError:
|
||||
return {"error": "httpx not installed"}
|
||||
try:
|
||||
http_url = extra.get("http_url", "http://127.0.0.1:8080").rstrip("/")
|
||||
account = extra.get("account", "")
|
||||
if not account:
|
||||
return {"error": "Signal account not configured"}
|
||||
|
||||
params = {"account": account, "message": message}
|
||||
if chat_id.startswith("group:"):
|
||||
params["groupId"] = chat_id[6:]
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "send",
|
||||
"params": params,
|
||||
"id": f"send_{int(time.time() * 1000)}",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(f"{http_url}/api/v1/rpc", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if "error" in data:
|
||||
return {"error": f"Signal RPC error: {data['error']}"}
|
||||
return {"success": True, "platform": "signal", "chat_id": chat_id}
|
||||
except Exception as e:
|
||||
return {"error": f"Signal send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_email(extra, chat_id, message):
|
||||
"""Send via SMTP (one-shot, no persistent connection needed)."""
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
address = extra.get("address") or os.getenv("EMAIL_ADDRESS", "")
|
||||
password = os.getenv("EMAIL_PASSWORD", "")
|
||||
smtp_host = extra.get("smtp_host") or os.getenv("EMAIL_SMTP_HOST", "")
|
||||
smtp_port = int(os.getenv("EMAIL_SMTP_PORT", "587"))
|
||||
|
||||
if not all([address, password, smtp_host]):
|
||||
return {"error": "Email not configured (EMAIL_ADDRESS, EMAIL_PASSWORD, EMAIL_SMTP_HOST required)"}
|
||||
|
||||
try:
|
||||
msg = MIMEText(message, "plain", "utf-8")
|
||||
msg["From"] = address
|
||||
msg["To"] = chat_id
|
||||
msg["Subject"] = "Hermes Agent"
|
||||
|
||||
server = smtplib.SMTP(smtp_host, smtp_port)
|
||||
server.starttls(context=ssl.create_default_context())
|
||||
server.login(address, password)
|
||||
server.send_message(msg)
|
||||
server.quit()
|
||||
return {"success": True, "platform": "email", "chat_id": chat_id}
|
||||
except Exception as e:
|
||||
return {"error": f"Email send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_sms(auth_token, chat_id, message):
|
||||
"""Send a single SMS via Twilio REST API.
|
||||
|
||||
Uses HTTP Basic auth (Account SID : Auth Token) and form-encoded POST.
|
||||
Chunking is handled by _send_to_platform() before this is called.
|
||||
"""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
|
||||
|
||||
import base64
|
||||
|
||||
account_sid = os.getenv("TWILIO_ACCOUNT_SID", "")
|
||||
from_number = os.getenv("TWILIO_PHONE_NUMBER", "")
|
||||
if not account_sid or not auth_token or not from_number:
|
||||
return {"error": "SMS not configured (TWILIO_ACCOUNT_SID, TWILIO_AUTH_TOKEN, TWILIO_PHONE_NUMBER required)"}
|
||||
|
||||
# Strip markdown — SMS renders it as literal characters
|
||||
message = re.sub(r"\*\*(.+?)\*\*", r"\1", message, flags=re.DOTALL)
|
||||
message = re.sub(r"\*(.+?)\*", r"\1", message, flags=re.DOTALL)
|
||||
message = re.sub(r"__(.+?)__", r"\1", message, flags=re.DOTALL)
|
||||
message = re.sub(r"_(.+?)_", r"\1", message, flags=re.DOTALL)
|
||||
message = re.sub(r"```[a-z]*\n?", "", message)
|
||||
message = re.sub(r"`(.+?)`", r"\1", message)
|
||||
message = re.sub(r"^#{1,6}\s+", "", message, flags=re.MULTILINE)
|
||||
message = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", message)
|
||||
message = re.sub(r"\n{3,}", "\n\n", message)
|
||||
message = message.strip()
|
||||
|
||||
try:
|
||||
creds = f"{account_sid}:{auth_token}"
|
||||
encoded = base64.b64encode(creds.encode("ascii")).decode("ascii")
|
||||
url = f"https://api.twilio.com/2010-04-01/Accounts/{account_sid}/Messages.json"
|
||||
headers = {"Authorization": f"Basic {encoded}"}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field("From", from_number)
|
||||
form_data.add_field("To", chat_id)
|
||||
form_data.add_field("Body", message)
|
||||
|
||||
async with session.post(url, data=form_data, headers=headers) as resp:
|
||||
body = await resp.json()
|
||||
if resp.status >= 400:
|
||||
error_msg = body.get("message", str(body))
|
||||
return {"error": f"Twilio API error ({resp.status}): {error_msg}"}
|
||||
msg_sid = body.get("sid", "")
|
||||
return {"success": True, "platform": "sms", "chat_id": chat_id, "message_id": msg_sid}
|
||||
except Exception as e:
|
||||
return {"error": f"SMS send failed: {e}"}
|
||||
|
||||
|
||||
def _check_send_message():
|
||||
"""Gate send_message on gateway running (always available on messaging platforms)."""
|
||||
platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
if platform and platform != "local":
|
||||
return True
|
||||
try:
|
||||
from gateway.status import is_gateway_running
|
||||
return is_gateway_running()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="send_message",
|
||||
toolset="messaging",
|
||||
schema=SEND_MESSAGE_SCHEMA,
|
||||
handler=send_message_tool,
|
||||
check_fn=_check_send_message,
|
||||
emoji="📨",
|
||||
)
|
||||
420
hermes_code/tools/session_search_tool.py
Normal file
420
hermes_code/tools/session_search_tool.py
Normal file
|
|
@ -0,0 +1,420 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Session Search Tool - Long-Term Conversation Recall
|
||||
|
||||
Searches past session transcripts in SQLite via FTS5, then summarizes the top
|
||||
matching sessions using a cheap/fast model (same pattern as web_extract).
|
||||
Returns focused summaries of past conversations rather than raw transcripts,
|
||||
keeping the main model's context window clean.
|
||||
|
||||
Flow:
|
||||
1. FTS5 search finds matching messages ranked by relevance
|
||||
2. Groups by session, takes the top N unique sessions (default 3)
|
||||
3. Loads each session's conversation, truncates to ~100k chars centered on matches
|
||||
4. Sends to Gemini Flash with a focused summarization prompt
|
||||
5. Returns per-session summaries with metadata
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
from agent.auxiliary_client import async_call_llm
|
||||
MAX_SESSION_CHARS = 100_000
|
||||
MAX_SUMMARY_TOKENS = 10000
|
||||
|
||||
|
||||
def _format_timestamp(ts: Union[int, float, str, None]) -> str:
|
||||
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
|
||||
|
||||
Returns "unknown" for None, str(ts) if conversion fails.
|
||||
"""
|
||||
if ts is None:
|
||||
return "unknown"
|
||||
try:
|
||||
if isinstance(ts, (int, float)):
|
||||
from datetime import datetime
|
||||
dt = datetime.fromtimestamp(ts)
|
||||
return dt.strftime("%B %d, %Y at %I:%M %p")
|
||||
if isinstance(ts, str):
|
||||
if ts.replace(".", "").replace("-", "").isdigit():
|
||||
from datetime import datetime
|
||||
dt = datetime.fromtimestamp(float(ts))
|
||||
return dt.strftime("%B %d, %Y at %I:%M %p")
|
||||
return ts
|
||||
except (ValueError, OSError, OverflowError) as e:
|
||||
# Log specific errors for debugging while gracefully handling edge cases
|
||||
logging.debug("Failed to format timestamp %s: %s", ts, e, exc_info=True)
|
||||
except Exception as e:
|
||||
logging.debug("Unexpected error formatting timestamp %s: %s", ts, e, exc_info=True)
|
||||
return str(ts)
|
||||
|
||||
|
||||
def _format_conversation(messages: List[Dict[str, Any]]) -> str:
|
||||
"""Format session messages into a readable transcript for summarization."""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown").upper()
|
||||
content = msg.get("content") or ""
|
||||
tool_name = msg.get("tool_name")
|
||||
|
||||
if role == "TOOL" and tool_name:
|
||||
# Truncate long tool outputs
|
||||
if len(content) > 500:
|
||||
content = content[:250] + "\n...[truncated]...\n" + content[-250:]
|
||||
parts.append(f"[TOOL:{tool_name}]: {content}")
|
||||
elif role == "ASSISTANT":
|
||||
# Include tool call names if present
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list):
|
||||
tc_names = []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
name = tc.get("name") or tc.get("function", {}).get("name", "?")
|
||||
tc_names.append(name)
|
||||
if tc_names:
|
||||
parts.append(f"[ASSISTANT]: [Called: {', '.join(tc_names)}]")
|
||||
if content:
|
||||
parts.append(f"[ASSISTANT]: {content}")
|
||||
else:
|
||||
parts.append(f"[ASSISTANT]: {content}")
|
||||
else:
|
||||
parts.append(f"[{role}]: {content}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _truncate_around_matches(
|
||||
full_text: str, query: str, max_chars: int = MAX_SESSION_CHARS
|
||||
) -> str:
|
||||
"""
|
||||
Truncate a conversation transcript to max_chars, centered around
|
||||
where the query terms appear. Keeps content near matches, trims the edges.
|
||||
"""
|
||||
if len(full_text) <= max_chars:
|
||||
return full_text
|
||||
|
||||
# Find the first occurrence of any query term
|
||||
query_terms = query.lower().split()
|
||||
text_lower = full_text.lower()
|
||||
first_match = len(full_text)
|
||||
for term in query_terms:
|
||||
pos = text_lower.find(term)
|
||||
if pos != -1 and pos < first_match:
|
||||
first_match = pos
|
||||
|
||||
if first_match == len(full_text):
|
||||
# No match found, take from the start
|
||||
first_match = 0
|
||||
|
||||
# Center the window around the first match
|
||||
half = max_chars // 2
|
||||
start = max(0, first_match - half)
|
||||
end = min(len(full_text), start + max_chars)
|
||||
if end - start < max_chars:
|
||||
start = max(0, end - max_chars)
|
||||
|
||||
truncated = full_text[start:end]
|
||||
prefix = "...[earlier conversation truncated]...\n\n" if start > 0 else ""
|
||||
suffix = "\n\n...[later conversation truncated]..." if end < len(full_text) else ""
|
||||
return prefix + truncated + suffix
|
||||
|
||||
|
||||
async def _summarize_session(
|
||||
conversation_text: str, query: str, session_meta: Dict[str, Any]
|
||||
) -> Optional[str]:
|
||||
"""Summarize a single session conversation focused on the search query."""
|
||||
system_prompt = (
|
||||
"You are reviewing a past conversation transcript to help recall what happened. "
|
||||
"Summarize the conversation with a focus on the search topic. Include:\n"
|
||||
"1. What the user asked about or wanted to accomplish\n"
|
||||
"2. What actions were taken and what the outcomes were\n"
|
||||
"3. Key decisions, solutions found, or conclusions reached\n"
|
||||
"4. Any specific commands, files, URLs, or technical details that were important\n"
|
||||
"5. Anything left unresolved or notable\n\n"
|
||||
"Be thorough but concise. Preserve specific details (commands, paths, error messages) "
|
||||
"that would be useful to recall. Write in past tense as a factual recap."
|
||||
)
|
||||
|
||||
source = session_meta.get("source", "unknown")
|
||||
started = _format_timestamp(session_meta.get("started_at"))
|
||||
|
||||
user_prompt = (
|
||||
f"Search topic: {query}\n"
|
||||
f"Session source: {source}\n"
|
||||
f"Session date: {started}\n\n"
|
||||
f"CONVERSATION TRANSCRIPT:\n{conversation_text}\n\n"
|
||||
f"Summarize this conversation with focus on: {query}"
|
||||
)
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = await async_call_llm(
|
||||
task="session_search",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=MAX_SUMMARY_TOKENS,
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
except RuntimeError:
|
||||
logging.warning("No auxiliary model available for session summarization")
|
||||
return None
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(1 * (attempt + 1))
|
||||
else:
|
||||
logging.warning(
|
||||
"Session summarization failed after %d attempts: %s",
|
||||
max_retries,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def session_search(
|
||||
query: str,
|
||||
role_filter: str = None,
|
||||
limit: int = 3,
|
||||
db=None,
|
||||
current_session_id: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search past sessions and return focused summaries of matching conversations.
|
||||
|
||||
Uses FTS5 to find matches, then summarizes the top sessions with Gemini Flash.
|
||||
The current session is excluded from results since the agent already has that context.
|
||||
"""
|
||||
if db is None:
|
||||
return json.dumps({"success": False, "error": "Session database not available."}, ensure_ascii=False)
|
||||
|
||||
if not query or not query.strip():
|
||||
return json.dumps({"success": False, "error": "Query cannot be empty."}, ensure_ascii=False)
|
||||
|
||||
query = query.strip()
|
||||
limit = min(limit, 5) # Cap at 5 sessions to avoid excessive LLM calls
|
||||
|
||||
try:
|
||||
# Parse role filter
|
||||
role_list = None
|
||||
if role_filter and role_filter.strip():
|
||||
role_list = [r.strip() for r in role_filter.split(",") if r.strip()]
|
||||
|
||||
# FTS5 search -- get matches ranked by relevance
|
||||
raw_results = db.search_messages(
|
||||
query=query,
|
||||
role_filter=role_list,
|
||||
limit=50, # Get more matches to find unique sessions
|
||||
offset=0,
|
||||
)
|
||||
|
||||
if not raw_results:
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": [],
|
||||
"count": 0,
|
||||
"message": "No matching sessions found.",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Resolve child sessions to their parent — delegation stores detailed
|
||||
# content in child sessions, but the user's conversation is the parent.
|
||||
def _resolve_to_parent(session_id: str) -> str:
|
||||
"""Walk delegation chain to find the root parent session ID."""
|
||||
visited = set()
|
||||
sid = session_id
|
||||
while sid and sid not in visited:
|
||||
visited.add(sid)
|
||||
try:
|
||||
session = db.get_session(sid)
|
||||
if not session:
|
||||
break
|
||||
parent = session.get("parent_session_id")
|
||||
if parent:
|
||||
sid = parent
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logging.debug(
|
||||
"Error resolving parent for session %s: %s",
|
||||
sid,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
return sid
|
||||
|
||||
current_lineage_root = (
|
||||
_resolve_to_parent(current_session_id) if current_session_id else None
|
||||
)
|
||||
|
||||
# Group by resolved (parent) session_id, dedup, skip the current
|
||||
# session lineage. Compression and delegation create child sessions
|
||||
# that still belong to the same active conversation.
|
||||
seen_sessions = {}
|
||||
for result in raw_results:
|
||||
raw_sid = result["session_id"]
|
||||
resolved_sid = _resolve_to_parent(raw_sid)
|
||||
# Skip the current session lineage — the agent already has that
|
||||
# context, even if older turns live in parent fragments.
|
||||
if current_lineage_root and resolved_sid == current_lineage_root:
|
||||
continue
|
||||
if current_session_id and raw_sid == current_session_id:
|
||||
continue
|
||||
if resolved_sid not in seen_sessions:
|
||||
result = dict(result)
|
||||
result["session_id"] = resolved_sid
|
||||
seen_sessions[resolved_sid] = result
|
||||
if len(seen_sessions) >= limit:
|
||||
break
|
||||
|
||||
# Prepare all sessions for parallel summarization
|
||||
tasks = []
|
||||
for session_id, match_info in seen_sessions.items():
|
||||
try:
|
||||
messages = db.get_messages_as_conversation(session_id)
|
||||
if not messages:
|
||||
continue
|
||||
session_meta = db.get_session(session_id) or {}
|
||||
conversation_text = _format_conversation(messages)
|
||||
conversation_text = _truncate_around_matches(conversation_text, query)
|
||||
tasks.append((session_id, match_info, conversation_text, session_meta))
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"Failed to prepare session %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Summarize all sessions in parallel
|
||||
async def _summarize_all() -> List[Union[str, Exception]]:
|
||||
"""Summarize all sessions in parallel."""
|
||||
coros = [
|
||||
_summarize_session(text, query, meta)
|
||||
for _, _, text, meta in tasks
|
||||
]
|
||||
return await asyncio.gather(*coros, return_exceptions=True)
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
results = pool.submit(lambda: asyncio.run(_summarize_all())).result(timeout=60)
|
||||
except RuntimeError:
|
||||
# No event loop running, create a new one
|
||||
results = asyncio.run(_summarize_all())
|
||||
except concurrent.futures.TimeoutError:
|
||||
logging.warning(
|
||||
"Session summarization timed out after 60 seconds",
|
||||
exc_info=True,
|
||||
)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
summaries = []
|
||||
for (session_id, match_info, _, _), result in zip(tasks, results):
|
||||
if isinstance(result, Exception):
|
||||
logging.warning(
|
||||
"Failed to summarize session %s: %s",
|
||||
session_id,
|
||||
result,
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
if result:
|
||||
summaries.append({
|
||||
"session_id": session_id,
|
||||
"when": _format_timestamp(match_info.get("session_started")),
|
||||
"source": match_info.get("source", "unknown"),
|
||||
"model": match_info.get("model"),
|
||||
"summary": result,
|
||||
})
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": summaries,
|
||||
"count": len(summaries),
|
||||
"sessions_searched": len(seen_sessions),
|
||||
}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
logging.error("Session search failed: %s", e, exc_info=True)
|
||||
return json.dumps({"success": False, "error": f"Search failed: {str(e)}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_session_search_requirements() -> bool:
|
||||
"""Requires SQLite state database and an auxiliary text model."""
|
||||
try:
|
||||
from hermes_state import DEFAULT_DB_PATH
|
||||
return DEFAULT_DB_PATH.parent.exists()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
SESSION_SEARCH_SCHEMA = {
|
||||
"name": "session_search",
|
||||
"description": (
|
||||
"Search your long-term memory of past conversations. This is your recall -- "
|
||||
"every past session is searchable, and this tool summarizes what happened.\n\n"
|
||||
"USE THIS PROACTIVELY when:\n"
|
||||
"- The user says 'we did this before', 'remember when', 'last time', 'as I mentioned'\n"
|
||||
"- The user asks about a topic you worked on before but don't have in current context\n"
|
||||
"- The user references a project, person, or concept that seems familiar but isn't in memory\n"
|
||||
"- You want to check if you've solved a similar problem before\n"
|
||||
"- The user asks 'what did we do about X?' or 'how did we fix Y?'\n\n"
|
||||
"Don't hesitate to search when it is actually cross-session -- it's fast and cheap. "
|
||||
"Better to search and confirm than to guess or ask the user to repeat themselves.\n\n"
|
||||
"Search syntax: keywords joined with OR for broad recall (elevenlabs OR baseten OR funding), "
|
||||
"phrases for exact match (\"docker networking\"), boolean (python NOT java), prefix (deploy*). "
|
||||
"IMPORTANT: Use OR between keywords for best results — FTS5 defaults to AND which misses "
|
||||
"sessions that only mention some terms. If a broad OR query returns nothing, try individual "
|
||||
"keyword searches in parallel. Returns summaries of the top matching sessions."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query — keywords, phrases, or boolean expressions to find in past sessions.",
|
||||
},
|
||||
"role_filter": {
|
||||
"type": "string",
|
||||
"description": "Optional: only search messages from specific roles (comma-separated). E.g. 'user,assistant' to skip tool outputs.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max sessions to summarize (default: 3, max: 5).",
|
||||
"default": 3,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="session_search",
|
||||
toolset="session_search",
|
||||
schema=SESSION_SEARCH_SCHEMA,
|
||||
handler=lambda args, **kw: session_search(
|
||||
query=args.get("query", ""),
|
||||
role_filter=args.get("role_filter"),
|
||||
limit=args.get("limit", 3),
|
||||
db=kw.get("db"),
|
||||
current_session_id=kw.get("current_session_id")),
|
||||
check_fn=check_session_search_requirements,
|
||||
emoji="🔍",
|
||||
)
|
||||
664
hermes_code/tools/skill_manager_tool.py
Normal file
664
hermes_code/tools/skill_manager_tool.py
Normal file
|
|
@ -0,0 +1,664 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Skill Manager Tool -- Agent-Managed Skill Creation & Editing
|
||||
|
||||
Allows the agent to create, update, and delete skills, turning successful
|
||||
approaches into reusable procedural knowledge. New skills are created in
|
||||
~/.hermes/skills/. Existing skills (bundled, hub-installed, or user-created)
|
||||
can be modified or deleted wherever they live.
|
||||
|
||||
Skills are the agent's procedural memory: they capture *how to do a specific
|
||||
type of task* based on proven experience. General memory (MEMORY.md, USER.md) is
|
||||
broad and declarative. Skills are narrow and actionable.
|
||||
|
||||
Actions:
|
||||
create -- Create a new skill (SKILL.md + directory structure)
|
||||
edit -- Replace the SKILL.md content of a user skill (full rewrite)
|
||||
patch -- Targeted find-and-replace within SKILL.md or any supporting file
|
||||
delete -- Remove a user skill entirely
|
||||
write_file -- Add/overwrite a supporting file (reference, template, script, asset)
|
||||
remove_file-- Remove a supporting file from a user skill
|
||||
|
||||
Directory layout for user skills:
|
||||
~/.hermes/skills/
|
||||
├── my-skill/
|
||||
│ ├── SKILL.md
|
||||
│ ├── references/
|
||||
│ ├── templates/
|
||||
│ ├── scripts/
|
||||
│ └── assets/
|
||||
└── category-name/
|
||||
└── another-skill/
|
||||
└── SKILL.md
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import security scanner — agent-created skills get the same scrutiny as
|
||||
# community hub installs.
|
||||
try:
|
||||
from tools.skills_guard import scan_skill, should_allow_install, format_scan_report
|
||||
_GUARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
_GUARD_AVAILABLE = False
|
||||
|
||||
|
||||
def _security_scan_skill(skill_dir: Path) -> Optional[str]:
|
||||
"""Scan a skill directory after write. Returns error string if blocked, else None."""
|
||||
if not _GUARD_AVAILABLE:
|
||||
return None
|
||||
try:
|
||||
result = scan_skill(skill_dir, source="agent-created")
|
||||
allowed, reason = should_allow_install(result)
|
||||
if allowed is False:
|
||||
report = format_scan_report(result)
|
||||
return f"Security scan blocked this skill ({reason}):\n{report}"
|
||||
if allowed is None:
|
||||
# "ask" — allow but include the warning so the user sees the findings
|
||||
report = format_scan_report(result)
|
||||
logger.warning("Agent-created skill has security findings: %s", reason)
|
||||
# Don't block — return None to allow, but log the warning
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Security scan failed for %s: %s", skill_dir, e, exc_info=True)
|
||||
return None
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
# All skills live in ~/.hermes/skills/ (single source of truth)
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
SKILLS_DIR = HERMES_HOME / "skills"
|
||||
|
||||
MAX_NAME_LENGTH = 64
|
||||
MAX_DESCRIPTION_LENGTH = 1024
|
||||
|
||||
# Characters allowed in skill names (filesystem-safe, URL-friendly)
|
||||
VALID_NAME_RE = re.compile(r'^[a-z0-9][a-z0-9._-]*$')
|
||||
|
||||
# Subdirectories allowed for write_file/remove_file
|
||||
ALLOWED_SUBDIRS = {"references", "templates", "scripts", "assets"}
|
||||
|
||||
|
||||
def check_skill_manage_requirements() -> bool:
|
||||
"""Skill management has no external requirements -- always available."""
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Validation helpers
|
||||
# =============================================================================
|
||||
|
||||
def _validate_name(name: str) -> Optional[str]:
|
||||
"""Validate a skill name. Returns error message or None if valid."""
|
||||
if not name:
|
||||
return "Skill name is required."
|
||||
if len(name) > MAX_NAME_LENGTH:
|
||||
return f"Skill name exceeds {MAX_NAME_LENGTH} characters."
|
||||
if not VALID_NAME_RE.match(name):
|
||||
return (
|
||||
f"Invalid skill name '{name}'. Use lowercase letters, numbers, "
|
||||
f"hyphens, dots, and underscores. Must start with a letter or digit."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _validate_frontmatter(content: str) -> Optional[str]:
|
||||
"""
|
||||
Validate that SKILL.md content has proper frontmatter with required fields.
|
||||
Returns error message or None if valid.
|
||||
"""
|
||||
if not content.strip():
|
||||
return "Content cannot be empty."
|
||||
|
||||
if not content.startswith("---"):
|
||||
return "SKILL.md must start with YAML frontmatter (---). See existing skills for format."
|
||||
|
||||
end_match = re.search(r'\n---\s*\n', content[3:])
|
||||
if not end_match:
|
||||
return "SKILL.md frontmatter is not closed. Ensure you have a closing '---' line."
|
||||
|
||||
yaml_content = content[3:end_match.start() + 3]
|
||||
|
||||
try:
|
||||
parsed = yaml.safe_load(yaml_content)
|
||||
except yaml.YAMLError as e:
|
||||
return f"YAML frontmatter parse error: {e}"
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
return "Frontmatter must be a YAML mapping (key: value pairs)."
|
||||
|
||||
if "name" not in parsed:
|
||||
return "Frontmatter must include 'name' field."
|
||||
if "description" not in parsed:
|
||||
return "Frontmatter must include 'description' field."
|
||||
if len(str(parsed["description"])) > MAX_DESCRIPTION_LENGTH:
|
||||
return f"Description exceeds {MAX_DESCRIPTION_LENGTH} characters."
|
||||
|
||||
body = content[end_match.end() + 3:].strip()
|
||||
if not body:
|
||||
return "SKILL.md must have content after the frontmatter (instructions, procedures, etc.)."
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_skill_dir(name: str, category: str = None) -> Path:
|
||||
"""Build the directory path for a new skill, optionally under a category."""
|
||||
if category:
|
||||
return SKILLS_DIR / category / name
|
||||
return SKILLS_DIR / name
|
||||
|
||||
|
||||
def _find_skill(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Find a skill by name in ~/.hermes/skills/.
|
||||
Returns {"path": Path} or None.
|
||||
"""
|
||||
if not SKILLS_DIR.exists():
|
||||
return None
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
if skill_md.parent.name == name:
|
||||
return {"path": skill_md.parent}
|
||||
return None
|
||||
|
||||
|
||||
def _validate_file_path(file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Validate a file path for write_file/remove_file.
|
||||
Must be under an allowed subdirectory and not escape the skill dir.
|
||||
"""
|
||||
if not file_path:
|
||||
return "file_path is required."
|
||||
|
||||
normalized = Path(file_path)
|
||||
|
||||
# Prevent path traversal
|
||||
if ".." in normalized.parts:
|
||||
return "Path traversal ('..') is not allowed."
|
||||
|
||||
# Must be under an allowed subdirectory
|
||||
if not normalized.parts or normalized.parts[0] not in ALLOWED_SUBDIRS:
|
||||
allowed = ", ".join(sorted(ALLOWED_SUBDIRS))
|
||||
return f"File must be under one of: {allowed}. Got: '{file_path}'"
|
||||
|
||||
# Must have a filename (not just a directory)
|
||||
if len(normalized.parts) < 2:
|
||||
return f"Provide a file path, not just a directory. Example: '{normalized.parts[0]}/myfile.md'"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _atomic_write_text(file_path: Path, content: str, encoding: str = "utf-8") -> None:
|
||||
"""
|
||||
Atomically write text content to a file.
|
||||
|
||||
Uses a temporary file in the same directory and os.replace() to ensure
|
||||
the target file is never left in a partially-written state if the process
|
||||
crashes or is interrupted.
|
||||
|
||||
Args:
|
||||
file_path: Target file path
|
||||
content: Content to write
|
||||
encoding: Text encoding (default: utf-8)
|
||||
"""
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, temp_path = tempfile.mkstemp(
|
||||
dir=str(file_path.parent),
|
||||
prefix=f".{file_path.name}.tmp.",
|
||||
suffix="",
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding=encoding) as f:
|
||||
f.write(content)
|
||||
os.replace(temp_path, file_path)
|
||||
except Exception:
|
||||
# Clean up temp file on error
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
logger.error("Failed to remove temporary file %s during atomic write", temp_path, exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Core actions
|
||||
# =============================================================================
|
||||
|
||||
def _create_skill(name: str, content: str, category: str = None) -> Dict[str, Any]:
|
||||
"""Create a new user skill with SKILL.md content."""
|
||||
# Validate name
|
||||
err = _validate_name(name)
|
||||
if err:
|
||||
return {"success": False, "error": err}
|
||||
|
||||
# Validate content
|
||||
err = _validate_frontmatter(content)
|
||||
if err:
|
||||
return {"success": False, "error": err}
|
||||
|
||||
# Check for name collisions across all directories
|
||||
existing = _find_skill(name)
|
||||
if existing:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"A skill named '{name}' already exists at {existing['path']}."
|
||||
}
|
||||
|
||||
# Create the skill directory
|
||||
skill_dir = _resolve_skill_dir(name, category)
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write SKILL.md atomically
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
_atomic_write_text(skill_md, content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(skill_dir)
|
||||
if scan_error:
|
||||
shutil.rmtree(skill_dir, ignore_errors=True)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"Skill '{name}' created.",
|
||||
"path": str(skill_dir.relative_to(SKILLS_DIR)),
|
||||
"skill_md": str(skill_md),
|
||||
}
|
||||
if category:
|
||||
result["category"] = category
|
||||
result["hint"] = (
|
||||
"To add reference files, templates, or scripts, use "
|
||||
"skill_manage(action='write_file', name='{}', file_path='references/example.md', file_content='...')".format(name)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _edit_skill(name: str, content: str) -> Dict[str, Any]:
|
||||
"""Replace the SKILL.md of any existing skill (full rewrite)."""
|
||||
err = _validate_frontmatter(content)
|
||||
if err:
|
||||
return {"success": False, "error": err}
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return {"success": False, "error": f"Skill '{name}' not found. Use skills_list() to see available skills."}
|
||||
|
||||
skill_md = existing["path"] / "SKILL.md"
|
||||
# Back up original content for rollback
|
||||
original_content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else None
|
||||
_atomic_write_text(skill_md, content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
_atomic_write_text(skill_md, original_content)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Skill '{name}' updated.",
|
||||
"path": str(existing["path"]),
|
||||
}
|
||||
|
||||
|
||||
def _patch_skill(
|
||||
name: str,
|
||||
old_string: str,
|
||||
new_string: str,
|
||||
file_path: str = None,
|
||||
replace_all: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Targeted find-and-replace within a skill file.
|
||||
|
||||
Defaults to SKILL.md. Use file_path to patch a supporting file instead.
|
||||
Requires a unique match unless replace_all is True.
|
||||
"""
|
||||
if not old_string:
|
||||
return {"success": False, "error": "old_string is required for 'patch'."}
|
||||
if new_string is None:
|
||||
return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."}
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return {"success": False, "error": f"Skill '{name}' not found."}
|
||||
|
||||
skill_dir = existing["path"]
|
||||
|
||||
if file_path:
|
||||
# Patching a supporting file
|
||||
err = _validate_file_path(file_path)
|
||||
if err:
|
||||
return {"success": False, "error": err}
|
||||
target = skill_dir / file_path
|
||||
else:
|
||||
# Patching SKILL.md
|
||||
target = skill_dir / "SKILL.md"
|
||||
|
||||
if not target.exists():
|
||||
return {"success": False, "error": f"File not found: {target.relative_to(skill_dir)}"}
|
||||
|
||||
content = target.read_text(encoding="utf-8")
|
||||
|
||||
count = content.count(old_string)
|
||||
if count == 0:
|
||||
# Show a short preview of the file so the model can self-correct
|
||||
preview = content[:500] + ("..." if len(content) > 500 else "")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "old_string not found in the file.",
|
||||
"file_preview": preview,
|
||||
}
|
||||
|
||||
if count > 1 and not replace_all:
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"old_string matched {count} times. Provide more surrounding context "
|
||||
f"to make the match unique, or set replace_all=true to replace all occurrences."
|
||||
),
|
||||
"match_count": count,
|
||||
}
|
||||
|
||||
new_content = content.replace(old_string, new_string) if replace_all else content.replace(old_string, new_string, 1)
|
||||
|
||||
# If patching SKILL.md, validate frontmatter is still intact
|
||||
if not file_path:
|
||||
err = _validate_frontmatter(new_content)
|
||||
if err:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Patch would break SKILL.md structure: {err}",
|
||||
}
|
||||
|
||||
original_content = content # for rollback
|
||||
_atomic_write_text(target, new_content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(skill_dir)
|
||||
if scan_error:
|
||||
_atomic_write_text(target, original_content)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
replacements = count if replace_all else 1
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Patched {'SKILL.md' if not file_path else file_path} in skill '{name}' ({replacements} replacement{'s' if replacements > 1 else ''}).",
|
||||
}
|
||||
|
||||
|
||||
def _delete_skill(name: str) -> Dict[str, Any]:
|
||||
"""Delete a skill."""
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return {"success": False, "error": f"Skill '{name}' not found."}
|
||||
|
||||
skill_dir = existing["path"]
|
||||
shutil.rmtree(skill_dir)
|
||||
|
||||
# Clean up empty category directories (don't remove SKILLS_DIR itself)
|
||||
parent = skill_dir.parent
|
||||
if parent != SKILLS_DIR and parent.exists() and not any(parent.iterdir()):
|
||||
parent.rmdir()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Skill '{name}' deleted.",
|
||||
}
|
||||
|
||||
|
||||
def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
||||
"""Add or overwrite a supporting file within any skill directory."""
|
||||
err = _validate_file_path(file_path)
|
||||
if err:
|
||||
return {"success": False, "error": err}
|
||||
|
||||
if not file_content and file_content != "":
|
||||
return {"success": False, "error": "file_content is required."}
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return {"success": False, "error": f"Skill '{name}' not found. Create it first with action='create'."}
|
||||
|
||||
target = existing["path"] / file_path
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Back up for rollback
|
||||
original_content = target.read_text(encoding="utf-8") if target.exists() else None
|
||||
_atomic_write_text(target, file_content)
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
_atomic_write_text(target, original_content)
|
||||
else:
|
||||
target.unlink(missing_ok=True)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"File '{file_path}' written to skill '{name}'.",
|
||||
"path": str(target),
|
||||
}
|
||||
|
||||
|
||||
def _remove_file(name: str, file_path: str) -> Dict[str, Any]:
|
||||
"""Remove a supporting file from any skill directory."""
|
||||
err = _validate_file_path(file_path)
|
||||
if err:
|
||||
return {"success": False, "error": err}
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return {"success": False, "error": f"Skill '{name}' not found."}
|
||||
skill_dir = existing["path"]
|
||||
|
||||
target = skill_dir / file_path
|
||||
if not target.exists():
|
||||
# List what's actually there for the model to see
|
||||
available = []
|
||||
for subdir in ALLOWED_SUBDIRS:
|
||||
d = skill_dir / subdir
|
||||
if d.exists():
|
||||
for f in d.rglob("*"):
|
||||
if f.is_file():
|
||||
available.append(str(f.relative_to(skill_dir)))
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"File '{file_path}' not found in skill '{name}'.",
|
||||
"available_files": available if available else None,
|
||||
}
|
||||
|
||||
target.unlink()
|
||||
|
||||
# Clean up empty subdirectories
|
||||
parent = target.parent
|
||||
if parent != skill_dir and parent.exists() and not any(parent.iterdir()):
|
||||
parent.rmdir()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"File '{file_path}' removed from skill '{name}'.",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main entry point
|
||||
# =============================================================================
|
||||
|
||||
def skill_manage(
|
||||
action: str,
|
||||
name: str,
|
||||
content: str = None,
|
||||
category: str = None,
|
||||
file_path: str = None,
|
||||
file_content: str = None,
|
||||
old_string: str = None,
|
||||
new_string: str = None,
|
||||
replace_all: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Manage user-created skills. Dispatches to the appropriate action handler.
|
||||
|
||||
Returns JSON string with results.
|
||||
"""
|
||||
if action == "create":
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "content is required for 'create'. Provide the full SKILL.md text (frontmatter + body)."}, ensure_ascii=False)
|
||||
result = _create_skill(name, content, category)
|
||||
|
||||
elif action == "edit":
|
||||
if not content:
|
||||
return json.dumps({"success": False, "error": "content is required for 'edit'. Provide the full updated SKILL.md text."}, ensure_ascii=False)
|
||||
result = _edit_skill(name, content)
|
||||
|
||||
elif action == "patch":
|
||||
if not old_string:
|
||||
return json.dumps({"success": False, "error": "old_string is required for 'patch'. Provide the text to find."}, ensure_ascii=False)
|
||||
if new_string is None:
|
||||
return json.dumps({"success": False, "error": "new_string is required for 'patch'. Use empty string to delete matched text."}, ensure_ascii=False)
|
||||
result = _patch_skill(name, old_string, new_string, file_path, replace_all)
|
||||
|
||||
elif action == "delete":
|
||||
result = _delete_skill(name)
|
||||
|
||||
elif action == "write_file":
|
||||
if not file_path:
|
||||
return json.dumps({"success": False, "error": "file_path is required for 'write_file'. Example: 'references/api-guide.md'"}, ensure_ascii=False)
|
||||
if file_content is None:
|
||||
return json.dumps({"success": False, "error": "file_content is required for 'write_file'."}, ensure_ascii=False)
|
||||
result = _write_file(name, file_path, file_content)
|
||||
|
||||
elif action == "remove_file":
|
||||
if not file_path:
|
||||
return json.dumps({"success": False, "error": "file_path is required for 'remove_file'."}, ensure_ascii=False)
|
||||
result = _remove_file(name, file_path)
|
||||
|
||||
else:
|
||||
result = {"success": False, "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file"}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI Function-Calling Schema
|
||||
# =============================================================================
|
||||
|
||||
SKILL_MANAGE_SCHEMA = {
|
||||
"name": "skill_manage",
|
||||
"description": (
|
||||
"Manage skills (create, update, delete). Skills are your procedural "
|
||||
"memory — reusable approaches for recurring task types. "
|
||||
"New skills go to ~/.hermes/skills/; existing skills can be modified wherever they live.\n\n"
|
||||
"Actions: create (full SKILL.md + optional category), "
|
||||
"patch (old_string/new_string — preferred for fixes), "
|
||||
"edit (full SKILL.md rewrite — major overhauls only), "
|
||||
"delete, write_file, remove_file.\n\n"
|
||||
"Create when: complex task succeeded (5+ calls), errors overcome, "
|
||||
"user-corrected approach worked, non-trivial workflow discovered, "
|
||||
"or user asks you to remember a procedure.\n"
|
||||
"Update when: instructions stale/wrong, OS-specific failures, "
|
||||
"missing steps or pitfalls found during use. "
|
||||
"If you used a skill and hit issues not covered by it, patch it immediately.\n\n"
|
||||
"After difficult/iterative tasks, offer to save as a skill. "
|
||||
"Skip for simple one-offs. Confirm with user before creating/deleting.\n\n"
|
||||
"Good skills: trigger conditions, numbered steps with exact commands, "
|
||||
"pitfalls section, verification steps. Use skill_view() to see format examples."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["create", "patch", "edit", "delete", "write_file", "remove_file"],
|
||||
"description": "The action to perform."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Skill name (lowercase, hyphens/underscores, max 64 chars). "
|
||||
"Must match an existing skill for patch/edit/delete/write_file/remove_file."
|
||||
)
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Full SKILL.md content (YAML frontmatter + markdown body). "
|
||||
"Required for 'create' and 'edit'. For 'edit', read the skill "
|
||||
"first with skill_view() and provide the complete updated text."
|
||||
)
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Text to find in the file (required for 'patch'). Must be unique "
|
||||
"unless replace_all=true. Include enough surrounding context to "
|
||||
"ensure uniqueness."
|
||||
)
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Replacement text (required for 'patch'). Can be empty string "
|
||||
"to delete the matched text."
|
||||
)
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "For 'patch': replace all occurrences instead of requiring a unique match (default: false)."
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional category/domain for organizing the skill (e.g., 'devops', "
|
||||
"'data-science', 'mlops'). Creates a subdirectory grouping. "
|
||||
"Only used with 'create'."
|
||||
)
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path to a supporting file within the skill directory. "
|
||||
"For 'write_file'/'remove_file': required, must be under references/, "
|
||||
"templates/, scripts/, or assets/. "
|
||||
"For 'patch': optional, defaults to SKILL.md if omitted."
|
||||
)
|
||||
},
|
||||
"file_content": {
|
||||
"type": "string",
|
||||
"description": "Content for the file. Required for 'write_file'."
|
||||
},
|
||||
},
|
||||
"required": ["action", "name"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="skill_manage",
|
||||
toolset="skills",
|
||||
schema=SKILL_MANAGE_SCHEMA,
|
||||
handler=lambda args, **kw: skill_manage(
|
||||
action=args.get("action", ""),
|
||||
name=args.get("name", ""),
|
||||
content=args.get("content"),
|
||||
category=args.get("category"),
|
||||
file_path=args.get("file_path"),
|
||||
file_content=args.get("file_content"),
|
||||
old_string=args.get("old_string"),
|
||||
new_string=args.get("new_string"),
|
||||
replace_all=args.get("replace_all", False)),
|
||||
emoji="📝",
|
||||
)
|
||||
1084
hermes_code/tools/skills_guard.py
Normal file
1084
hermes_code/tools/skills_guard.py
Normal file
File diff suppressed because it is too large
Load diff
2488
hermes_code/tools/skills_hub.py
Normal file
2488
hermes_code/tools/skills_hub.py
Normal file
File diff suppressed because it is too large
Load diff
287
hermes_code/tools/skills_sync.py
Normal file
287
hermes_code/tools/skills_sync.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Skills Sync -- Manifest-based seeding and updating of bundled skills.
|
||||
|
||||
Copies bundled skills from the repo's skills/ directory into ~/.hermes/skills/
|
||||
and uses a manifest to track which skills have been synced and their origin hash.
|
||||
|
||||
Manifest format (v2): each line is "skill_name:origin_hash" where origin_hash
|
||||
is the MD5 of the bundled skill at the time it was last synced to the user dir.
|
||||
Old v1 manifests (plain names without hashes) are auto-migrated.
|
||||
|
||||
Update logic:
|
||||
- NEW skills (not in manifest): copied to user dir, origin hash recorded.
|
||||
- EXISTING skills (in manifest, present in user dir):
|
||||
* If user copy matches origin hash: user hasn't modified it → safe to
|
||||
update from bundled if bundled changed. New origin hash recorded.
|
||||
* If user copy differs from origin hash: user customized it → SKIP.
|
||||
- DELETED by user (in manifest, absent from user dir): respected, not re-added.
|
||||
- REMOVED from bundled (in manifest, gone from repo): cleaned from manifest.
|
||||
|
||||
The manifest lives at ~/.hermes/skills/.bundled_manifest.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
SKILLS_DIR = HERMES_HOME / "skills"
|
||||
MANIFEST_FILE = SKILLS_DIR / ".bundled_manifest"
|
||||
|
||||
|
||||
def _get_bundled_dir() -> Path:
|
||||
"""Locate the bundled skills/ directory in the repo."""
|
||||
return Path(__file__).parent.parent / "skills"
|
||||
|
||||
|
||||
def _read_manifest() -> Dict[str, str]:
|
||||
"""
|
||||
Read the manifest as a dict of {skill_name: origin_hash}.
|
||||
|
||||
Handles both v1 (plain names) and v2 (name:hash) formats.
|
||||
v1 entries get an empty hash string which triggers migration on next sync.
|
||||
"""
|
||||
if not MANIFEST_FILE.exists():
|
||||
return {}
|
||||
try:
|
||||
result = {}
|
||||
for line in MANIFEST_FILE.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if ":" in line:
|
||||
# v2 format: name:hash
|
||||
name, _, hash_val = line.partition(":")
|
||||
result[name.strip()] = hash_val.strip()
|
||||
else:
|
||||
# v1 format: plain name — empty hash triggers migration
|
||||
result[line] = ""
|
||||
return result
|
||||
except (OSError, IOError):
|
||||
return {}
|
||||
|
||||
|
||||
def _write_manifest(entries: Dict[str, str]):
|
||||
"""Write the manifest file atomically in v2 format (name:hash).
|
||||
|
||||
Uses a temp file + os.replace() to avoid corruption if the process
|
||||
crashes or is interrupted mid-write.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
MANIFEST_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
data = "\n".join(f"{name}:{hash_val}" for name, hash_val in sorted(entries.items())) + "\n"
|
||||
|
||||
try:
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=str(MANIFEST_FILE.parent),
|
||||
prefix=".bundled_manifest_",
|
||||
suffix=".tmp",
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(data)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, MANIFEST_FILE)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug("Failed to write skills manifest %s: %s", MANIFEST_FILE, e, exc_info=True)
|
||||
|
||||
|
||||
def _discover_bundled_skills(bundled_dir: Path) -> List[Tuple[str, Path]]:
|
||||
"""
|
||||
Find all SKILL.md files in the bundled directory.
|
||||
Returns list of (skill_name, skill_directory_path) tuples.
|
||||
"""
|
||||
skills = []
|
||||
if not bundled_dir.exists():
|
||||
return skills
|
||||
|
||||
for skill_md in bundled_dir.rglob("SKILL.md"):
|
||||
path_str = str(skill_md)
|
||||
if "/.git/" in path_str or "/.github/" in path_str or "/.hub/" in path_str:
|
||||
continue
|
||||
skill_dir = skill_md.parent
|
||||
skill_name = skill_dir.name
|
||||
skills.append((skill_name, skill_dir))
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
def _compute_relative_dest(skill_dir: Path, bundled_dir: Path) -> Path:
|
||||
"""
|
||||
Compute the destination path in SKILLS_DIR preserving the category structure.
|
||||
e.g., bundled/skills/mlops/axolotl -> ~/.hermes/skills/mlops/axolotl
|
||||
"""
|
||||
rel = skill_dir.relative_to(bundled_dir)
|
||||
return SKILLS_DIR / rel
|
||||
|
||||
|
||||
def _dir_hash(directory: Path) -> str:
|
||||
"""Compute a hash of all file contents in a directory for change detection."""
|
||||
hasher = hashlib.md5()
|
||||
try:
|
||||
for fpath in sorted(directory.rglob("*")):
|
||||
if fpath.is_file():
|
||||
rel = fpath.relative_to(directory)
|
||||
hasher.update(str(rel).encode("utf-8"))
|
||||
hasher.update(fpath.read_bytes())
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def sync_skills(quiet: bool = False) -> dict:
|
||||
"""
|
||||
Sync bundled skills into ~/.hermes/skills/ using the manifest.
|
||||
|
||||
Returns:
|
||||
dict with keys: copied (list), updated (list), skipped (int),
|
||||
user_modified (list), cleaned (list), total_bundled (int)
|
||||
"""
|
||||
bundled_dir = _get_bundled_dir()
|
||||
if not bundled_dir.exists():
|
||||
return {
|
||||
"copied": [], "updated": [], "skipped": 0,
|
||||
"user_modified": [], "cleaned": [], "total_bundled": 0,
|
||||
}
|
||||
|
||||
SKILLS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
manifest = _read_manifest()
|
||||
bundled_skills = _discover_bundled_skills(bundled_dir)
|
||||
bundled_names = {name for name, _ in bundled_skills}
|
||||
|
||||
copied = []
|
||||
updated = []
|
||||
user_modified = []
|
||||
skipped = 0
|
||||
|
||||
for skill_name, skill_src in bundled_skills:
|
||||
dest = _compute_relative_dest(skill_src, bundled_dir)
|
||||
bundled_hash = _dir_hash(skill_src)
|
||||
|
||||
if skill_name not in manifest:
|
||||
# ── New skill — never offered before ──
|
||||
try:
|
||||
if dest.exists():
|
||||
# User already has a skill with the same name — don't overwrite
|
||||
skipped += 1
|
||||
manifest[skill_name] = bundled_hash
|
||||
else:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copytree(skill_src, dest)
|
||||
copied.append(skill_name)
|
||||
manifest[skill_name] = bundled_hash
|
||||
if not quiet:
|
||||
print(f" + {skill_name}")
|
||||
except (OSError, IOError) as e:
|
||||
if not quiet:
|
||||
print(f" ! Failed to copy {skill_name}: {e}")
|
||||
# Do NOT add to manifest — next sync should retry
|
||||
|
||||
elif dest.exists():
|
||||
# ── Existing skill — in manifest AND on disk ──
|
||||
origin_hash = manifest.get(skill_name, "")
|
||||
user_hash = _dir_hash(dest)
|
||||
|
||||
if not origin_hash:
|
||||
# v1 migration: no origin hash recorded. Set baseline from
|
||||
# user's current copy so future syncs can detect modifications.
|
||||
manifest[skill_name] = user_hash
|
||||
if user_hash == bundled_hash:
|
||||
skipped += 1 # already in sync
|
||||
else:
|
||||
# Can't tell if user modified or bundled changed — be safe
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
if user_hash != origin_hash:
|
||||
# User modified this skill — don't overwrite their changes
|
||||
user_modified.append(skill_name)
|
||||
if not quiet:
|
||||
print(f" ~ {skill_name} (user-modified, skipping)")
|
||||
continue
|
||||
|
||||
# User copy matches origin — check if bundled has a newer version
|
||||
if bundled_hash != origin_hash:
|
||||
try:
|
||||
# Move old copy to a backup so we can restore on failure
|
||||
backup = dest.with_suffix(".bak")
|
||||
shutil.move(str(dest), str(backup))
|
||||
try:
|
||||
shutil.copytree(skill_src, dest)
|
||||
manifest[skill_name] = bundled_hash
|
||||
updated.append(skill_name)
|
||||
if not quiet:
|
||||
print(f" ↑ {skill_name} (updated)")
|
||||
# Remove backup after successful copy
|
||||
shutil.rmtree(backup, ignore_errors=True)
|
||||
except (OSError, IOError):
|
||||
# Restore from backup
|
||||
if backup.exists() and not dest.exists():
|
||||
shutil.move(str(backup), str(dest))
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
if not quiet:
|
||||
print(f" ! Failed to update {skill_name}: {e}")
|
||||
else:
|
||||
skipped += 1 # bundled unchanged, user unchanged
|
||||
|
||||
else:
|
||||
# ── In manifest but not on disk — user deleted it ──
|
||||
skipped += 1
|
||||
|
||||
# Clean stale manifest entries (skills removed from bundled dir)
|
||||
cleaned = sorted(set(manifest.keys()) - bundled_names)
|
||||
for name in cleaned:
|
||||
del manifest[name]
|
||||
|
||||
# Also copy DESCRIPTION.md files for categories (if not already present)
|
||||
for desc_md in bundled_dir.rglob("DESCRIPTION.md"):
|
||||
rel = desc_md.relative_to(bundled_dir)
|
||||
dest_desc = SKILLS_DIR / rel
|
||||
if not dest_desc.exists():
|
||||
try:
|
||||
dest_desc.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(desc_md, dest_desc)
|
||||
except (OSError, IOError) as e:
|
||||
logger.debug("Could not copy %s: %s", desc_md, e)
|
||||
|
||||
_write_manifest(manifest)
|
||||
|
||||
return {
|
||||
"copied": copied,
|
||||
"updated": updated,
|
||||
"skipped": skipped,
|
||||
"user_modified": user_modified,
|
||||
"cleaned": cleaned,
|
||||
"total_bundled": len(bundled_skills),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Syncing bundled skills into ~/.hermes/skills/ ...")
|
||||
result = sync_skills(quiet=False)
|
||||
parts = [
|
||||
f"{len(result['copied'])} new",
|
||||
f"{len(result['updated'])} updated",
|
||||
f"{result['skipped']} unchanged",
|
||||
]
|
||||
if result["user_modified"]:
|
||||
parts.append(f"{len(result['user_modified'])} user-modified (kept)")
|
||||
if result["cleaned"]:
|
||||
parts.append(f"{len(result['cleaned'])} cleaned from manifest")
|
||||
print(f"\nDone: {', '.join(parts)}. {result['total_bundled']} total bundled.")
|
||||
1340
hermes_code/tools/skills_tool.py
Normal file
1340
hermes_code/tools/skills_tool.py
Normal file
File diff suppressed because it is too large
Load diff
1356
hermes_code/tools/terminal_tool.py
Normal file
1356
hermes_code/tools/terminal_tool.py
Normal file
File diff suppressed because it is too large
Load diff
674
hermes_code/tools/tirith_security.py
Normal file
674
hermes_code/tools/tirith_security.py
Normal file
|
|
@ -0,0 +1,674 @@
|
|||
"""Tirith pre-exec security scanning wrapper.
|
||||
|
||||
Runs the tirith binary as a subprocess to scan commands for content-level
|
||||
threats (homograph URLs, pipe-to-interpreter, terminal injection, etc.).
|
||||
|
||||
Exit code is the verdict source of truth:
|
||||
0 = allow, 1 = block, 2 = warn
|
||||
|
||||
JSON stdout enriches findings/summary but never overrides the verdict.
|
||||
Operational failures (spawn error, timeout, unknown exit code) respect
|
||||
the fail_open config setting. Programming errors propagate.
|
||||
|
||||
Auto-install: if tirith is not found on PATH or at the configured path,
|
||||
it is automatically downloaded from GitHub releases to $HERMES_HOME/bin/tirith.
|
||||
The download always verifies SHA-256 checksums. When cosign is available on
|
||||
PATH, provenance verification (GitHub Actions workflow signature) is also
|
||||
performed. If cosign is not installed, the download proceeds with SHA-256
|
||||
verification only — still secure via HTTPS + checksum, just without supply
|
||||
chain provenance proof. Installation runs in a background thread so startup
|
||||
never blocks.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
import tarfile
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_REPO = "sheeki03/tirith"
|
||||
|
||||
# Cosign provenance verification — pinned to the specific release workflow
|
||||
_COSIGN_IDENTITY_REGEXP = f"^https://github.com/{_REPO}/\\.github/workflows/release\\.yml@refs/tags/v"
|
||||
_COSIGN_ISSUER = "https://token.actions.githubusercontent.com"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _env_bool(key: str, default: bool) -> bool:
|
||||
val = os.getenv(key)
|
||||
if val is None:
|
||||
return default
|
||||
return val.lower() in ("1", "true", "yes")
|
||||
|
||||
|
||||
def _env_int(key: str, default: int) -> int:
|
||||
val = os.getenv(key)
|
||||
if val is None:
|
||||
return default
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def _load_security_config() -> dict:
|
||||
"""Load security settings from config.yaml, with env var overrides."""
|
||||
defaults = {
|
||||
"tirith_enabled": True,
|
||||
"tirith_path": "tirith",
|
||||
"tirith_timeout": 5,
|
||||
"tirith_fail_open": True,
|
||||
}
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config().get("security", {}) or {}
|
||||
except Exception:
|
||||
cfg = {}
|
||||
|
||||
return {
|
||||
"tirith_enabled": _env_bool("TIRITH_ENABLED", cfg.get("tirith_enabled", defaults["tirith_enabled"])),
|
||||
"tirith_path": os.getenv("TIRITH_BIN", cfg.get("tirith_path", defaults["tirith_path"])),
|
||||
"tirith_timeout": _env_int("TIRITH_TIMEOUT", cfg.get("tirith_timeout", defaults["tirith_timeout"])),
|
||||
"tirith_fail_open": _env_bool("TIRITH_FAIL_OPEN", cfg.get("tirith_fail_open", defaults["tirith_fail_open"])),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-install
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Cached path after first resolution (avoids repeated shutil.which per command).
|
||||
# _INSTALL_FAILED means "we tried and failed" — prevents retry on every command.
|
||||
_resolved_path: str | None | bool = None
|
||||
_INSTALL_FAILED = False # sentinel: distinct from "not yet tried"
|
||||
_install_failure_reason: str = "" # reason tag when _resolved_path is _INSTALL_FAILED
|
||||
|
||||
# Background install thread coordination
|
||||
_install_lock = threading.Lock()
|
||||
_install_thread: threading.Thread | None = None
|
||||
|
||||
# Disk-persistent failure marker — avoids retry across process restarts
|
||||
_MARKER_TTL = 86400 # 24 hours
|
||||
|
||||
|
||||
def _get_hermes_home() -> str:
|
||||
"""Return the Hermes home directory, respecting HERMES_HOME env var.
|
||||
|
||||
Matches the convention used throughout the codebase (hermes_cli.config,
|
||||
cli.py, gateway/run.py, etc.) so tirith state stays inside the active
|
||||
profile and tests get automatic isolation via conftest's HERMES_HOME
|
||||
monkeypatch.
|
||||
"""
|
||||
return os.getenv("HERMES_HOME") or os.path.join(os.path.expanduser("~"), ".hermes")
|
||||
|
||||
|
||||
def _failure_marker_path() -> str:
|
||||
"""Return the path to the install-failure marker file."""
|
||||
return os.path.join(_get_hermes_home(), ".tirith-install-failed")
|
||||
|
||||
|
||||
def _read_failure_reason() -> str | None:
|
||||
"""Read the failure reason from the disk marker.
|
||||
|
||||
Returns the reason string, or None if the marker doesn't exist or is
|
||||
older than _MARKER_TTL.
|
||||
"""
|
||||
try:
|
||||
p = _failure_marker_path()
|
||||
mtime = os.path.getmtime(p)
|
||||
if (time.time() - mtime) >= _MARKER_TTL:
|
||||
return None
|
||||
with open(p, "r") as f:
|
||||
return f.read().strip()
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
def _is_install_failed_on_disk() -> bool:
|
||||
"""Check if a recent install failure was persisted to disk.
|
||||
|
||||
Returns False (allowing retry) when:
|
||||
- No marker exists
|
||||
- Marker is older than _MARKER_TTL (24h)
|
||||
- Marker reason is 'cosign_missing' and cosign is now on PATH
|
||||
"""
|
||||
reason = _read_failure_reason()
|
||||
if reason is None:
|
||||
return False
|
||||
if reason == "cosign_missing" and shutil.which("cosign"):
|
||||
_clear_install_failed()
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _mark_install_failed(reason: str = ""):
|
||||
"""Persist install failure to disk to avoid retry on next process.
|
||||
|
||||
Args:
|
||||
reason: Short tag identifying the failure cause. Use "cosign_missing"
|
||||
when cosign is not on PATH so the marker can be auto-cleared
|
||||
once cosign becomes available.
|
||||
"""
|
||||
try:
|
||||
p = _failure_marker_path()
|
||||
os.makedirs(os.path.dirname(p), exist_ok=True)
|
||||
with open(p, "w") as f:
|
||||
f.write(reason)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _clear_install_failed():
|
||||
"""Remove the failure marker after successful install."""
|
||||
try:
|
||||
os.unlink(_failure_marker_path())
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _hermes_bin_dir() -> str:
|
||||
"""Return $HERMES_HOME/bin, creating it if needed."""
|
||||
d = os.path.join(_get_hermes_home(), "bin")
|
||||
os.makedirs(d, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
def _detect_target() -> str | None:
|
||||
"""Return the Rust target triple for the current platform, or None."""
|
||||
system = platform.system()
|
||||
machine = platform.machine().lower()
|
||||
|
||||
if system == "Darwin":
|
||||
plat = "apple-darwin"
|
||||
elif system == "Linux":
|
||||
plat = "unknown-linux-gnu"
|
||||
else:
|
||||
return None
|
||||
|
||||
if machine in ("x86_64", "amd64"):
|
||||
arch = "x86_64"
|
||||
elif machine in ("aarch64", "arm64"):
|
||||
arch = "aarch64"
|
||||
else:
|
||||
return None
|
||||
|
||||
return f"{arch}-{plat}"
|
||||
|
||||
|
||||
def _download_file(url: str, dest: str, timeout: int = 10):
|
||||
"""Download a URL to a local file."""
|
||||
req = urllib.request.Request(url)
|
||||
token = os.getenv("GITHUB_TOKEN")
|
||||
if token:
|
||||
req.add_header("Authorization", f"token {token}")
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp, open(dest, "wb") as f:
|
||||
shutil.copyfileobj(resp, f)
|
||||
|
||||
|
||||
def _verify_cosign(checksums_path: str, sig_path: str, cert_path: str) -> bool | None:
|
||||
"""Verify cosign provenance signature on checksums.txt.
|
||||
|
||||
Returns:
|
||||
True — cosign verified successfully
|
||||
False — cosign found but verification failed
|
||||
None — cosign not available (not on PATH, or execution failed)
|
||||
|
||||
The caller treats both False and None as "abort auto-install" — only
|
||||
True allows the install to proceed.
|
||||
"""
|
||||
cosign = shutil.which("cosign")
|
||||
if not cosign:
|
||||
logger.info("cosign not found on PATH")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[cosign, "verify-blob",
|
||||
"--certificate", cert_path,
|
||||
"--signature", sig_path,
|
||||
"--certificate-identity-regexp", _COSIGN_IDENTITY_REGEXP,
|
||||
"--certificate-oidc-issuer", _COSIGN_ISSUER,
|
||||
checksums_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=15,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info("cosign provenance verification passed")
|
||||
return True
|
||||
else:
|
||||
logger.warning("cosign verification failed (exit %d): %s",
|
||||
result.returncode, result.stderr.strip())
|
||||
return False
|
||||
except (OSError, subprocess.TimeoutExpired) as exc:
|
||||
logger.warning("cosign execution failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _verify_checksum(archive_path: str, checksums_path: str, archive_name: str) -> bool:
|
||||
"""Verify SHA-256 of the archive against checksums.txt."""
|
||||
expected = None
|
||||
with open(checksums_path) as f:
|
||||
for line in f:
|
||||
# Format: "<hash> <filename>"
|
||||
parts = line.strip().split(" ", 1)
|
||||
if len(parts) == 2 and parts[1] == archive_name:
|
||||
expected = parts[0]
|
||||
break
|
||||
if not expected:
|
||||
logger.warning("No checksum entry for %s", archive_name)
|
||||
return False
|
||||
|
||||
sha = hashlib.sha256()
|
||||
with open(archive_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha.update(chunk)
|
||||
actual = sha.hexdigest()
|
||||
if actual != expected:
|
||||
logger.warning("Checksum mismatch: expected %s, got %s", expected, actual)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _install_tirith(*, log_failures: bool = True) -> tuple[str | None, str]:
|
||||
"""Download and install tirith to $HERMES_HOME/bin/tirith.
|
||||
|
||||
Verifies provenance via cosign and SHA-256 checksum.
|
||||
Returns (installed_path, failure_reason). On success failure_reason is "".
|
||||
failure_reason is a short tag used by the disk marker to decide if the
|
||||
failure is retryable (e.g. "cosign_missing" clears when cosign appears).
|
||||
"""
|
||||
log = logger.warning if log_failures else logger.debug
|
||||
|
||||
target = _detect_target()
|
||||
if not target:
|
||||
logger.info("tirith auto-install: unsupported platform %s/%s",
|
||||
platform.system(), platform.machine())
|
||||
return None, "unsupported_platform"
|
||||
|
||||
archive_name = f"tirith-{target}.tar.gz"
|
||||
base_url = f"https://github.com/{_REPO}/releases/latest/download"
|
||||
|
||||
tmpdir = tempfile.mkdtemp(prefix="tirith-install-")
|
||||
try:
|
||||
archive_path = os.path.join(tmpdir, archive_name)
|
||||
checksums_path = os.path.join(tmpdir, "checksums.txt")
|
||||
sig_path = os.path.join(tmpdir, "checksums.txt.sig")
|
||||
cert_path = os.path.join(tmpdir, "checksums.txt.pem")
|
||||
|
||||
logger.info("tirith not found — downloading latest release for %s...", target)
|
||||
|
||||
try:
|
||||
_download_file(f"{base_url}/{archive_name}", archive_path)
|
||||
_download_file(f"{base_url}/checksums.txt", checksums_path)
|
||||
except Exception as exc:
|
||||
log("tirith download failed: %s", exc)
|
||||
return None, "download_failed"
|
||||
|
||||
# Cosign provenance verification — preferred but not mandatory.
|
||||
# When cosign is available, we verify that the release was produced
|
||||
# by the expected GitHub Actions workflow (full supply chain proof).
|
||||
# Without cosign, SHA-256 checksum + HTTPS still provides integrity
|
||||
# and transport-level authenticity.
|
||||
cosign_verified = False
|
||||
if shutil.which("cosign"):
|
||||
try:
|
||||
_download_file(f"{base_url}/checksums.txt.sig", sig_path)
|
||||
_download_file(f"{base_url}/checksums.txt.pem", cert_path)
|
||||
except Exception as exc:
|
||||
logger.info("cosign artifacts unavailable (%s), proceeding with SHA-256 only", exc)
|
||||
else:
|
||||
cosign_result = _verify_cosign(checksums_path, sig_path, cert_path)
|
||||
if cosign_result is True:
|
||||
cosign_verified = True
|
||||
elif cosign_result is False:
|
||||
# Verification explicitly rejected — abort, the release
|
||||
# may have been tampered with.
|
||||
log("tirith install aborted: cosign provenance verification failed")
|
||||
return None, "cosign_verification_failed"
|
||||
else:
|
||||
# None = execution failure (timeout/OSError) — proceed
|
||||
# with SHA-256 only since cosign itself is broken.
|
||||
logger.info("cosign execution failed, proceeding with SHA-256 only")
|
||||
else:
|
||||
logger.info("cosign not on PATH — installing tirith with SHA-256 verification only "
|
||||
"(install cosign for full supply chain verification)")
|
||||
|
||||
if not _verify_checksum(archive_path, checksums_path, archive_name):
|
||||
return None, "checksum_failed"
|
||||
|
||||
with tarfile.open(archive_path, "r:gz") as tar:
|
||||
# Extract only the tirith binary (safety: reject paths with ..)
|
||||
for member in tar.getmembers():
|
||||
if member.name == "tirith" or member.name.endswith("/tirith"):
|
||||
if ".." in member.name:
|
||||
continue
|
||||
member.name = "tirith"
|
||||
tar.extract(member, tmpdir)
|
||||
break
|
||||
else:
|
||||
log("tirith binary not found in archive")
|
||||
return None, "binary_not_in_archive"
|
||||
|
||||
src = os.path.join(tmpdir, "tirith")
|
||||
dest = os.path.join(_hermes_bin_dir(), "tirith")
|
||||
shutil.move(src, dest)
|
||||
os.chmod(dest, os.stat(dest).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
|
||||
|
||||
verification = "cosign + SHA-256" if cosign_verified else "SHA-256 only"
|
||||
logger.info("tirith installed to %s (%s)", dest, verification)
|
||||
return dest, ""
|
||||
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
|
||||
def _is_explicit_path(configured_path: str) -> bool:
|
||||
"""Return True if the user explicitly configured a non-default tirith path."""
|
||||
return configured_path != "tirith"
|
||||
|
||||
|
||||
def _resolve_tirith_path(configured_path: str) -> str:
|
||||
"""Resolve the tirith binary path, auto-installing if necessary.
|
||||
|
||||
If the user explicitly set a path (anything other than the bare "tirith"
|
||||
default), that path is authoritative — we never fall through to
|
||||
auto-download a different binary.
|
||||
|
||||
For the default "tirith":
|
||||
1. PATH lookup via shutil.which
|
||||
2. $HERMES_HOME/bin/tirith (previously auto-installed)
|
||||
3. Auto-install from GitHub releases → $HERMES_HOME/bin/tirith
|
||||
|
||||
Failed installs are cached for the process lifetime (and persisted to
|
||||
disk for 24h) to avoid repeated network attempts.
|
||||
"""
|
||||
global _resolved_path, _install_failure_reason
|
||||
|
||||
# Fast path: successfully resolved on a previous call.
|
||||
if _resolved_path is not None and _resolved_path is not _INSTALL_FAILED:
|
||||
return _resolved_path
|
||||
|
||||
expanded = os.path.expanduser(configured_path)
|
||||
explicit = _is_explicit_path(configured_path)
|
||||
install_failed = _resolved_path is _INSTALL_FAILED
|
||||
|
||||
# Explicit path: check it and stop. Never auto-download a replacement.
|
||||
if explicit:
|
||||
if os.path.isfile(expanded) and os.access(expanded, os.X_OK):
|
||||
_resolved_path = expanded
|
||||
return expanded
|
||||
# Also try shutil.which in case it's a bare name on PATH
|
||||
found = shutil.which(expanded)
|
||||
if found:
|
||||
_resolved_path = found
|
||||
return found
|
||||
logger.warning("Configured tirith path %r not found; scanning disabled", configured_path)
|
||||
_resolved_path = _INSTALL_FAILED
|
||||
_install_failure_reason = "explicit_path_missing"
|
||||
return expanded
|
||||
|
||||
# Default "tirith" — always re-run cheap local checks so a manual
|
||||
# install is picked up even after a previous network failure (P2 fix:
|
||||
# long-lived gateway/CLI recovers without restart).
|
||||
found = shutil.which("tirith")
|
||||
if found:
|
||||
_resolved_path = found
|
||||
_install_failure_reason = ""
|
||||
_clear_install_failed()
|
||||
return found
|
||||
|
||||
hermes_bin = os.path.join(_hermes_bin_dir(), "tirith")
|
||||
if os.path.isfile(hermes_bin) and os.access(hermes_bin, os.X_OK):
|
||||
_resolved_path = hermes_bin
|
||||
_install_failure_reason = ""
|
||||
_clear_install_failed()
|
||||
return hermes_bin
|
||||
|
||||
# Local checks failed. If a previous install attempt already failed,
|
||||
# skip the network retry — UNLESS the failure was "cosign_missing" and
|
||||
# cosign is now available (retryable cause resolved in-process).
|
||||
if install_failed:
|
||||
if _install_failure_reason == "cosign_missing" and shutil.which("cosign"):
|
||||
# Retryable cause resolved — clear sentinel and fall through to retry
|
||||
_resolved_path = None
|
||||
_install_failure_reason = ""
|
||||
_clear_install_failed()
|
||||
install_failed = False
|
||||
else:
|
||||
return expanded
|
||||
|
||||
# If a background install thread is running, don't start a parallel one —
|
||||
# return the configured path; the OSError handler in check_command_security
|
||||
# will apply fail_open until the thread finishes.
|
||||
if _install_thread is not None and _install_thread.is_alive():
|
||||
return expanded
|
||||
|
||||
# Check disk failure marker before attempting network download.
|
||||
# Preserve the marker's real reason so in-memory retry logic can
|
||||
# detect retryable causes (e.g. cosign_missing) without restart.
|
||||
disk_reason = _read_failure_reason()
|
||||
if disk_reason is not None and _is_install_failed_on_disk():
|
||||
_resolved_path = _INSTALL_FAILED
|
||||
_install_failure_reason = disk_reason
|
||||
return expanded
|
||||
|
||||
installed, reason = _install_tirith()
|
||||
if installed:
|
||||
_resolved_path = installed
|
||||
_install_failure_reason = ""
|
||||
_clear_install_failed()
|
||||
return installed
|
||||
|
||||
# Install failed — cache the miss and persist reason to disk
|
||||
_resolved_path = _INSTALL_FAILED
|
||||
_install_failure_reason = reason
|
||||
_mark_install_failed(reason)
|
||||
return expanded
|
||||
|
||||
|
||||
def _background_install(*, log_failures: bool = True):
|
||||
"""Background thread target: download and install tirith."""
|
||||
global _resolved_path, _install_failure_reason
|
||||
with _install_lock:
|
||||
# Double-check after acquiring lock (another thread may have resolved)
|
||||
if _resolved_path is not None:
|
||||
return
|
||||
|
||||
# Re-check local paths (may have been installed by another process)
|
||||
found = shutil.which("tirith")
|
||||
if found:
|
||||
_resolved_path = found
|
||||
_install_failure_reason = ""
|
||||
return
|
||||
|
||||
hermes_bin = os.path.join(_hermes_bin_dir(), "tirith")
|
||||
if os.path.isfile(hermes_bin) and os.access(hermes_bin, os.X_OK):
|
||||
_resolved_path = hermes_bin
|
||||
_install_failure_reason = ""
|
||||
return
|
||||
|
||||
installed, reason = _install_tirith(log_failures=log_failures)
|
||||
if installed:
|
||||
_resolved_path = installed
|
||||
_install_failure_reason = ""
|
||||
_clear_install_failed()
|
||||
else:
|
||||
_resolved_path = _INSTALL_FAILED
|
||||
_install_failure_reason = reason
|
||||
_mark_install_failed(reason)
|
||||
|
||||
|
||||
def ensure_installed(*, log_failures: bool = True):
|
||||
"""Ensure tirith is available, downloading in background if needed.
|
||||
|
||||
Quick PATH/local checks are synchronous; network download runs in a
|
||||
daemon thread so startup never blocks. Safe to call multiple times.
|
||||
Returns the resolved path immediately if available, or None.
|
||||
"""
|
||||
global _resolved_path, _install_thread, _install_failure_reason
|
||||
|
||||
cfg = _load_security_config()
|
||||
if not cfg["tirith_enabled"]:
|
||||
return None
|
||||
|
||||
# Already resolved from a previous call
|
||||
if _resolved_path is not None and _resolved_path is not _INSTALL_FAILED:
|
||||
path = _resolved_path
|
||||
if os.path.isfile(path) and os.access(path, os.X_OK):
|
||||
return path
|
||||
return None
|
||||
|
||||
configured_path = cfg["tirith_path"]
|
||||
explicit = _is_explicit_path(configured_path)
|
||||
expanded = os.path.expanduser(configured_path)
|
||||
|
||||
# Explicit path: synchronous check only, no download
|
||||
if explicit:
|
||||
if os.path.isfile(expanded) and os.access(expanded, os.X_OK):
|
||||
_resolved_path = expanded
|
||||
return expanded
|
||||
found = shutil.which(expanded)
|
||||
if found:
|
||||
_resolved_path = found
|
||||
return found
|
||||
_resolved_path = _INSTALL_FAILED
|
||||
_install_failure_reason = "explicit_path_missing"
|
||||
return None
|
||||
|
||||
# Default "tirith" — quick local checks first (no network)
|
||||
found = shutil.which("tirith")
|
||||
if found:
|
||||
_resolved_path = found
|
||||
_install_failure_reason = ""
|
||||
_clear_install_failed()
|
||||
return found
|
||||
|
||||
hermes_bin = os.path.join(_hermes_bin_dir(), "tirith")
|
||||
if os.path.isfile(hermes_bin) and os.access(hermes_bin, os.X_OK):
|
||||
_resolved_path = hermes_bin
|
||||
_install_failure_reason = ""
|
||||
_clear_install_failed()
|
||||
return hermes_bin
|
||||
|
||||
# If previously failed in-memory, check if the cause is now resolved
|
||||
if _resolved_path is _INSTALL_FAILED:
|
||||
if _install_failure_reason == "cosign_missing" and shutil.which("cosign"):
|
||||
_resolved_path = None
|
||||
_install_failure_reason = ""
|
||||
_clear_install_failed()
|
||||
else:
|
||||
return None
|
||||
|
||||
# Check disk failure marker (skip network attempt for 24h, unless
|
||||
# the cosign_missing reason was resolved — handled by _is_install_failed_on_disk).
|
||||
# Preserve the marker's real reason for in-memory retry logic.
|
||||
disk_reason = _read_failure_reason()
|
||||
if disk_reason is not None and _is_install_failed_on_disk():
|
||||
_resolved_path = _INSTALL_FAILED
|
||||
_install_failure_reason = disk_reason
|
||||
return None
|
||||
|
||||
# Need to download — launch background thread so startup doesn't block
|
||||
if _install_thread is None or not _install_thread.is_alive():
|
||||
_install_thread = threading.Thread(
|
||||
target=_background_install,
|
||||
kwargs={"log_failures": log_failures},
|
||||
daemon=True,
|
||||
)
|
||||
_install_thread.start()
|
||||
|
||||
return None # Not available yet; commands will fail-open until ready
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MAX_FINDINGS = 50
|
||||
_MAX_SUMMARY_LEN = 500
|
||||
|
||||
|
||||
def check_command_security(command: str) -> dict:
|
||||
"""Run tirith security scan on a command.
|
||||
|
||||
Exit code determines action (0=allow, 1=block, 2=warn). JSON enriches
|
||||
findings/summary. Spawn failures and timeouts respect fail_open config.
|
||||
Programming errors propagate.
|
||||
|
||||
Returns:
|
||||
{"action": "allow"|"warn"|"block", "findings": [...], "summary": str}
|
||||
"""
|
||||
cfg = _load_security_config()
|
||||
|
||||
if not cfg["tirith_enabled"]:
|
||||
return {"action": "allow", "findings": [], "summary": ""}
|
||||
|
||||
tirith_path = _resolve_tirith_path(cfg["tirith_path"])
|
||||
timeout = cfg["tirith_timeout"]
|
||||
fail_open = cfg["tirith_fail_open"]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[tirith_path, "check", "--json", "--non-interactive",
|
||||
"--shell", "posix", "--", command],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
except OSError as exc:
|
||||
# Covers FileNotFoundError, PermissionError, exec format error
|
||||
logger.warning("tirith spawn failed: %s", exc)
|
||||
if fail_open:
|
||||
return {"action": "allow", "findings": [], "summary": f"tirith unavailable: {exc}"}
|
||||
return {"action": "block", "findings": [], "summary": f"tirith spawn failed (fail-closed): {exc}"}
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("tirith timed out after %ds", timeout)
|
||||
if fail_open:
|
||||
return {"action": "allow", "findings": [], "summary": f"tirith timed out ({timeout}s)"}
|
||||
return {"action": "block", "findings": [], "summary": f"tirith timed out (fail-closed)"}
|
||||
|
||||
# Map exit code to action
|
||||
exit_code = result.returncode
|
||||
if exit_code == 0:
|
||||
action = "allow"
|
||||
elif exit_code == 1:
|
||||
action = "block"
|
||||
elif exit_code == 2:
|
||||
action = "warn"
|
||||
else:
|
||||
# Unknown exit code — respect fail_open
|
||||
logger.warning("tirith returned unexpected exit code %d", exit_code)
|
||||
if fail_open:
|
||||
return {"action": "allow", "findings": [], "summary": f"tirith exit code {exit_code} (fail-open)"}
|
||||
return {"action": "block", "findings": [], "summary": f"tirith exit code {exit_code} (fail-closed)"}
|
||||
|
||||
# Parse JSON for enrichment (never overrides the exit code verdict)
|
||||
findings = []
|
||||
summary = ""
|
||||
try:
|
||||
data = json.loads(result.stdout) if result.stdout.strip() else {}
|
||||
raw_findings = data.get("findings", [])
|
||||
findings = raw_findings[:_MAX_FINDINGS]
|
||||
summary = (data.get("summary", "") or "")[:_MAX_SUMMARY_LEN]
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
# JSON parse failure degrades findings/summary, not the verdict
|
||||
logger.debug("tirith JSON parse failed, using exit code only")
|
||||
if action == "block":
|
||||
summary = "security issue detected (details unavailable)"
|
||||
elif action == "warn":
|
||||
summary = "security warning detected (details unavailable)"
|
||||
|
||||
return {"action": action, "findings": findings, "summary": summary}
|
||||
268
hermes_code/tools/todo_tool.py
Normal file
268
hermes_code/tools/todo_tool.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Todo Tool Module - Planning & Task Management
|
||||
|
||||
Provides an in-memory task list the agent uses to decompose complex tasks,
|
||||
track progress, and maintain focus across long conversations. The state
|
||||
lives on the AIAgent instance (one per session) and is re-injected into
|
||||
the conversation after context compression events.
|
||||
|
||||
Design:
|
||||
- Single `todo` tool: provide `todos` param to write, omit to read
|
||||
- Every call returns the full current list
|
||||
- No system prompt mutation, no tool response modification
|
||||
- Behavioral guidance lives entirely in the tool schema description
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
|
||||
# Valid status values for todo items
|
||||
VALID_STATUSES = {"pending", "in_progress", "completed", "cancelled"}
|
||||
|
||||
|
||||
class TodoStore:
|
||||
"""
|
||||
In-memory todo list. One instance per AIAgent (one per session).
|
||||
|
||||
Items are ordered -- list position is priority. Each item has:
|
||||
- id: unique string identifier (agent-chosen)
|
||||
- content: task description
|
||||
- status: pending | in_progress | completed | cancelled
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._items: List[Dict[str, str]] = []
|
||||
|
||||
def write(self, todos: List[Dict[str, Any]], merge: bool = False) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Write todos. Returns the full current list after writing.
|
||||
|
||||
Args:
|
||||
todos: list of {id, content, status} dicts
|
||||
merge: if False, replace the entire list. If True, update
|
||||
existing items by id and append new ones.
|
||||
"""
|
||||
if not merge:
|
||||
# Replace mode: new list entirely
|
||||
self._items = [self._validate(t) for t in todos]
|
||||
else:
|
||||
# Merge mode: update existing items by id, append new ones
|
||||
existing = {item["id"]: item for item in self._items}
|
||||
for t in todos:
|
||||
item_id = str(t.get("id", "")).strip()
|
||||
if not item_id:
|
||||
continue # Can't merge without an id
|
||||
|
||||
if item_id in existing:
|
||||
# Update only the fields the LLM actually provided
|
||||
if "content" in t and t["content"]:
|
||||
existing[item_id]["content"] = str(t["content"]).strip()
|
||||
if "status" in t and t["status"]:
|
||||
status = str(t["status"]).strip().lower()
|
||||
if status in VALID_STATUSES:
|
||||
existing[item_id]["status"] = status
|
||||
else:
|
||||
# New item -- validate fully and append to end
|
||||
validated = self._validate(t)
|
||||
existing[validated["id"]] = validated
|
||||
self._items.append(validated)
|
||||
# Rebuild _items preserving order for existing items
|
||||
seen = set()
|
||||
rebuilt = []
|
||||
for item in self._items:
|
||||
current = existing.get(item["id"], item)
|
||||
if current["id"] not in seen:
|
||||
rebuilt.append(current)
|
||||
seen.add(current["id"])
|
||||
self._items = rebuilt
|
||||
return self.read()
|
||||
|
||||
def read(self) -> List[Dict[str, str]]:
|
||||
"""Return a copy of the current list."""
|
||||
return [item.copy() for item in self._items]
|
||||
|
||||
def has_items(self) -> bool:
|
||||
"""Check if there are any items in the list."""
|
||||
return len(self._items) > 0
|
||||
|
||||
def format_for_injection(self) -> Optional[str]:
|
||||
"""
|
||||
Render the todo list for post-compression injection.
|
||||
|
||||
Returns a human-readable string to append to the compressed
|
||||
message history, or None if the list is empty.
|
||||
"""
|
||||
if not self._items:
|
||||
return None
|
||||
|
||||
# Status markers for compact display
|
||||
markers = {
|
||||
"completed": "[x]",
|
||||
"in_progress": "[>]",
|
||||
"pending": "[ ]",
|
||||
"cancelled": "[~]",
|
||||
}
|
||||
|
||||
# Only inject pending/in_progress items — completed/cancelled ones
|
||||
# cause the model to re-do finished work after compression.
|
||||
active_items = [
|
||||
item for item in self._items
|
||||
if item["status"] in ("pending", "in_progress")
|
||||
]
|
||||
if not active_items:
|
||||
return None
|
||||
|
||||
lines = ["[Your active task list was preserved across context compression]"]
|
||||
for item in active_items:
|
||||
marker = markers.get(item["status"], "[?]")
|
||||
lines.append(f"- {marker} {item['id']}. {item['content']} ({item['status']})")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _validate(item: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Validate and normalize a todo item.
|
||||
|
||||
Ensures required fields exist and status is valid.
|
||||
Returns a clean dict with only {id, content, status}.
|
||||
"""
|
||||
item_id = str(item.get("id", "")).strip()
|
||||
if not item_id:
|
||||
item_id = "?"
|
||||
|
||||
content = str(item.get("content", "")).strip()
|
||||
if not content:
|
||||
content = "(no description)"
|
||||
|
||||
status = str(item.get("status", "pending")).strip().lower()
|
||||
if status not in VALID_STATUSES:
|
||||
status = "pending"
|
||||
|
||||
return {"id": item_id, "content": content, "status": status}
|
||||
|
||||
|
||||
def todo_tool(
|
||||
todos: Optional[List[Dict[str, Any]]] = None,
|
||||
merge: bool = False,
|
||||
store: Optional[TodoStore] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Single entry point for the todo tool. Reads or writes depending on params.
|
||||
|
||||
Args:
|
||||
todos: if provided, write these items. If None, read current list.
|
||||
merge: if True, update by id. If False (default), replace entire list.
|
||||
store: the TodoStore instance from the AIAgent.
|
||||
|
||||
Returns:
|
||||
JSON string with the full current list and summary metadata.
|
||||
"""
|
||||
if store is None:
|
||||
return json.dumps({"error": "TodoStore not initialized"}, ensure_ascii=False)
|
||||
|
||||
if todos is not None:
|
||||
items = store.write(todos, merge)
|
||||
else:
|
||||
items = store.read()
|
||||
|
||||
# Build summary counts
|
||||
pending = sum(1 for i in items if i["status"] == "pending")
|
||||
in_progress = sum(1 for i in items if i["status"] == "in_progress")
|
||||
completed = sum(1 for i in items if i["status"] == "completed")
|
||||
cancelled = sum(1 for i in items if i["status"] == "cancelled")
|
||||
|
||||
return json.dumps({
|
||||
"todos": items,
|
||||
"summary": {
|
||||
"total": len(items),
|
||||
"pending": pending,
|
||||
"in_progress": in_progress,
|
||||
"completed": completed,
|
||||
"cancelled": cancelled,
|
||||
},
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_todo_requirements() -> bool:
|
||||
"""Todo tool has no external requirements -- always available."""
|
||||
return True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI Function-Calling Schema
|
||||
# =============================================================================
|
||||
# Behavioral guidance is baked into the description so it's part of the
|
||||
# static tool schema (cached, never changes mid-conversation).
|
||||
|
||||
TODO_SCHEMA = {
|
||||
"name": "todo",
|
||||
"description": (
|
||||
"Manage your task list for the current session. Use for complex tasks "
|
||||
"with 3+ steps or when the user provides multiple tasks. "
|
||||
"Call with no parameters to read the current list.\n\n"
|
||||
"Writing:\n"
|
||||
"- Provide 'todos' array to create/update items\n"
|
||||
"- merge=false (default): replace the entire list with a fresh plan\n"
|
||||
"- merge=true: update existing items by id, add any new ones\n\n"
|
||||
"Each item: {id: string, content: string, "
|
||||
"status: pending|in_progress|completed|cancelled}\n"
|
||||
"List order is priority. Only ONE item in_progress at a time.\n"
|
||||
"Mark items completed immediately when done. If something fails, "
|
||||
"cancel it and add a revised item.\n\n"
|
||||
"Always returns the full current list."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todos": {
|
||||
"type": "array",
|
||||
"description": "Task items to write. Omit to read current list.",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Unique item identifier"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Task description"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed", "cancelled"],
|
||||
"description": "Current status"
|
||||
}
|
||||
},
|
||||
"required": ["id", "content", "status"]
|
||||
}
|
||||
},
|
||||
"merge": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"true: update existing items by id, add new ones. "
|
||||
"false (default): replace the entire list."
|
||||
),
|
||||
"default": False
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="todo",
|
||||
toolset="todo",
|
||||
schema=TODO_SCHEMA,
|
||||
handler=lambda args, **kw: todo_tool(
|
||||
todos=args.get("todos"), merge=args.get("merge", False), store=kw.get("store")),
|
||||
check_fn=check_todo_requirements,
|
||||
emoji="📋",
|
||||
)
|
||||
554
hermes_code/tools/transcription_tools.py
Normal file
554
hermes_code/tools/transcription_tools.py
Normal file
|
|
@ -0,0 +1,554 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Transcription Tools Module
|
||||
|
||||
Provides speech-to-text transcription with three providers:
|
||||
|
||||
- **local** (default, free) — faster-whisper running locally, no API key needed.
|
||||
Auto-downloads the model (~150 MB for ``base``) on first use.
|
||||
- **groq** (free tier) — Groq Whisper API, requires ``GROQ_API_KEY``.
|
||||
- **openai** (paid) — OpenAI Whisper API, requires ``VOICE_TOOLS_OPENAI_KEY``.
|
||||
|
||||
Used by the messaging gateway to automatically transcribe voice messages
|
||||
sent by users on Telegram, Discord, WhatsApp, Slack, and Signal.
|
||||
|
||||
Supported input formats: mp3, mp4, mpeg, mpga, m4a, wav, webm, ogg
|
||||
|
||||
Usage::
|
||||
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
|
||||
result = transcribe_audio("/path/to/audio.ogg")
|
||||
if result["success"]:
|
||||
print(result["transcript"])
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional imports — graceful degradation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import importlib.util as _ilu
|
||||
_HAS_FASTER_WHISPER = _ilu.find_spec("faster_whisper") is not None
|
||||
_HAS_OPENAI = _ilu.find_spec("openai") is not None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_PROVIDER = "local"
|
||||
DEFAULT_LOCAL_MODEL = "base"
|
||||
DEFAULT_LOCAL_STT_LANGUAGE = "en"
|
||||
DEFAULT_STT_MODEL = os.getenv("STT_OPENAI_MODEL", "whisper-1")
|
||||
DEFAULT_GROQ_STT_MODEL = os.getenv("STT_GROQ_MODEL", "whisper-large-v3-turbo")
|
||||
LOCAL_STT_COMMAND_ENV = "HERMES_LOCAL_STT_COMMAND"
|
||||
LOCAL_STT_LANGUAGE_ENV = "HERMES_LOCAL_STT_LANGUAGE"
|
||||
COMMON_LOCAL_BIN_DIRS = ("/opt/homebrew/bin", "/usr/local/bin")
|
||||
|
||||
GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
|
||||
OPENAI_BASE_URL = os.getenv("STT_OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
|
||||
SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg"}
|
||||
LOCAL_NATIVE_AUDIO_FORMATS = {".wav", ".aiff", ".aif"}
|
||||
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB
|
||||
|
||||
# Known model sets for auto-correction
|
||||
OPENAI_MODELS = {"whisper-1", "gpt-4o-mini-transcribe", "gpt-4o-transcribe"}
|
||||
GROQ_MODELS = {"whisper-large-v3", "whisper-large-v3-turbo", "distil-whisper-large-v3-en"}
|
||||
|
||||
# Singleton for the local model — loaded once, reused across calls
|
||||
_local_model: Optional[object] = None
|
||||
_local_model_name: Optional[str] = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_stt_model_from_config() -> Optional[str]:
|
||||
"""Read the STT model name from ~/.hermes/config.yaml.
|
||||
|
||||
Returns the value of ``stt.model`` if present, otherwise ``None``.
|
||||
Silently returns ``None`` on any error (missing file, bad YAML, etc.).
|
||||
"""
|
||||
try:
|
||||
import yaml
|
||||
cfg_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
return data.get("stt", {}).get("model")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _load_stt_config() -> dict:
|
||||
"""Load the ``stt`` section from user config, falling back to defaults."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
return load_config().get("stt", {})
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def is_stt_enabled(stt_config: Optional[dict] = None) -> bool:
|
||||
"""Return whether STT is enabled in config."""
|
||||
if stt_config is None:
|
||||
stt_config = _load_stt_config()
|
||||
enabled = stt_config.get("enabled", True)
|
||||
if isinstance(enabled, str):
|
||||
return enabled.strip().lower() in ("true", "1", "yes", "on")
|
||||
if enabled is None:
|
||||
return True
|
||||
return bool(enabled)
|
||||
|
||||
|
||||
def _resolve_openai_api_key() -> str:
|
||||
"""Prefer the voice-tools key, but fall back to the normal OpenAI key."""
|
||||
return os.getenv("VOICE_TOOLS_OPENAI_KEY", "") or os.getenv("OPENAI_API_KEY", "")
|
||||
|
||||
|
||||
def _find_binary(binary_name: str) -> Optional[str]:
|
||||
"""Find a local binary, checking common Homebrew/local prefixes as well as PATH."""
|
||||
for directory in COMMON_LOCAL_BIN_DIRS:
|
||||
candidate = Path(directory) / binary_name
|
||||
if candidate.exists() and os.access(candidate, os.X_OK):
|
||||
return str(candidate)
|
||||
return shutil.which(binary_name)
|
||||
|
||||
|
||||
def _find_ffmpeg_binary() -> Optional[str]:
|
||||
return _find_binary("ffmpeg")
|
||||
|
||||
|
||||
def _find_whisper_binary() -> Optional[str]:
|
||||
return _find_binary("whisper")
|
||||
|
||||
|
||||
def _get_local_command_template() -> Optional[str]:
|
||||
configured = os.getenv(LOCAL_STT_COMMAND_ENV, "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
|
||||
whisper_binary = _find_whisper_binary()
|
||||
if whisper_binary:
|
||||
quoted_binary = shlex.quote(whisper_binary)
|
||||
return (
|
||||
f"{quoted_binary} {{input_path}} --model {{model}} --output_format txt "
|
||||
"--output_dir {output_dir} --language {language}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _has_local_command() -> bool:
|
||||
return _get_local_command_template() is not None
|
||||
|
||||
|
||||
def _normalize_local_command_model(model_name: Optional[str]) -> str:
|
||||
if not model_name or model_name in OPENAI_MODELS or model_name in GROQ_MODELS:
|
||||
return DEFAULT_LOCAL_MODEL
|
||||
return model_name
|
||||
|
||||
|
||||
def _get_provider(stt_config: dict) -> str:
|
||||
"""Determine which STT provider to use.
|
||||
|
||||
When ``stt.provider`` is explicitly set in config, that choice is
|
||||
honoured — no silent cloud fallback. When no provider is configured,
|
||||
auto-detect tries: local > groq (free) > openai (paid).
|
||||
"""
|
||||
if not is_stt_enabled(stt_config):
|
||||
return "none"
|
||||
|
||||
explicit = "provider" in stt_config
|
||||
provider = stt_config.get("provider", DEFAULT_PROVIDER)
|
||||
|
||||
# --- Explicit provider: respect the user's choice ----------------------
|
||||
|
||||
if explicit:
|
||||
if provider == "local":
|
||||
if _HAS_FASTER_WHISPER:
|
||||
return "local"
|
||||
if _has_local_command():
|
||||
return "local_command"
|
||||
logger.warning(
|
||||
"STT provider 'local' configured but unavailable "
|
||||
"(install faster-whisper or set HERMES_LOCAL_STT_COMMAND)"
|
||||
)
|
||||
return "none"
|
||||
|
||||
if provider == "local_command":
|
||||
if _has_local_command():
|
||||
return "local_command"
|
||||
if _HAS_FASTER_WHISPER:
|
||||
logger.info("Local STT command unavailable, using local faster-whisper")
|
||||
return "local"
|
||||
logger.warning(
|
||||
"STT provider 'local_command' configured but unavailable"
|
||||
)
|
||||
return "none"
|
||||
|
||||
if provider == "groq":
|
||||
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
|
||||
return "groq"
|
||||
logger.warning(
|
||||
"STT provider 'groq' configured but GROQ_API_KEY not set"
|
||||
)
|
||||
return "none"
|
||||
|
||||
if provider == "openai":
|
||||
if _HAS_OPENAI and _resolve_openai_api_key():
|
||||
return "openai"
|
||||
logger.warning(
|
||||
"STT provider 'openai' configured but no API key available"
|
||||
)
|
||||
return "none"
|
||||
|
||||
return provider # Unknown — let it fail downstream
|
||||
|
||||
# --- Auto-detect (no explicit provider): local > groq > openai ---------
|
||||
|
||||
if _HAS_FASTER_WHISPER:
|
||||
return "local"
|
||||
if _has_local_command():
|
||||
return "local_command"
|
||||
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
|
||||
logger.info("No local STT available, using Groq Whisper API")
|
||||
return "groq"
|
||||
if _HAS_OPENAI and _resolve_openai_api_key():
|
||||
logger.info("No local STT available, using OpenAI Whisper API")
|
||||
return "openai"
|
||||
return "none"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_audio_file(file_path: str) -> Optional[Dict[str, Any]]:
|
||||
"""Validate the audio file. Returns an error dict or None if OK."""
|
||||
audio_path = Path(file_path)
|
||||
|
||||
if not audio_path.exists():
|
||||
return {"success": False, "transcript": "", "error": f"Audio file not found: {file_path}"}
|
||||
if not audio_path.is_file():
|
||||
return {"success": False, "transcript": "", "error": f"Path is not a file: {file_path}"}
|
||||
if audio_path.suffix.lower() not in SUPPORTED_FORMATS:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Unsupported format: {audio_path.suffix}. Supported: {', '.join(sorted(SUPPORTED_FORMATS))}",
|
||||
}
|
||||
try:
|
||||
file_size = audio_path.stat().st_size
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max {MAX_FILE_SIZE / (1024*1024):.0f}MB)",
|
||||
}
|
||||
except OSError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Failed to access file: {e}"}
|
||||
|
||||
return None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: local (faster-whisper)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Transcribe using faster-whisper (local, free)."""
|
||||
global _local_model, _local_model_name
|
||||
|
||||
if not _HAS_FASTER_WHISPER:
|
||||
return {"success": False, "transcript": "", "error": "faster-whisper not installed"}
|
||||
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
# Lazy-load the model (downloads on first use, ~150 MB for 'base')
|
||||
if _local_model is None or _local_model_name != model_name:
|
||||
logger.info("Loading faster-whisper model '%s' (first load downloads the model)...", model_name)
|
||||
_local_model = WhisperModel(model_name, device="auto", compute_type="auto")
|
||||
_local_model_name = model_name
|
||||
|
||||
segments, info = _local_model.transcribe(file_path, beam_size=5)
|
||||
transcript = " ".join(segment.text.strip() for segment in segments)
|
||||
|
||||
logger.info(
|
||||
"Transcribed %s via local whisper (%s, lang=%s, %.1fs audio)",
|
||||
Path(file_path).name, model_name, info.language, info.duration,
|
||||
)
|
||||
|
||||
return {"success": True, "transcript": transcript, "provider": "local"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Local transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Local transcription failed: {e}"}
|
||||
|
||||
|
||||
def _prepare_local_audio(file_path: str, work_dir: str) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Normalize audio for local CLI STT when needed."""
|
||||
audio_path = Path(file_path)
|
||||
if audio_path.suffix.lower() in LOCAL_NATIVE_AUDIO_FORMATS:
|
||||
return file_path, None
|
||||
|
||||
ffmpeg = _find_ffmpeg_binary()
|
||||
if not ffmpeg:
|
||||
return None, "Local STT fallback requires ffmpeg for non-WAV inputs, but ffmpeg was not found"
|
||||
|
||||
converted_path = os.path.join(work_dir, f"{audio_path.stem}.wav")
|
||||
command = [ffmpeg, "-y", "-i", file_path, converted_path]
|
||||
|
||||
try:
|
||||
subprocess.run(command, check=True, capture_output=True, text=True)
|
||||
return converted_path, None
|
||||
except subprocess.CalledProcessError as e:
|
||||
details = e.stderr.strip() or e.stdout.strip() or str(e)
|
||||
logger.error("ffmpeg conversion failed for %s: %s", file_path, details)
|
||||
return None, f"Failed to convert audio for local STT: {details}"
|
||||
|
||||
|
||||
def _transcribe_local_command(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Run the configured local STT command template and read back a .txt transcript."""
|
||||
command_template = _get_local_command_template()
|
||||
if not command_template:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": (
|
||||
f"{LOCAL_STT_COMMAND_ENV} not configured and no local whisper binary was found"
|
||||
),
|
||||
}
|
||||
|
||||
language = os.getenv(LOCAL_STT_LANGUAGE_ENV, DEFAULT_LOCAL_STT_LANGUAGE)
|
||||
normalized_model = _normalize_local_command_model(model_name)
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix="hermes-local-stt-") as output_dir:
|
||||
prepared_input, prep_error = _prepare_local_audio(file_path, output_dir)
|
||||
if prep_error:
|
||||
return {"success": False, "transcript": "", "error": prep_error}
|
||||
|
||||
command = command_template.format(
|
||||
input_path=shlex.quote(prepared_input),
|
||||
output_dir=shlex.quote(output_dir),
|
||||
language=shlex.quote(language),
|
||||
model=shlex.quote(normalized_model),
|
||||
)
|
||||
subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
|
||||
|
||||
txt_files = sorted(Path(output_dir).glob("*.txt"))
|
||||
if not txt_files:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": "Local STT command completed but did not produce a .txt transcript",
|
||||
}
|
||||
|
||||
transcript_text = txt_files[0].read_text(encoding="utf-8").strip()
|
||||
logger.info(
|
||||
"Transcribed %s via local STT command (%s, %d chars)",
|
||||
Path(file_path).name,
|
||||
normalized_model,
|
||||
len(transcript_text),
|
||||
)
|
||||
return {"success": True, "transcript": transcript_text, "provider": "local_command"}
|
||||
|
||||
except KeyError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"Invalid {LOCAL_STT_COMMAND_ENV} template, missing placeholder: {e}",
|
||||
}
|
||||
except subprocess.CalledProcessError as e:
|
||||
details = e.stderr.strip() or e.stdout.strip() or str(e)
|
||||
logger.error("Local STT command failed for %s: %s", file_path, details)
|
||||
return {"success": False, "transcript": "", "error": f"Local STT failed: {details}"}
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error during local command transcription: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Local transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: groq (Whisper API — free tier)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _transcribe_groq(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Transcribe using Groq Whisper API (free tier available)."""
|
||||
api_key = os.getenv("GROQ_API_KEY")
|
||||
if not api_key:
|
||||
return {"success": False, "transcript": "", "error": "GROQ_API_KEY not set"}
|
||||
|
||||
if not _HAS_OPENAI:
|
||||
return {"success": False, "transcript": "", "error": "openai package not installed"}
|
||||
|
||||
# Auto-correct model if caller passed an OpenAI-only model
|
||||
if model_name in OPENAI_MODELS:
|
||||
logger.info("Model %s not available on Groq, using %s", model_name, DEFAULT_GROQ_STT_MODEL)
|
||||
model_name = DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
try:
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
client = OpenAI(api_key=api_key, base_url=GROQ_BASE_URL, timeout=30, max_retries=0)
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=audio_file,
|
||||
response_format="text",
|
||||
)
|
||||
|
||||
transcript_text = str(transcription).strip()
|
||||
logger.info("Transcribed %s via Groq API (%s, %d chars)",
|
||||
Path(file_path).name, model_name, len(transcript_text))
|
||||
|
||||
return {"success": True, "transcript": transcript_text, "provider": "groq"}
|
||||
|
||||
except PermissionError:
|
||||
return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"}
|
||||
except APIConnectionError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Connection error: {e}"}
|
||||
except APITimeoutError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Request timeout: {e}"}
|
||||
except APIError as e:
|
||||
return {"success": False, "transcript": "", "error": f"API error: {e}"}
|
||||
except Exception as e:
|
||||
logger.error("Groq transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: openai (Whisper API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _transcribe_openai(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Transcribe using OpenAI Whisper API (paid)."""
|
||||
api_key = _resolve_openai_api_key()
|
||||
if not api_key:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": "Neither VOICE_TOOLS_OPENAI_KEY nor OPENAI_API_KEY is set",
|
||||
}
|
||||
|
||||
if not _HAS_OPENAI:
|
||||
return {"success": False, "transcript": "", "error": "openai package not installed"}
|
||||
|
||||
# Auto-correct model if caller passed a Groq-only model
|
||||
if model_name in GROQ_MODELS:
|
||||
logger.info("Model %s not available on OpenAI, using %s", model_name, DEFAULT_STT_MODEL)
|
||||
model_name = DEFAULT_STT_MODEL
|
||||
|
||||
try:
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
client = OpenAI(api_key=api_key, base_url=OPENAI_BASE_URL, timeout=30, max_retries=0)
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=audio_file,
|
||||
response_format="text",
|
||||
)
|
||||
|
||||
transcript_text = str(transcription).strip()
|
||||
logger.info("Transcribed %s via OpenAI API (%s, %d chars)",
|
||||
Path(file_path).name, model_name, len(transcript_text))
|
||||
|
||||
return {"success": True, "transcript": transcript_text, "provider": "openai"}
|
||||
|
||||
except PermissionError:
|
||||
return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"}
|
||||
except APIConnectionError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Connection error: {e}"}
|
||||
except APITimeoutError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Request timeout: {e}"}
|
||||
except APIError as e:
|
||||
return {"success": False, "transcript": "", "error": f"API error: {e}"}
|
||||
except Exception as e:
|
||||
logger.error("OpenAI transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Transcribe an audio file using the configured STT provider.
|
||||
|
||||
Provider priority:
|
||||
1. User config (``stt.provider`` in config.yaml)
|
||||
2. Auto-detect: local faster-whisper (free) > Groq (free tier) > OpenAI (paid)
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the audio file to transcribe.
|
||||
model: Override the model. If None, uses config or provider default.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- "success" (bool): Whether transcription succeeded
|
||||
- "transcript" (str): The transcribed text (empty on failure)
|
||||
- "error" (str, optional): Error message if success is False
|
||||
- "provider" (str, optional): Which provider was used
|
||||
"""
|
||||
# Validate input
|
||||
error = _validate_audio_file(file_path)
|
||||
if error:
|
||||
return error
|
||||
|
||||
# Load config and determine provider
|
||||
stt_config = _load_stt_config()
|
||||
if not is_stt_enabled(stt_config):
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": "STT is disabled in config.yaml (stt.enabled: false).",
|
||||
}
|
||||
|
||||
provider = _get_provider(stt_config)
|
||||
|
||||
if provider == "local":
|
||||
local_cfg = stt_config.get("local", {})
|
||||
model_name = model or local_cfg.get("model", DEFAULT_LOCAL_MODEL)
|
||||
return _transcribe_local(file_path, model_name)
|
||||
|
||||
if provider == "local_command":
|
||||
local_cfg = stt_config.get("local", {})
|
||||
model_name = _normalize_local_command_model(
|
||||
model or local_cfg.get("model", DEFAULT_LOCAL_MODEL)
|
||||
)
|
||||
return _transcribe_local_command(file_path, model_name)
|
||||
|
||||
if provider == "groq":
|
||||
model_name = model or DEFAULT_GROQ_STT_MODEL
|
||||
return _transcribe_groq(file_path, model_name)
|
||||
|
||||
if provider == "openai":
|
||||
openai_cfg = stt_config.get("openai", {})
|
||||
model_name = model or openai_cfg.get("model", DEFAULT_STT_MODEL)
|
||||
return _transcribe_openai(file_path, model_name)
|
||||
|
||||
# No provider available
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": (
|
||||
"No STT provider available. Install faster-whisper for free local "
|
||||
f"transcription, configure {LOCAL_STT_COMMAND_ENV} or install a local whisper CLI, "
|
||||
"set GROQ_API_KEY for free Groq Whisper, or set VOICE_TOOLS_OPENAI_KEY "
|
||||
"or OPENAI_API_KEY for the OpenAI Whisper API."
|
||||
),
|
||||
}
|
||||
847
hermes_code/tools/tts_tool.py
Normal file
847
hermes_code/tools/tts_tool.py
Normal file
|
|
@ -0,0 +1,847 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Text-to-Speech Tool Module
|
||||
|
||||
Supports four TTS providers:
|
||||
- Edge TTS (default, free, no API key): Microsoft Edge neural voices
|
||||
- ElevenLabs (premium): High-quality voices, needs ELEVENLABS_API_KEY
|
||||
- OpenAI TTS: Good quality, needs OPENAI_API_KEY
|
||||
- NeuTTS (local, free, no API key): On-device TTS via neutts_cli, needs neutts installed
|
||||
|
||||
Output formats:
|
||||
- Opus (.ogg) for Telegram voice bubbles (requires ffmpeg for Edge TTS)
|
||||
- MP3 (.mp3) for everything else (CLI, Discord, WhatsApp)
|
||||
|
||||
Configuration is loaded from ~/.hermes/config.yaml under the 'tts:' key.
|
||||
The user chooses the provider and voice; the model just sends text.
|
||||
|
||||
Usage:
|
||||
from tools.tts_tool import text_to_speech_tool, check_tts_requirements
|
||||
|
||||
result = text_to_speech_tool(text="Hello world")
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy imports -- providers are imported only when actually used to avoid
|
||||
# crashing in headless environments (SSH, Docker, WSL, no PortAudio).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _import_edge_tts():
|
||||
"""Lazy import edge_tts. Returns the module or raises ImportError."""
|
||||
import edge_tts
|
||||
return edge_tts
|
||||
|
||||
def _import_elevenlabs():
|
||||
"""Lazy import ElevenLabs client. Returns the class or raises ImportError."""
|
||||
from elevenlabs.client import ElevenLabs
|
||||
return ElevenLabs
|
||||
|
||||
def _import_openai_client():
|
||||
"""Lazy import OpenAI client. Returns the class or raises ImportError."""
|
||||
from openai import OpenAI as OpenAIClient
|
||||
return OpenAIClient
|
||||
|
||||
def _import_sounddevice():
|
||||
"""Lazy import sounddevice. Returns the module or raises ImportError/OSError."""
|
||||
import sounddevice as sd
|
||||
return sd
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Defaults
|
||||
# ===========================================================================
|
||||
DEFAULT_PROVIDER = "edge"
|
||||
DEFAULT_EDGE_VOICE = "en-US-AriaNeural"
|
||||
DEFAULT_ELEVENLABS_VOICE_ID = "pNInz6obpgDQGcFmaJgB" # Adam
|
||||
DEFAULT_ELEVENLABS_MODEL_ID = "eleven_multilingual_v2"
|
||||
DEFAULT_ELEVENLABS_STREAMING_MODEL_ID = "eleven_flash_v2_5"
|
||||
DEFAULT_OPENAI_MODEL = "gpt-4o-mini-tts"
|
||||
DEFAULT_OPENAI_VOICE = "alloy"
|
||||
DEFAULT_OUTPUT_DIR = str(Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "audio_cache")
|
||||
MAX_TEXT_LENGTH = 4000
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Config loader -- reads tts: section from ~/.hermes/config.yaml
|
||||
# ===========================================================================
|
||||
def _load_tts_config() -> Dict[str, Any]:
|
||||
"""
|
||||
Load TTS configuration from ~/.hermes/config.yaml.
|
||||
|
||||
Returns a dict with provider settings. Falls back to defaults
|
||||
for any missing fields.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
config = load_config()
|
||||
return config.get("tts", {})
|
||||
except ImportError:
|
||||
logger.debug("hermes_cli.config not available, using default TTS config")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load TTS config: %s", e, exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def _get_provider(tts_config: Dict[str, Any]) -> str:
|
||||
"""Get the configured TTS provider name."""
|
||||
return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# ffmpeg Opus conversion (Edge TTS MP3 -> OGG Opus for Telegram)
|
||||
# ===========================================================================
|
||||
def _has_ffmpeg() -> bool:
|
||||
"""Check if ffmpeg is available on the system."""
|
||||
return shutil.which("ffmpeg") is not None
|
||||
|
||||
|
||||
def _convert_to_opus(mp3_path: str) -> Optional[str]:
|
||||
"""
|
||||
Convert an MP3 file to OGG Opus format for Telegram voice bubbles.
|
||||
|
||||
Args:
|
||||
mp3_path: Path to the input MP3 file.
|
||||
|
||||
Returns:
|
||||
Path to the .ogg file, or None if conversion fails.
|
||||
"""
|
||||
if not _has_ffmpeg():
|
||||
return None
|
||||
|
||||
ogg_path = mp3_path.rsplit(".", 1)[0] + ".ogg"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["ffmpeg", "-i", mp3_path, "-acodec", "libopus",
|
||||
"-ac", "1", "-b:a", "64k", "-vbr", "off", ogg_path, "-y"],
|
||||
capture_output=True, timeout=30,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning("ffmpeg conversion failed with return code %d: %s",
|
||||
result.returncode, result.stderr.decode('utf-8', errors='ignore')[:200])
|
||||
return None
|
||||
if os.path.exists(ogg_path) and os.path.getsize(ogg_path) > 0:
|
||||
return ogg_path
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("ffmpeg OGG conversion timed out after 30s")
|
||||
except FileNotFoundError:
|
||||
logger.warning("ffmpeg not found in PATH")
|
||||
except Exception as e:
|
||||
logger.warning("ffmpeg OGG conversion failed: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Provider: Edge TTS (free)
|
||||
# ===========================================================================
|
||||
async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate audio using Edge TTS.
|
||||
|
||||
Args:
|
||||
text: Text to convert.
|
||||
output_path: Where to save the MP3 file.
|
||||
tts_config: TTS config dict.
|
||||
|
||||
Returns:
|
||||
Path to the saved audio file.
|
||||
"""
|
||||
_edge_tts = _import_edge_tts()
|
||||
edge_config = tts_config.get("edge", {})
|
||||
voice = edge_config.get("voice", DEFAULT_EDGE_VOICE)
|
||||
|
||||
communicate = _edge_tts.Communicate(text, voice)
|
||||
await communicate.save(output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Provider: ElevenLabs (premium)
|
||||
# ===========================================================================
|
||||
def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate audio using ElevenLabs.
|
||||
|
||||
Args:
|
||||
text: Text to convert.
|
||||
output_path: Where to save the audio file.
|
||||
tts_config: TTS config dict.
|
||||
|
||||
Returns:
|
||||
Path to the saved audio file.
|
||||
"""
|
||||
api_key = os.getenv("ELEVENLABS_API_KEY", "")
|
||||
if not api_key:
|
||||
raise ValueError("ELEVENLABS_API_KEY not set. Get one at https://elevenlabs.io/")
|
||||
|
||||
el_config = tts_config.get("elevenlabs", {})
|
||||
voice_id = el_config.get("voice_id", DEFAULT_ELEVENLABS_VOICE_ID)
|
||||
model_id = el_config.get("model_id", DEFAULT_ELEVENLABS_MODEL_ID)
|
||||
|
||||
# Determine output format based on file extension
|
||||
if output_path.endswith(".ogg"):
|
||||
output_format = "opus_48000_64"
|
||||
else:
|
||||
output_format = "mp3_44100_128"
|
||||
|
||||
ElevenLabs = _import_elevenlabs()
|
||||
client = ElevenLabs(api_key=api_key)
|
||||
audio_generator = client.text_to_speech.convert(
|
||||
text=text,
|
||||
voice_id=voice_id,
|
||||
model_id=model_id,
|
||||
output_format=output_format,
|
||||
)
|
||||
|
||||
# audio_generator yields chunks -- write them all
|
||||
with open(output_path, "wb") as f:
|
||||
for chunk in audio_generator:
|
||||
f.write(chunk)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Provider: OpenAI TTS
|
||||
# ===========================================================================
|
||||
def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate audio using OpenAI TTS.
|
||||
|
||||
Args:
|
||||
text: Text to convert.
|
||||
output_path: Where to save the audio file.
|
||||
tts_config: TTS config dict.
|
||||
|
||||
Returns:
|
||||
Path to the saved audio file.
|
||||
"""
|
||||
api_key = os.getenv("VOICE_TOOLS_OPENAI_KEY", "")
|
||||
if not api_key:
|
||||
raise ValueError("VOICE_TOOLS_OPENAI_KEY not set. Get one at https://platform.openai.com/api-keys")
|
||||
|
||||
oai_config = tts_config.get("openai", {})
|
||||
model = oai_config.get("model", DEFAULT_OPENAI_MODEL)
|
||||
voice = oai_config.get("voice", DEFAULT_OPENAI_VOICE)
|
||||
base_url = oai_config.get("base_url", "https://api.openai.com/v1")
|
||||
|
||||
# Determine response format from extension
|
||||
if output_path.endswith(".ogg"):
|
||||
response_format = "opus"
|
||||
else:
|
||||
response_format = "mp3"
|
||||
|
||||
OpenAIClient = _import_openai_client()
|
||||
client = OpenAIClient(api_key=api_key, base_url=base_url)
|
||||
response = client.audio.speech.create(
|
||||
model=model,
|
||||
voice=voice,
|
||||
input=text,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
response.stream_to_file(output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# NeuTTS (local, on-device TTS via neutts_cli)
|
||||
# ===========================================================================
|
||||
|
||||
def _check_neutts_available() -> bool:
|
||||
"""Check if the neutts engine is importable (installed locally)."""
|
||||
try:
|
||||
import importlib.util
|
||||
return importlib.util.find_spec("neutts") is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _default_neutts_ref_audio() -> str:
|
||||
"""Return path to the bundled default voice reference audio."""
|
||||
return str(Path(__file__).parent / "neutts_samples" / "jo.wav")
|
||||
|
||||
|
||||
def _default_neutts_ref_text() -> str:
|
||||
"""Return path to the bundled default voice reference transcript."""
|
||||
return str(Path(__file__).parent / "neutts_samples" / "jo.txt")
|
||||
|
||||
|
||||
def _generate_neutts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
|
||||
"""Generate speech using the local NeuTTS engine.
|
||||
|
||||
Runs synthesis in a subprocess via tools/neutts_synth.py to keep the
|
||||
~500MB model in a separate process that exits after synthesis.
|
||||
Outputs WAV; the caller handles conversion for Telegram if needed.
|
||||
"""
|
||||
import sys
|
||||
|
||||
neutts_config = tts_config.get("neutts", {})
|
||||
ref_audio = neutts_config.get("ref_audio", "") or _default_neutts_ref_audio()
|
||||
ref_text = neutts_config.get("ref_text", "") or _default_neutts_ref_text()
|
||||
model = neutts_config.get("model", "neuphonic/neutts-air-q4-gguf")
|
||||
device = neutts_config.get("device", "cpu")
|
||||
|
||||
# NeuTTS outputs WAV natively — use a .wav path for generation,
|
||||
# let the caller convert to the final format afterward.
|
||||
wav_path = output_path
|
||||
if not output_path.endswith(".wav"):
|
||||
wav_path = output_path.rsplit(".", 1)[0] + ".wav"
|
||||
|
||||
synth_script = str(Path(__file__).parent / "neutts_synth.py")
|
||||
cmd = [
|
||||
sys.executable, synth_script,
|
||||
"--text", text,
|
||||
"--out", wav_path,
|
||||
"--ref-audio", ref_audio,
|
||||
"--ref-text", ref_text,
|
||||
"--model", model,
|
||||
"--device", device,
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.strip()
|
||||
# Filter out the "OK:" line from stderr
|
||||
error_lines = [l for l in stderr.splitlines() if not l.startswith("OK:")]
|
||||
raise RuntimeError(f"NeuTTS synthesis failed: {chr(10).join(error_lines) or 'unknown error'}")
|
||||
|
||||
# If the caller wanted .mp3 or .ogg, convert from WAV
|
||||
if wav_path != output_path:
|
||||
ffmpeg = shutil.which("ffmpeg")
|
||||
if ffmpeg:
|
||||
conv_cmd = [ffmpeg, "-i", wav_path, "-y", "-loglevel", "error", output_path]
|
||||
subprocess.run(conv_cmd, check=True, timeout=30)
|
||||
os.remove(wav_path)
|
||||
else:
|
||||
# No ffmpeg — just rename the WAV to the expected path
|
||||
os.rename(wav_path, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Main tool function
|
||||
# ===========================================================================
|
||||
def text_to_speech_tool(
|
||||
text: str,
|
||||
output_path: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Convert text to speech audio.
|
||||
|
||||
Reads provider/voice config from ~/.hermes/config.yaml (tts: section).
|
||||
The model sends text; the user configures voice and provider.
|
||||
|
||||
On messaging platforms, the returned MEDIA:<path> tag is intercepted
|
||||
by the send pipeline and delivered as a native voice message.
|
||||
In CLI mode, the file is saved to ~/voice-memos/.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech.
|
||||
output_path: Optional custom save path. Defaults to ~/voice-memos/<timestamp>.mp3
|
||||
|
||||
Returns:
|
||||
str: JSON result with success, file_path, and optionally MEDIA tag.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return json.dumps({"success": False, "error": "Text is required"}, ensure_ascii=False)
|
||||
|
||||
# Truncate very long text with a warning
|
||||
if len(text) > MAX_TEXT_LENGTH:
|
||||
logger.warning("TTS text too long (%d chars), truncating to %d", len(text), MAX_TEXT_LENGTH)
|
||||
text = text[:MAX_TEXT_LENGTH]
|
||||
|
||||
tts_config = _load_tts_config()
|
||||
provider = _get_provider(tts_config)
|
||||
|
||||
# Detect platform from gateway env var to choose the best output format.
|
||||
# Telegram voice bubbles require Opus (.ogg); OpenAI and ElevenLabs can
|
||||
# produce Opus natively (no ffmpeg needed). Edge TTS always outputs MP3
|
||||
# and needs ffmpeg for conversion.
|
||||
platform = os.getenv("HERMES_SESSION_PLATFORM", "").lower()
|
||||
want_opus = (platform == "telegram")
|
||||
|
||||
# Determine output path
|
||||
if output_path:
|
||||
file_path = Path(output_path).expanduser()
|
||||
else:
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_dir = Path(DEFAULT_OUTPUT_DIR)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Use .ogg for Telegram with providers that support native Opus output,
|
||||
# otherwise fall back to .mp3 (Edge TTS will attempt ffmpeg conversion later).
|
||||
if want_opus and provider in ("openai", "elevenlabs"):
|
||||
file_path = out_dir / f"tts_{timestamp}.ogg"
|
||||
else:
|
||||
file_path = out_dir / f"tts_{timestamp}.mp3"
|
||||
|
||||
# Ensure parent directory exists
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_str = str(file_path)
|
||||
|
||||
try:
|
||||
# Generate audio with the configured provider
|
||||
if provider == "elevenlabs":
|
||||
try:
|
||||
_import_elevenlabs()
|
||||
except ImportError:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs"
|
||||
}, ensure_ascii=False)
|
||||
logger.info("Generating speech with ElevenLabs...")
|
||||
_generate_elevenlabs(text, file_str, tts_config)
|
||||
|
||||
elif provider == "openai":
|
||||
try:
|
||||
_import_openai_client()
|
||||
except ImportError:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "OpenAI provider selected but 'openai' package not installed."
|
||||
}, ensure_ascii=False)
|
||||
logger.info("Generating speech with OpenAI TTS...")
|
||||
_generate_openai_tts(text, file_str, tts_config)
|
||||
|
||||
elif provider == "neutts":
|
||||
if not _check_neutts_available():
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "NeuTTS provider selected but neutts is not installed. "
|
||||
"Run hermes setup and choose NeuTTS, or install espeak-ng and run python -m pip install -U neutts[all]."
|
||||
}, ensure_ascii=False)
|
||||
logger.info("Generating speech with NeuTTS (local)...")
|
||||
_generate_neutts(text, file_str, tts_config)
|
||||
|
||||
else:
|
||||
# Default: Edge TTS (free), with NeuTTS as local fallback
|
||||
edge_available = True
|
||||
try:
|
||||
_import_edge_tts()
|
||||
except ImportError:
|
||||
edge_available = False
|
||||
|
||||
if edge_available:
|
||||
logger.info("Generating speech with Edge TTS...")
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
pool.submit(
|
||||
lambda: asyncio.run(_generate_edge_tts(text, file_str, tts_config))
|
||||
).result(timeout=60)
|
||||
except RuntimeError:
|
||||
asyncio.run(_generate_edge_tts(text, file_str, tts_config))
|
||||
elif _check_neutts_available():
|
||||
logger.info("Edge TTS not available, falling back to NeuTTS (local)...")
|
||||
provider = "neutts"
|
||||
_generate_neutts(text, file_str, tts_config)
|
||||
else:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "No TTS provider available. Install edge-tts (pip install edge-tts) "
|
||||
"or set up NeuTTS for local synthesis."
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Check the file was actually created
|
||||
if not os.path.exists(file_str) or os.path.getsize(file_str) == 0:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": f"TTS generation produced no output (provider: {provider})"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Try Opus conversion for Telegram compatibility
|
||||
# Edge TTS outputs MP3, NeuTTS outputs WAV — both need ffmpeg conversion
|
||||
voice_compatible = False
|
||||
if provider in ("edge", "neutts") and not file_str.endswith(".ogg"):
|
||||
opus_path = _convert_to_opus(file_str)
|
||||
if opus_path:
|
||||
file_str = opus_path
|
||||
voice_compatible = True
|
||||
elif provider in ("elevenlabs", "openai"):
|
||||
# These providers can output Opus natively if the path ends in .ogg
|
||||
voice_compatible = file_str.endswith(".ogg")
|
||||
|
||||
file_size = os.path.getsize(file_str)
|
||||
logger.info("TTS audio saved: %s (%s bytes, provider: %s)", file_str, f"{file_size:,}", provider)
|
||||
|
||||
# Build response with MEDIA tag for platform delivery
|
||||
media_tag = f"MEDIA:{file_str}"
|
||||
if voice_compatible:
|
||||
media_tag = f"[[audio_as_voice]]\n{media_tag}"
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"file_path": file_str,
|
||||
"media_tag": media_tag,
|
||||
"provider": provider,
|
||||
"voice_compatible": voice_compatible,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
except ValueError as e:
|
||||
# Configuration errors (missing API keys, etc.)
|
||||
error_msg = f"TTS configuration error ({provider}): {e}"
|
||||
logger.error("%s", error_msg)
|
||||
return json.dumps({"success": False, "error": error_msg}, ensure_ascii=False)
|
||||
except FileNotFoundError as e:
|
||||
# Missing dependencies or files
|
||||
error_msg = f"TTS dependency missing ({provider}): {e}"
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
return json.dumps({"success": False, "error": error_msg}, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
# Unexpected errors
|
||||
error_msg = f"TTS generation failed ({provider}): {e}"
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
return json.dumps({"success": False, "error": error_msg}, ensure_ascii=False)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Requirements check
|
||||
# ===========================================================================
|
||||
def check_tts_requirements() -> bool:
|
||||
"""
|
||||
Check if at least one TTS provider is available.
|
||||
|
||||
Edge TTS needs no API key and is the default, so if the package
|
||||
is installed, TTS is available.
|
||||
|
||||
Returns:
|
||||
bool: True if at least one provider can work.
|
||||
"""
|
||||
try:
|
||||
_import_edge_tts()
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
_import_elevenlabs()
|
||||
if os.getenv("ELEVENLABS_API_KEY"):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
_import_openai_client()
|
||||
if os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
if _check_neutts_available():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Streaming TTS: sentence-by-sentence pipeline for ElevenLabs
|
||||
# ===========================================================================
|
||||
# Sentence boundary pattern: punctuation followed by space or newline
|
||||
_SENTENCE_BOUNDARY_RE = re.compile(r'(?<=[.!?])(?:\s|\n)|(?:\n\n)')
|
||||
|
||||
# Markdown stripping patterns (same as cli.py _voice_speak_response)
|
||||
_MD_CODE_BLOCK = re.compile(r'```[\s\S]*?```')
|
||||
_MD_LINK = re.compile(r'\[([^\]]+)\]\([^)]+\)')
|
||||
_MD_URL = re.compile(r'https?://\S+')
|
||||
_MD_BOLD = re.compile(r'\*\*(.+?)\*\*')
|
||||
_MD_ITALIC = re.compile(r'\*(.+?)\*')
|
||||
_MD_INLINE_CODE = re.compile(r'`(.+?)`')
|
||||
_MD_HEADER = re.compile(r'^#+\s*', flags=re.MULTILINE)
|
||||
_MD_LIST_ITEM = re.compile(r'^\s*[-*]\s+', flags=re.MULTILINE)
|
||||
_MD_HR = re.compile(r'---+')
|
||||
_MD_EXCESS_NL = re.compile(r'\n{3,}')
|
||||
|
||||
|
||||
def _strip_markdown_for_tts(text: str) -> str:
|
||||
"""Remove markdown formatting that shouldn't be spoken aloud."""
|
||||
text = _MD_CODE_BLOCK.sub(' ', text)
|
||||
text = _MD_LINK.sub(r'\1', text)
|
||||
text = _MD_URL.sub('', text)
|
||||
text = _MD_BOLD.sub(r'\1', text)
|
||||
text = _MD_ITALIC.sub(r'\1', text)
|
||||
text = _MD_INLINE_CODE.sub(r'\1', text)
|
||||
text = _MD_HEADER.sub('', text)
|
||||
text = _MD_LIST_ITEM.sub('', text)
|
||||
text = _MD_HR.sub('', text)
|
||||
text = _MD_EXCESS_NL.sub('\n\n', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def stream_tts_to_speaker(
|
||||
text_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
tts_done_event: threading.Event,
|
||||
display_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
"""Consume text deltas from *text_queue*, buffer them into sentences,
|
||||
and stream each sentence through ElevenLabs TTS to the speaker in
|
||||
real-time.
|
||||
|
||||
Protocol:
|
||||
* The producer puts ``str`` deltas onto *text_queue*.
|
||||
* A ``None`` sentinel signals end-of-text (flush remaining buffer).
|
||||
* *stop_event* can be set to abort early (e.g. user interrupt).
|
||||
* *tts_done_event* is **set** in the ``finally`` block so callers
|
||||
waiting on it (continuous voice mode) know playback is finished.
|
||||
"""
|
||||
tts_done_event.clear()
|
||||
|
||||
try:
|
||||
# --- TTS client setup (optional -- display_callback works without it) ---
|
||||
client = None
|
||||
output_stream = None
|
||||
voice_id = DEFAULT_ELEVENLABS_VOICE_ID
|
||||
model_id = DEFAULT_ELEVENLABS_STREAMING_MODEL_ID
|
||||
|
||||
tts_config = _load_tts_config()
|
||||
el_config = tts_config.get("elevenlabs", {})
|
||||
voice_id = el_config.get("voice_id", voice_id)
|
||||
model_id = el_config.get("streaming_model_id",
|
||||
el_config.get("model_id", model_id))
|
||||
|
||||
api_key = os.getenv("ELEVENLABS_API_KEY", "")
|
||||
if not api_key:
|
||||
logger.warning("ELEVENLABS_API_KEY not set; streaming TTS audio disabled")
|
||||
else:
|
||||
try:
|
||||
ElevenLabs = _import_elevenlabs()
|
||||
client = ElevenLabs(api_key=api_key)
|
||||
except ImportError:
|
||||
logger.warning("elevenlabs package not installed; streaming TTS disabled")
|
||||
|
||||
# Open a single sounddevice output stream for the lifetime of
|
||||
# this function. ElevenLabs pcm_24000 produces signed 16-bit
|
||||
# little-endian mono PCM at 24 kHz.
|
||||
if client is not None:
|
||||
try:
|
||||
sd = _import_sounddevice()
|
||||
import numpy as _np
|
||||
output_stream = sd.OutputStream(
|
||||
samplerate=24000, channels=1, dtype="int16",
|
||||
)
|
||||
output_stream.start()
|
||||
except (ImportError, OSError) as exc:
|
||||
logger.debug("sounddevice not available: %s", exc)
|
||||
output_stream = None
|
||||
except Exception as exc:
|
||||
logger.warning("sounddevice OutputStream failed: %s", exc)
|
||||
output_stream = None
|
||||
|
||||
sentence_buf = ""
|
||||
min_sentence_len = 20
|
||||
long_flush_len = 100
|
||||
queue_timeout = 0.5
|
||||
_spoken_sentences: list[str] = [] # track spoken sentences to skip duplicates
|
||||
# Regex to strip complete <think>...</think> blocks from buffer
|
||||
_think_block_re = re.compile(r'<think[\s>].*?</think>', flags=re.DOTALL)
|
||||
|
||||
def _speak_sentence(sentence: str):
|
||||
"""Display sentence and optionally generate + play audio."""
|
||||
if stop_event.is_set():
|
||||
return
|
||||
cleaned = _strip_markdown_for_tts(sentence).strip()
|
||||
if not cleaned:
|
||||
return
|
||||
# Skip duplicate/near-duplicate sentences (LLM repetition)
|
||||
cleaned_lower = cleaned.lower().rstrip(".!,")
|
||||
for prev in _spoken_sentences:
|
||||
if prev.lower().rstrip(".!,") == cleaned_lower:
|
||||
return
|
||||
_spoken_sentences.append(cleaned)
|
||||
# Display raw sentence on screen before TTS processing
|
||||
if display_callback is not None:
|
||||
display_callback(sentence)
|
||||
# Skip audio generation if no TTS client available
|
||||
if client is None:
|
||||
return
|
||||
# Truncate very long sentences
|
||||
if len(cleaned) > MAX_TEXT_LENGTH:
|
||||
cleaned = cleaned[:MAX_TEXT_LENGTH]
|
||||
try:
|
||||
audio_iter = client.text_to_speech.convert(
|
||||
text=cleaned,
|
||||
voice_id=voice_id,
|
||||
model_id=model_id,
|
||||
output_format="pcm_24000",
|
||||
)
|
||||
if output_stream is not None:
|
||||
for chunk in audio_iter:
|
||||
if stop_event.is_set():
|
||||
break
|
||||
import numpy as _np
|
||||
audio_array = _np.frombuffer(chunk, dtype=_np.int16)
|
||||
output_stream.write(audio_array.reshape(-1, 1))
|
||||
else:
|
||||
# Fallback: write chunks to temp file and play via system player
|
||||
_play_via_tempfile(audio_iter, stop_event)
|
||||
except Exception as exc:
|
||||
logger.warning("Streaming TTS sentence failed: %s", exc)
|
||||
|
||||
def _play_via_tempfile(audio_iter, stop_evt):
|
||||
"""Write PCM chunks to a temp WAV file and play it."""
|
||||
tmp_path = None
|
||||
try:
|
||||
import wave
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp_path = tmp.name
|
||||
with wave.open(tmp, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2) # 16-bit
|
||||
wf.setframerate(24000)
|
||||
for chunk in audio_iter:
|
||||
if stop_evt.is_set():
|
||||
break
|
||||
wf.writeframes(chunk)
|
||||
from tools.voice_mode import play_audio_file
|
||||
play_audio_file(tmp_path)
|
||||
except Exception as exc:
|
||||
logger.warning("Temp-file TTS fallback failed: %s", exc)
|
||||
finally:
|
||||
if tmp_path:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
while not stop_event.is_set():
|
||||
# Read next delta from queue
|
||||
try:
|
||||
delta = text_queue.get(timeout=queue_timeout)
|
||||
except queue.Empty:
|
||||
# Timeout: if we have accumulated a long buffer, flush it
|
||||
if len(sentence_buf) > long_flush_len:
|
||||
_speak_sentence(sentence_buf)
|
||||
sentence_buf = ""
|
||||
continue
|
||||
|
||||
if delta is None:
|
||||
# End-of-text sentinel: strip any remaining think blocks, flush
|
||||
sentence_buf = _think_block_re.sub('', sentence_buf)
|
||||
if sentence_buf.strip():
|
||||
_speak_sentence(sentence_buf)
|
||||
break
|
||||
|
||||
sentence_buf += delta
|
||||
|
||||
# --- Think block filtering ---
|
||||
# Strip complete <think>...</think> blocks from buffer.
|
||||
# Works correctly even when tags span multiple deltas.
|
||||
sentence_buf = _think_block_re.sub('', sentence_buf)
|
||||
|
||||
# If an incomplete <think tag is at the end, wait for more data
|
||||
# before extracting sentences (the closing tag may arrive next).
|
||||
if '<think' in sentence_buf and '</think>' not in sentence_buf:
|
||||
continue
|
||||
|
||||
# Check for sentence boundaries
|
||||
while True:
|
||||
m = _SENTENCE_BOUNDARY_RE.search(sentence_buf)
|
||||
if m is None:
|
||||
break
|
||||
end_pos = m.end()
|
||||
sentence = sentence_buf[:end_pos]
|
||||
sentence_buf = sentence_buf[end_pos:]
|
||||
# Merge short fragments into the next sentence
|
||||
if len(sentence.strip()) < min_sentence_len:
|
||||
sentence_buf = sentence + sentence_buf
|
||||
break
|
||||
_speak_sentence(sentence)
|
||||
|
||||
# Drain any remaining items from the queue
|
||||
while True:
|
||||
try:
|
||||
text_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# output_stream is closed in the finally block below
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Streaming TTS pipeline error: %s", exc)
|
||||
finally:
|
||||
# Always close the audio output stream to avoid locking the device
|
||||
if output_stream is not None:
|
||||
try:
|
||||
output_stream.stop()
|
||||
output_stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
tts_done_event.set()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Main -- quick diagnostics
|
||||
# ===========================================================================
|
||||
if __name__ == "__main__":
|
||||
print("🔊 Text-to-Speech Tool Module")
|
||||
print("=" * 50)
|
||||
|
||||
def _check(importer, label):
|
||||
try:
|
||||
importer()
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
print(f"\nProvider availability:")
|
||||
print(f" Edge TTS: {'installed' if _check(_import_edge_tts, 'edge') else 'not installed (pip install edge-tts)'}")
|
||||
print(f" ElevenLabs: {'installed' if _check(_import_elevenlabs, 'el') else 'not installed (pip install elevenlabs)'}")
|
||||
print(f" API Key: {'set' if os.getenv('ELEVENLABS_API_KEY') else 'not set'}")
|
||||
print(f" OpenAI: {'installed' if _check(_import_openai_client, 'oai') else 'not installed'}")
|
||||
print(f" API Key: {'set' if os.getenv('VOICE_TOOLS_OPENAI_KEY') else 'not set (VOICE_TOOLS_OPENAI_KEY)'}")
|
||||
print(f" ffmpeg: {'✅ found' if _has_ffmpeg() else '❌ not found (needed for Telegram Opus)'}")
|
||||
print(f"\n Output dir: {DEFAULT_OUTPUT_DIR}")
|
||||
|
||||
config = _load_tts_config()
|
||||
provider = _get_provider(config)
|
||||
print(f" Configured provider: {provider}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
from tools.registry import registry
|
||||
|
||||
TTS_SCHEMA = {
|
||||
"name": "text_to_speech",
|
||||
"description": "Convert text to speech audio. Returns a MEDIA: path that the platform delivers as a voice message. On Telegram it plays as a voice bubble, on Discord/WhatsApp as an audio attachment. In CLI mode, saves to ~/voice-memos/. Voice and provider are user-configured, not model-selected.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to convert to speech. Keep under 4000 characters."
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "Optional custom file path to save the audio. Defaults to ~/.hermes/audio_cache/<timestamp>.mp3"
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
|
||||
registry.register(
|
||||
name="text_to_speech",
|
||||
toolset="tts",
|
||||
schema=TTS_SCHEMA,
|
||||
handler=lambda args, **kw: text_to_speech_tool(
|
||||
text=args.get("text", ""),
|
||||
output_path=args.get("output_path")),
|
||||
check_fn=check_tts_requirements,
|
||||
emoji="🔊",
|
||||
)
|
||||
96
hermes_code/tools/url_safety.py
Normal file
96
hermes_code/tools/url_safety.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""URL safety checks — blocks requests to private/internal network addresses.
|
||||
|
||||
Prevents SSRF (Server-Side Request Forgery) where a malicious prompt or
|
||||
skill could trick the agent into fetching internal resources like cloud
|
||||
metadata endpoints (169.254.169.254), localhost services, or private
|
||||
network hosts.
|
||||
|
||||
Limitations (documented, not fixable at pre-flight level):
|
||||
- DNS rebinding (TOCTOU): an attacker-controlled DNS server with TTL=0
|
||||
can return a public IP for the check, then a private IP for the actual
|
||||
connection. Fixing this requires connection-level validation (e.g.
|
||||
Python's Champion library or an egress proxy like Stripe's Smokescreen).
|
||||
- Redirect-based bypass in vision_tools is mitigated by an httpx event
|
||||
hook that re-validates each redirect target. Web tools use third-party
|
||||
SDKs (Firecrawl/Tavily) where redirect handling is on their servers.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hostnames that should always be blocked regardless of IP resolution
|
||||
_BLOCKED_HOSTNAMES = frozenset({
|
||||
"metadata.google.internal",
|
||||
"metadata.goog",
|
||||
})
|
||||
|
||||
# 100.64.0.0/10 (CGNAT / Shared Address Space, RFC 6598) is NOT covered by
|
||||
# ipaddress.is_private — it returns False for both is_private and is_global.
|
||||
# Must be blocked explicitly. Used by carrier-grade NAT, Tailscale/WireGuard
|
||||
# VPNs, and some cloud internal networks.
|
||||
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
|
||||
|
||||
|
||||
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
"""Return True if the IP should be blocked for SSRF protection."""
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved:
|
||||
return True
|
||||
if ip.is_multicast or ip.is_unspecified:
|
||||
return True
|
||||
# CGNAT range not covered by is_private
|
||||
if ip in _CGNAT_NETWORK:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_safe_url(url: str) -> bool:
|
||||
"""Return True if the URL target is not a private/internal address.
|
||||
|
||||
Resolves the hostname to an IP and checks against private ranges.
|
||||
Fails closed: DNS errors and unexpected exceptions block the request.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = (parsed.hostname or "").strip().lower()
|
||||
if not hostname:
|
||||
return False
|
||||
|
||||
# Block known internal hostnames
|
||||
if hostname in _BLOCKED_HOSTNAMES:
|
||||
logger.warning("Blocked request to internal hostname: %s", hostname)
|
||||
return False
|
||||
|
||||
# Try to resolve and check IP
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
except socket.gaierror:
|
||||
# DNS resolution failed — fail closed. If DNS can't resolve it,
|
||||
# the HTTP client will also fail, so blocking loses nothing.
|
||||
logger.warning("Blocked request — DNS resolution failed for: %s", hostname)
|
||||
return False
|
||||
|
||||
for family, _, _, _, sockaddr in addr_info:
|
||||
ip_str = sockaddr[0]
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if _is_blocked_ip(ip):
|
||||
logger.warning(
|
||||
"Blocked request to private/internal address: %s -> %s",
|
||||
hostname, ip_str,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
# Fail closed on unexpected errors — don't let parsing edge cases
|
||||
# become SSRF bypass vectors
|
||||
logger.warning("Blocked request — URL safety check error for %s: %s", url, exc)
|
||||
return False
|
||||
541
hermes_code/tools/vision_tools.py
Normal file
541
hermes_code/tools/vision_tools.py
Normal file
|
|
@ -0,0 +1,541 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Vision Tools Module
|
||||
|
||||
This module provides vision analysis tools that work with image URLs.
|
||||
Uses the centralized auxiliary vision router, which can select OpenRouter,
|
||||
Nous, Codex, native Anthropic, or a custom OpenAI-compatible endpoint.
|
||||
|
||||
Available tools:
|
||||
- vision_analyze_tool: Analyze images from URLs with custom prompts
|
||||
|
||||
Features:
|
||||
- Downloads images from URLs and converts to base64 for API compatibility
|
||||
- Comprehensive image description
|
||||
- Context-aware analysis based on user queries
|
||||
- Automatic temporary file cleanup
|
||||
- Proper error handling and validation
|
||||
- Debug logging support
|
||||
|
||||
Usage:
|
||||
from vision_tools import vision_analyze_tool
|
||||
import asyncio
|
||||
|
||||
# Analyze an image
|
||||
result = await vision_analyze_tool(
|
||||
image_url="https://example.com/image.jpg",
|
||||
user_prompt="What architectural style is this building?"
|
||||
)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
import httpx
|
||||
from agent.auxiliary_client import async_call_llm
|
||||
from tools.debug_helpers import DebugSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_debug = DebugSession("vision_tools", env_var="VISION_TOOLS_DEBUG")
|
||||
|
||||
|
||||
def _validate_image_url(url: str) -> bool:
|
||||
"""
|
||||
Basic validation of image URL format.
|
||||
|
||||
Args:
|
||||
url (str): The URL to validate
|
||||
|
||||
Returns:
|
||||
bool: True if URL appears to be valid, False otherwise
|
||||
"""
|
||||
if not url or not isinstance(url, str):
|
||||
return False
|
||||
|
||||
# Basic HTTP/HTTPS URL check
|
||||
if not (url.startswith("http://") or url.startswith("https://")):
|
||||
return False
|
||||
|
||||
# Parse to ensure we at least have a network location; still allow URLs
|
||||
# without file extensions (e.g. CDN endpoints that redirect to images).
|
||||
parsed = urlparse(url)
|
||||
if not parsed.netloc:
|
||||
return False
|
||||
|
||||
# Block private/internal addresses to prevent SSRF
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
|
||||
"""
|
||||
Download an image from a URL to a local destination (async) with retry logic.
|
||||
|
||||
Args:
|
||||
image_url (str): The URL of the image to download
|
||||
destination (Path): The path where the image should be saved
|
||||
max_retries (int): Maximum number of retry attempts (default: 3)
|
||||
|
||||
Returns:
|
||||
Path: The path to the downloaded image
|
||||
|
||||
Raises:
|
||||
Exception: If download fails after all retries
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Create parent directories if they don't exist
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def _ssrf_redirect_guard(response):
|
||||
"""Re-validate each redirect target to prevent redirect-based SSRF.
|
||||
|
||||
Without this, an attacker can host a public URL that 302-redirects
|
||||
to http://169.254.169.254/ and bypass the pre-flight is_safe_url check.
|
||||
|
||||
Must be async because httpx.AsyncClient awaits event hooks.
|
||||
"""
|
||||
if response.is_redirect and response.next_request:
|
||||
redirect_url = str(response.next_request.url)
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(redirect_url):
|
||||
raise ValueError(
|
||||
f"Blocked redirect to private/internal address: {redirect_url}"
|
||||
)
|
||||
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Download the image with appropriate headers using async httpx
|
||||
# Enable follow_redirects to handle image CDNs that redirect (e.g., Imgur, Picsum)
|
||||
# SSRF: event_hooks validates each redirect target against private IP ranges
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
follow_redirects=True,
|
||||
event_hooks={"response": [_ssrf_redirect_guard]},
|
||||
) as client:
|
||||
response = await client.get(
|
||||
image_url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Save the image content
|
||||
destination.write_bytes(response.content)
|
||||
|
||||
return destination
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 ** (attempt + 1) # 2s, 4s, 8s
|
||||
logger.warning("Image download failed (attempt %s/%s): %s", attempt + 1, max_retries, str(e)[:50])
|
||||
logger.warning("Retrying in %ss...", wait_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(
|
||||
"Image download failed after %s attempts: %s",
|
||||
max_retries,
|
||||
str(e)[:100],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
raise last_error
|
||||
|
||||
|
||||
def _determine_mime_type(image_path: Path) -> str:
|
||||
"""
|
||||
Determine the MIME type of an image based on its file extension.
|
||||
|
||||
Args:
|
||||
image_path (Path): Path to the image file
|
||||
|
||||
Returns:
|
||||
str: The MIME type (defaults to image/jpeg if unknown)
|
||||
"""
|
||||
extension = image_path.suffix.lower()
|
||||
mime_types = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.bmp': 'image/bmp',
|
||||
'.webp': 'image/webp',
|
||||
'.svg': 'image/svg+xml'
|
||||
}
|
||||
return mime_types.get(extension, 'image/jpeg')
|
||||
|
||||
|
||||
def _image_to_base64_data_url(image_path: Path, mime_type: Optional[str] = None) -> str:
|
||||
"""
|
||||
Convert an image file to a base64-encoded data URL.
|
||||
|
||||
Args:
|
||||
image_path (Path): Path to the image file
|
||||
mime_type (Optional[str]): MIME type of the image (auto-detected if None)
|
||||
|
||||
Returns:
|
||||
str: Base64-encoded data URL (e.g., "data:image/jpeg;base64,...")
|
||||
"""
|
||||
# Read the image as bytes
|
||||
data = image_path.read_bytes()
|
||||
|
||||
# Encode to base64
|
||||
encoded = base64.b64encode(data).decode("ascii")
|
||||
|
||||
# Determine MIME type
|
||||
mime = mime_type or _determine_mime_type(image_path)
|
||||
|
||||
# Create data URL
|
||||
data_url = f"data:{mime};base64,{encoded}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
async def vision_analyze_tool(
|
||||
image_url: str,
|
||||
user_prompt: str,
|
||||
model: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Analyze an image from a URL or local file path using vision AI.
|
||||
|
||||
This tool accepts either an HTTP/HTTPS URL or a local file path. For URLs,
|
||||
it downloads the image first. In both cases, the image is converted to base64
|
||||
and processed using Gemini 3 Flash Preview via OpenRouter API.
|
||||
|
||||
The user_prompt parameter is expected to be pre-formatted by the calling
|
||||
function (typically model_tools.py) to include both full description
|
||||
requests and specific questions.
|
||||
|
||||
Args:
|
||||
image_url (str): The URL or local file path of the image to analyze.
|
||||
Accepts http://, https:// URLs or absolute/relative file paths.
|
||||
user_prompt (str): The pre-formatted prompt for the vision model
|
||||
model (str): The vision model to use (default: google/gemini-3-flash-preview)
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the analysis results with the following structure:
|
||||
{
|
||||
"success": bool,
|
||||
"analysis": str (defaults to error message if None)
|
||||
}
|
||||
|
||||
Raises:
|
||||
Exception: If download fails, analysis fails, or API key is not set
|
||||
|
||||
Note:
|
||||
- For URLs, temporary images are stored in ./temp_vision_images/ and cleaned up
|
||||
- For local file paths, the file is used directly and NOT deleted
|
||||
- Supports common image formats (JPEG, PNG, GIF, WebP, etc.)
|
||||
"""
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"image_url": image_url,
|
||||
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
|
||||
"model": model
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"analysis_length": 0,
|
||||
"model_used": model,
|
||||
"image_size_bytes": 0
|
||||
}
|
||||
|
||||
temp_image_path = None
|
||||
# Track whether we should clean up the file after processing.
|
||||
# Local files (e.g. from the image cache) should NOT be deleted.
|
||||
should_cleanup = True
|
||||
|
||||
try:
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return json.dumps({"success": False, "error": "Interrupted"})
|
||||
|
||||
logger.info("Analyzing image: %s", image_url[:60])
|
||||
logger.info("User prompt: %s", user_prompt[:100])
|
||||
|
||||
# Determine if this is a local file path or a remote URL
|
||||
local_path = Path(os.path.expanduser(image_url))
|
||||
if local_path.is_file():
|
||||
# Local file path (e.g. from platform image cache) -- skip download
|
||||
logger.info("Using local image file: %s", image_url)
|
||||
temp_image_path = local_path
|
||||
should_cleanup = False # Don't delete cached/local files
|
||||
elif _validate_image_url(image_url):
|
||||
# Remote URL -- download to a temporary location
|
||||
logger.info("Downloading image from URL...")
|
||||
temp_dir = Path("./temp_vision_images")
|
||||
temp_image_path = temp_dir / f"temp_image_{uuid.uuid4()}.jpg"
|
||||
await _download_image(image_url, temp_image_path)
|
||||
should_cleanup = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid image source. Provide an HTTP/HTTPS URL or a valid local file path."
|
||||
)
|
||||
|
||||
# Get image file size for logging
|
||||
image_size_bytes = temp_image_path.stat().st_size
|
||||
image_size_kb = image_size_bytes / 1024
|
||||
logger.info("Image ready (%.1f KB)", image_size_kb)
|
||||
|
||||
# Convert image to base64 data URL
|
||||
logger.info("Converting image to base64...")
|
||||
image_data_url = _image_to_base64_data_url(temp_image_path)
|
||||
# Calculate size in KB for better readability
|
||||
data_size_kb = len(image_data_url) / 1024
|
||||
logger.info("Image converted to base64 (%.1f KB)", data_size_kb)
|
||||
|
||||
debug_call_data["image_size_bytes"] = image_size_bytes
|
||||
|
||||
# Use the prompt as provided (model_tools.py now handles full description formatting)
|
||||
comprehensive_prompt = user_prompt
|
||||
|
||||
# Prepare the message with base64-encoded image
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": comprehensive_prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_data_url
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
logger.info("Processing image with vision model...")
|
||||
|
||||
# Call the vision API via centralized router.
|
||||
# Read timeout from config.yaml (auxiliary.vision.timeout), default 30s.
|
||||
vision_timeout = 30.0
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
_cfg = load_config()
|
||||
_vt = _cfg.get("auxiliary", {}).get("vision", {}).get("timeout")
|
||||
if _vt is not None:
|
||||
vision_timeout = float(_vt)
|
||||
except Exception:
|
||||
pass
|
||||
call_kwargs = {
|
||||
"task": "vision",
|
||||
"messages": messages,
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 2000,
|
||||
"timeout": vision_timeout,
|
||||
}
|
||||
if model:
|
||||
call_kwargs["model"] = model
|
||||
response = await async_call_llm(**call_kwargs)
|
||||
|
||||
# Extract the analysis
|
||||
analysis = response.choices[0].message.content.strip()
|
||||
analysis_length = len(analysis)
|
||||
|
||||
logger.info("Image analysis completed (%s characters)", analysis_length)
|
||||
|
||||
# Prepare successful response
|
||||
result = {
|
||||
"success": True,
|
||||
"analysis": analysis or "There was a problem with the request and the image could not be analyzed."
|
||||
}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["analysis_length"] = analysis_length
|
||||
|
||||
# Log debug information
|
||||
_debug.log_call("vision_analyze_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing image: {str(e)}"
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
|
||||
# Detect vision capability errors — give the model a clear message
|
||||
# so it can inform the user instead of a cryptic API error.
|
||||
err_str = str(e).lower()
|
||||
if any(hint in err_str for hint in (
|
||||
"402", "insufficient", "payment required", "credits", "billing",
|
||||
)):
|
||||
analysis = (
|
||||
"Insufficient credits or payment required. Please top up your "
|
||||
f"API provider account and try again. Error: {e}"
|
||||
)
|
||||
elif any(hint in err_str for hint in (
|
||||
"does not support", "not support image", "invalid_request",
|
||||
"content_policy", "image_url", "multimodal",
|
||||
"unrecognized request argument", "image input",
|
||||
)):
|
||||
analysis = (
|
||||
f"{model} does not support vision or our request was not "
|
||||
f"accepted by the server. Error: {e}"
|
||||
)
|
||||
else:
|
||||
analysis = (
|
||||
"There was a problem with the request and the image could not "
|
||||
f"be analyzed. Error: {e}"
|
||||
)
|
||||
|
||||
# Prepare error response
|
||||
result = {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"analysis": analysis,
|
||||
}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
_debug.log_call("vision_analyze_tool", debug_call_data)
|
||||
_debug.save()
|
||||
|
||||
return json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
finally:
|
||||
# Clean up temporary image file (but NOT local/cached files)
|
||||
if should_cleanup and temp_image_path and temp_image_path.exists():
|
||||
try:
|
||||
temp_image_path.unlink()
|
||||
logger.debug("Cleaned up temporary image file")
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(
|
||||
"Could not delete temporary file: %s", cleanup_error, exc_info=True
|
||||
)
|
||||
|
||||
|
||||
def check_vision_requirements() -> bool:
|
||||
"""Check if the configured runtime vision path can resolve a client."""
|
||||
try:
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
_provider, client, _model = resolve_vision_provider_client()
|
||||
return client is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
return _debug.get_session_info()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Simple test/demo when run directly
|
||||
"""
|
||||
print("👁️ Vision Tools Module")
|
||||
print("=" * 40)
|
||||
|
||||
# Check if vision model is available
|
||||
api_available = check_vision_requirements()
|
||||
|
||||
if not api_available:
|
||||
print("❌ No auxiliary vision model available")
|
||||
print("Configure a supported multimodal backend (OpenRouter, Nous, Codex, Anthropic, or a custom OpenAI-compatible endpoint).")
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ Vision model available")
|
||||
|
||||
print("🛠️ Vision tools ready for use!")
|
||||
|
||||
# Show debug mode status
|
||||
if _debug.active:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {_debug.session_id}")
|
||||
print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{_debug.session_id}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from vision_tools import vision_analyze_tool")
|
||||
print(" import asyncio")
|
||||
print("")
|
||||
print(" async def main():")
|
||||
print(" result = await vision_analyze_tool(")
|
||||
print(" image_url='https://example.com/image.jpg',")
|
||||
print(" user_prompt='What do you see in this image?'")
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'What architectural style is this building?'")
|
||||
print(" - 'Describe the emotions and mood in this image'")
|
||||
print(" - 'What text can you read in this image?'")
|
||||
print(" - 'Identify any safety hazards visible'")
|
||||
print(" - 'What products or brands are shown?'")
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export VISION_TOOLS_DEBUG=true")
|
||||
print(" # Debug logs capture all vision analysis calls and results")
|
||||
print(" # Logs saved to: ./logs/vision_tools_debug_UUID.json")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
from tools.registry import registry
|
||||
|
||||
VISION_ANALYZE_SCHEMA = {
|
||||
"name": "vision_analyze",
|
||||
"description": "Analyze images using AI vision. Provides a comprehensive description and answers a specific question about the image content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {
|
||||
"type": "string",
|
||||
"description": "Image URL (http/https) or local file path to analyze."
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Your specific question or request about the image to resolve. The AI will automatically provide a complete image description AND answer your specific question."
|
||||
}
|
||||
},
|
||||
"required": ["image_url", "question"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _handle_vision_analyze(args: Dict[str, Any], **kw: Any) -> Awaitable[str]:
|
||||
image_url = args.get("image_url", "")
|
||||
question = args.get("question", "")
|
||||
full_prompt = (
|
||||
"Fully describe and explain everything about this image, then answer the "
|
||||
f"following question:\n\n{question}"
|
||||
)
|
||||
model = os.getenv("AUXILIARY_VISION_MODEL", "").strip() or None
|
||||
return vision_analyze_tool(image_url, full_prompt, model)
|
||||
|
||||
|
||||
registry.register(
|
||||
name="vision_analyze",
|
||||
toolset="vision",
|
||||
schema=VISION_ANALYZE_SCHEMA,
|
||||
handler=_handle_vision_analyze,
|
||||
check_fn=check_vision_requirements,
|
||||
is_async=True,
|
||||
emoji="👁️",
|
||||
)
|
||||
793
hermes_code/tools/voice_mode.py
Normal file
793
hermes_code/tools/voice_mode.py
Normal file
|
|
@ -0,0 +1,793 @@
|
|||
"""Voice Mode -- Push-to-talk audio recording and playback for the CLI.
|
||||
|
||||
Provides audio capture via sounddevice, WAV encoding via stdlib wave,
|
||||
STT dispatch via tools.transcription_tools, and TTS playback via
|
||||
sounddevice or system audio players.
|
||||
|
||||
Dependencies (optional):
|
||||
pip install sounddevice numpy
|
||||
or: pip install hermes-agent[voice]
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy audio imports -- never imported at module level to avoid crashing
|
||||
# in headless environments (SSH, Docker, WSL, no PortAudio).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _import_audio():
|
||||
"""Lazy-import sounddevice and numpy. Returns (sd, np).
|
||||
|
||||
Raises ImportError or OSError if the libraries are not available
|
||||
(e.g. PortAudio missing on headless servers).
|
||||
"""
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
return sd, np
|
||||
|
||||
|
||||
def _audio_available() -> bool:
|
||||
"""Return True if audio libraries can be imported."""
|
||||
try:
|
||||
_import_audio()
|
||||
return True
|
||||
except (ImportError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def detect_audio_environment() -> dict:
|
||||
"""Detect if the current environment supports audio I/O.
|
||||
|
||||
Returns dict with 'available' (bool) and 'warnings' (list of strings).
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# SSH detection
|
||||
if any(os.environ.get(v) for v in ('SSH_CLIENT', 'SSH_TTY', 'SSH_CONNECTION')):
|
||||
warnings.append("Running over SSH -- no audio devices available")
|
||||
|
||||
# Docker detection
|
||||
if os.path.exists('/.dockerenv'):
|
||||
warnings.append("Running inside Docker container -- no audio devices")
|
||||
|
||||
# WSL detection
|
||||
try:
|
||||
with open('/proc/version', 'r') as f:
|
||||
if 'microsoft' in f.read().lower():
|
||||
warnings.append("Running in WSL -- audio requires PulseAudio bridge to Windows")
|
||||
except (FileNotFoundError, PermissionError, OSError):
|
||||
pass
|
||||
|
||||
# Check audio libraries
|
||||
try:
|
||||
sd, _ = _import_audio()
|
||||
try:
|
||||
devices = sd.query_devices()
|
||||
if not devices:
|
||||
warnings.append("No audio input/output devices detected")
|
||||
except Exception:
|
||||
warnings.append("Audio subsystem error (PortAudio cannot query devices)")
|
||||
except ImportError:
|
||||
warnings.append("Audio libraries not installed (pip install sounddevice numpy)")
|
||||
except OSError:
|
||||
warnings.append(
|
||||
"PortAudio system library not found -- install it first:\n"
|
||||
" Linux: sudo apt-get install libportaudio2\n"
|
||||
" macOS: brew install portaudio\n"
|
||||
"Then retry /voice on."
|
||||
)
|
||||
|
||||
return {
|
||||
"available": len(warnings) == 0,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recording parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_RATE = 16000 # Whisper native rate
|
||||
CHANNELS = 1 # Mono
|
||||
DTYPE = "int16" # 16-bit PCM
|
||||
SAMPLE_WIDTH = 2 # bytes per sample (int16)
|
||||
MAX_RECORDING_SECONDS = 120 # Safety cap
|
||||
|
||||
# Silence detection defaults
|
||||
SILENCE_RMS_THRESHOLD = 200 # RMS below this = silence (int16 range 0-32767)
|
||||
SILENCE_DURATION_SECONDS = 3.0 # Seconds of continuous silence before auto-stop
|
||||
|
||||
# Temp directory for voice recordings
|
||||
_TEMP_DIR = os.path.join(tempfile.gettempdir(), "hermes_voice")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audio cues (beep tones)
|
||||
# ============================================================================
|
||||
def play_beep(frequency: int = 880, duration: float = 0.12, count: int = 1) -> None:
|
||||
"""Play a short beep tone using numpy + sounddevice.
|
||||
|
||||
Args:
|
||||
frequency: Tone frequency in Hz (default 880 = A5).
|
||||
duration: Duration of each beep in seconds.
|
||||
count: Number of beeps to play (with short gap between).
|
||||
"""
|
||||
try:
|
||||
sd, np = _import_audio()
|
||||
except (ImportError, OSError):
|
||||
return
|
||||
try:
|
||||
gap = 0.06 # seconds between beeps
|
||||
samples_per_beep = int(SAMPLE_RATE * duration)
|
||||
samples_per_gap = int(SAMPLE_RATE * gap)
|
||||
|
||||
parts = []
|
||||
for i in range(count):
|
||||
t = np.linspace(0, duration, samples_per_beep, endpoint=False)
|
||||
# Apply fade in/out to avoid click artifacts
|
||||
tone = np.sin(2 * np.pi * frequency * t)
|
||||
fade_len = min(int(SAMPLE_RATE * 0.01), samples_per_beep // 4)
|
||||
tone[:fade_len] *= np.linspace(0, 1, fade_len)
|
||||
tone[-fade_len:] *= np.linspace(1, 0, fade_len)
|
||||
parts.append((tone * 0.3 * 32767).astype(np.int16))
|
||||
if i < count - 1:
|
||||
parts.append(np.zeros(samples_per_gap, dtype=np.int16))
|
||||
|
||||
audio = np.concatenate(parts)
|
||||
sd.play(audio, samplerate=SAMPLE_RATE)
|
||||
# sd.wait() calls Event.wait() without timeout — hangs forever if the
|
||||
# audio device stalls. Poll with a 2s ceiling and force-stop.
|
||||
deadline = time.monotonic() + 2.0
|
||||
while sd.get_stream() and sd.get_stream().active and time.monotonic() < deadline:
|
||||
time.sleep(0.01)
|
||||
sd.stop()
|
||||
except Exception as e:
|
||||
logger.debug("Beep playback failed: %s", e)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AudioRecorder
|
||||
# ============================================================================
|
||||
class AudioRecorder:
|
||||
"""Thread-safe audio recorder using sounddevice.InputStream.
|
||||
|
||||
Usage::
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start(on_silence_stop=my_callback)
|
||||
# ... user speaks ...
|
||||
wav_path = recorder.stop() # returns path to WAV file
|
||||
# or
|
||||
recorder.cancel() # discard without saving
|
||||
|
||||
If ``on_silence_stop`` is provided, recording automatically stops when
|
||||
the user is silent for ``silence_duration`` seconds and calls the callback.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._stream: Any = None
|
||||
self._frames: List[Any] = []
|
||||
self._recording = False
|
||||
self._start_time: float = 0.0
|
||||
# Silence detection state
|
||||
self._has_spoken = False
|
||||
self._speech_start: float = 0.0 # When speech attempt began
|
||||
self._dip_start: float = 0.0 # When current below-threshold dip began
|
||||
self._min_speech_duration: float = 0.3 # Seconds of speech needed to confirm
|
||||
self._max_dip_tolerance: float = 0.3 # Max dip duration before resetting speech
|
||||
self._silence_start: float = 0.0
|
||||
self._resume_start: float = 0.0 # Tracks sustained speech after silence starts
|
||||
self._resume_dip_start: float = 0.0 # Dip tolerance tracker for resume detection
|
||||
self._on_silence_stop = None
|
||||
self._silence_threshold: int = SILENCE_RMS_THRESHOLD
|
||||
self._silence_duration: float = SILENCE_DURATION_SECONDS
|
||||
self._max_wait: float = 15.0 # Max seconds to wait for speech before auto-stop
|
||||
# Peak RMS seen during recording (for speech presence check in stop())
|
||||
self._peak_rms: int = 0
|
||||
# Live audio level (read by UI for visual feedback)
|
||||
self._current_rms: int = 0
|
||||
|
||||
# -- public properties ---------------------------------------------------
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
return self._recording
|
||||
|
||||
@property
|
||||
def elapsed_seconds(self) -> float:
|
||||
if not self._recording:
|
||||
return 0.0
|
||||
return time.monotonic() - self._start_time
|
||||
|
||||
@property
|
||||
def current_rms(self) -> int:
|
||||
"""Current audio input RMS level (0-32767). Updated each audio chunk."""
|
||||
return self._current_rms
|
||||
|
||||
# -- public methods ------------------------------------------------------
|
||||
|
||||
def _ensure_stream(self) -> None:
|
||||
"""Create the audio InputStream once and keep it alive.
|
||||
|
||||
The stream stays open for the lifetime of the recorder. Between
|
||||
recordings the callback simply discards audio chunks (``_recording``
|
||||
is ``False``). This avoids the CoreAudio bug where closing and
|
||||
re-opening an ``InputStream`` hangs indefinitely on macOS.
|
||||
"""
|
||||
if self._stream is not None:
|
||||
return # already alive
|
||||
|
||||
sd, np = _import_audio()
|
||||
|
||||
def _callback(indata, frames, time_info, status): # noqa: ARG001
|
||||
if status:
|
||||
logger.debug("sounddevice status: %s", status)
|
||||
# When not recording the stream is idle — discard audio.
|
||||
if not self._recording:
|
||||
return
|
||||
self._frames.append(indata.copy())
|
||||
|
||||
# Compute RMS for level display and silence detection
|
||||
rms = int(np.sqrt(np.mean(indata.astype(np.float64) ** 2)))
|
||||
self._current_rms = rms
|
||||
if rms > self._peak_rms:
|
||||
self._peak_rms = rms
|
||||
|
||||
# Silence detection
|
||||
if self._on_silence_stop is not None:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._start_time
|
||||
|
||||
if rms > self._silence_threshold:
|
||||
# Audio is above threshold -- this is speech (or noise).
|
||||
self._dip_start = 0.0 # Reset dip tracker
|
||||
if self._speech_start == 0.0:
|
||||
self._speech_start = now
|
||||
elif not self._has_spoken and now - self._speech_start >= self._min_speech_duration:
|
||||
self._has_spoken = True
|
||||
logger.debug("Speech confirmed (%.2fs above threshold)",
|
||||
now - self._speech_start)
|
||||
# After speech is confirmed, only reset silence timer if
|
||||
# speech is sustained (>0.3s above threshold). Brief
|
||||
# spikes from ambient noise should NOT reset the timer.
|
||||
if not self._has_spoken:
|
||||
self._silence_start = 0.0
|
||||
else:
|
||||
# Track resumed speech with dip tolerance.
|
||||
# Brief dips below threshold are normal during speech,
|
||||
# so we mirror the initial speech detection pattern:
|
||||
# start tracking, tolerate short dips, confirm after 0.3s.
|
||||
self._resume_dip_start = 0.0 # Above threshold — no dip
|
||||
if self._resume_start == 0.0:
|
||||
self._resume_start = now
|
||||
elif now - self._resume_start >= self._min_speech_duration:
|
||||
self._silence_start = 0.0
|
||||
self._resume_start = 0.0
|
||||
elif self._has_spoken:
|
||||
# Below threshold after speech confirmed.
|
||||
# Use dip tolerance before resetting resume tracker —
|
||||
# natural speech has brief dips below threshold.
|
||||
if self._resume_start > 0:
|
||||
if self._resume_dip_start == 0.0:
|
||||
self._resume_dip_start = now
|
||||
elif now - self._resume_dip_start >= self._max_dip_tolerance:
|
||||
# Sustained dip — user actually stopped speaking
|
||||
self._resume_start = 0.0
|
||||
self._resume_dip_start = 0.0
|
||||
elif self._speech_start > 0:
|
||||
# We were in a speech attempt but RMS dipped.
|
||||
# Tolerate brief dips (micro-pauses between syllables).
|
||||
if self._dip_start == 0.0:
|
||||
self._dip_start = now
|
||||
elif now - self._dip_start >= self._max_dip_tolerance:
|
||||
# Dip lasted too long -- genuine silence, reset
|
||||
logger.debug("Speech attempt reset (dip lasted %.2fs)",
|
||||
now - self._dip_start)
|
||||
self._speech_start = 0.0
|
||||
self._dip_start = 0.0
|
||||
|
||||
# Fire silence callback when:
|
||||
# 1. User spoke then went silent for silence_duration, OR
|
||||
# 2. No speech detected at all for max_wait seconds
|
||||
should_fire = False
|
||||
if self._has_spoken and rms <= self._silence_threshold:
|
||||
# User was speaking and now is silent
|
||||
if self._silence_start == 0.0:
|
||||
self._silence_start = now
|
||||
elif now - self._silence_start >= self._silence_duration:
|
||||
logger.info("Silence detected (%.1fs), auto-stopping",
|
||||
self._silence_duration)
|
||||
should_fire = True
|
||||
elif not self._has_spoken and elapsed >= self._max_wait:
|
||||
logger.info("No speech within %.0fs, auto-stopping",
|
||||
self._max_wait)
|
||||
should_fire = True
|
||||
|
||||
if should_fire:
|
||||
with self._lock:
|
||||
cb = self._on_silence_stop
|
||||
self._on_silence_stop = None # fire only once
|
||||
if cb:
|
||||
def _safe_cb():
|
||||
try:
|
||||
cb()
|
||||
except Exception as e:
|
||||
logger.error("Silence callback failed: %s", e, exc_info=True)
|
||||
threading.Thread(target=_safe_cb, daemon=True).start()
|
||||
|
||||
# Create stream — may block on CoreAudio (first call only).
|
||||
stream = None
|
||||
try:
|
||||
stream = sd.InputStream(
|
||||
samplerate=SAMPLE_RATE,
|
||||
channels=CHANNELS,
|
||||
dtype=DTYPE,
|
||||
callback=_callback,
|
||||
)
|
||||
stream.start()
|
||||
except Exception as e:
|
||||
if stream is not None:
|
||||
try:
|
||||
stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise RuntimeError(
|
||||
f"Failed to open audio input stream: {e}. "
|
||||
"Check that a microphone is connected and accessible."
|
||||
) from e
|
||||
self._stream = stream
|
||||
|
||||
def start(self, on_silence_stop=None) -> None:
|
||||
"""Start capturing audio from the default input device.
|
||||
|
||||
The underlying InputStream is created once and kept alive across
|
||||
recordings. Subsequent calls simply reset detection state and
|
||||
toggle frame collection via ``_recording``.
|
||||
|
||||
Args:
|
||||
on_silence_stop: Optional callback invoked (in a daemon thread) when
|
||||
silence is detected after speech. The callback receives no arguments.
|
||||
Use this to auto-stop recording and trigger transcription.
|
||||
|
||||
Raises ``RuntimeError`` if sounddevice/numpy are not installed
|
||||
or if a recording is already in progress.
|
||||
"""
|
||||
try:
|
||||
_import_audio()
|
||||
except (ImportError, OSError) as e:
|
||||
raise RuntimeError(
|
||||
"Voice mode requires sounddevice and numpy.\n"
|
||||
"Install with: pip install sounddevice numpy\n"
|
||||
"Or: pip install hermes-agent[voice]"
|
||||
) from e
|
||||
|
||||
with self._lock:
|
||||
if self._recording:
|
||||
return # already recording
|
||||
|
||||
self._frames = []
|
||||
self._start_time = time.monotonic()
|
||||
self._has_spoken = False
|
||||
self._speech_start = 0.0
|
||||
self._dip_start = 0.0
|
||||
self._silence_start = 0.0
|
||||
self._resume_start = 0.0
|
||||
self._resume_dip_start = 0.0
|
||||
self._peak_rms = 0
|
||||
self._current_rms = 0
|
||||
self._on_silence_stop = on_silence_stop
|
||||
|
||||
# Ensure the persistent stream is alive (no-op after first call).
|
||||
self._ensure_stream()
|
||||
|
||||
with self._lock:
|
||||
self._recording = True
|
||||
logger.info("Voice recording started (rate=%d, channels=%d)", SAMPLE_RATE, CHANNELS)
|
||||
|
||||
def _close_stream_with_timeout(self, timeout: float = 3.0) -> None:
|
||||
"""Close the audio stream with a timeout to prevent CoreAudio hangs."""
|
||||
if self._stream is None:
|
||||
return
|
||||
|
||||
stream = self._stream
|
||||
self._stream = None
|
||||
|
||||
def _do_close():
|
||||
try:
|
||||
stream.stop()
|
||||
stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_do_close, daemon=True)
|
||||
t.start()
|
||||
# Poll in short intervals so Ctrl+C is not blocked
|
||||
deadline = __import__("time").monotonic() + timeout
|
||||
while t.is_alive() and __import__("time").monotonic() < deadline:
|
||||
t.join(timeout=0.1)
|
||||
if t.is_alive():
|
||||
logger.warning("Audio stream close timed out after %.1fs — forcing ahead", timeout)
|
||||
|
||||
def stop(self) -> Optional[str]:
|
||||
"""Stop recording and write captured audio to a WAV file.
|
||||
|
||||
The underlying stream is kept alive for reuse — only frame
|
||||
collection is stopped.
|
||||
|
||||
Returns:
|
||||
Path to the WAV file, or ``None`` if no audio was captured.
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._recording:
|
||||
return None
|
||||
|
||||
self._recording = False
|
||||
self._current_rms = 0
|
||||
# Stream stays alive — no close needed.
|
||||
|
||||
if not self._frames:
|
||||
return None
|
||||
|
||||
# Concatenate frames and write WAV
|
||||
_, np = _import_audio()
|
||||
audio_data = np.concatenate(self._frames, axis=0)
|
||||
self._frames = []
|
||||
|
||||
elapsed = time.monotonic() - self._start_time
|
||||
logger.info("Voice recording stopped (%.1fs, %d samples)", elapsed, len(audio_data))
|
||||
|
||||
# Skip very short recordings (< 0.3s of audio)
|
||||
min_samples = int(SAMPLE_RATE * 0.3)
|
||||
if len(audio_data) < min_samples:
|
||||
logger.debug("Recording too short (%d samples), discarding", len(audio_data))
|
||||
return None
|
||||
|
||||
# Skip silent recordings using peak RMS (not overall average, which
|
||||
# gets diluted by silence at the end of the recording).
|
||||
if self._peak_rms < SILENCE_RMS_THRESHOLD:
|
||||
logger.info("Recording too quiet (peak RMS=%d < %d), discarding",
|
||||
self._peak_rms, SILENCE_RMS_THRESHOLD)
|
||||
return None
|
||||
|
||||
return self._write_wav(audio_data)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Stop recording and discard all captured audio.
|
||||
|
||||
The underlying stream is kept alive for reuse.
|
||||
"""
|
||||
with self._lock:
|
||||
self._recording = False
|
||||
self._frames = []
|
||||
self._on_silence_stop = None
|
||||
self._current_rms = 0
|
||||
logger.info("Voice recording cancelled")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Release the audio stream. Call when voice mode is disabled."""
|
||||
with self._lock:
|
||||
self._recording = False
|
||||
self._frames = []
|
||||
self._on_silence_stop = None
|
||||
# Close stream OUTSIDE the lock to avoid deadlock with audio callback
|
||||
self._close_stream_with_timeout()
|
||||
logger.info("AudioRecorder shut down")
|
||||
|
||||
# -- private helpers -----------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _write_wav(audio_data) -> str:
|
||||
"""Write numpy int16 audio data to a WAV file.
|
||||
|
||||
Returns the file path.
|
||||
"""
|
||||
os.makedirs(_TEMP_DIR, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
wav_path = os.path.join(_TEMP_DIR, f"recording_{timestamp}.wav")
|
||||
|
||||
with wave.open(wav_path, "wb") as wf:
|
||||
wf.setnchannels(CHANNELS)
|
||||
wf.setsampwidth(SAMPLE_WIDTH)
|
||||
wf.setframerate(SAMPLE_RATE)
|
||||
wf.writeframes(audio_data.tobytes())
|
||||
|
||||
file_size = os.path.getsize(wav_path)
|
||||
logger.info("WAV written: %s (%d bytes)", wav_path, file_size)
|
||||
return wav_path
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Whisper hallucination filter
|
||||
# ============================================================================
|
||||
# Whisper commonly hallucinates these phrases on silent/near-silent audio.
|
||||
WHISPER_HALLUCINATIONS = {
|
||||
"thank you.",
|
||||
"thank you",
|
||||
"thanks for watching.",
|
||||
"thanks for watching",
|
||||
"subscribe to my channel.",
|
||||
"subscribe to my channel",
|
||||
"like and subscribe.",
|
||||
"like and subscribe",
|
||||
"please subscribe.",
|
||||
"please subscribe",
|
||||
"thank you for watching.",
|
||||
"thank you for watching",
|
||||
"bye.",
|
||||
"bye",
|
||||
"you",
|
||||
"the end.",
|
||||
"the end",
|
||||
# Non-English hallucinations (common on silence)
|
||||
"продолжение следует",
|
||||
"продолжение следует...",
|
||||
"sous-titres",
|
||||
"sous-titres réalisés par la communauté d'amara.org",
|
||||
"sottotitoli creati dalla comunità amara.org",
|
||||
"untertitel von stephanie geiges",
|
||||
"amara.org",
|
||||
"www.mooji.org",
|
||||
"ご視聴ありがとうございました",
|
||||
}
|
||||
|
||||
# Regex patterns for repetitive hallucinations (e.g. "Thank you. Thank you. Thank you.")
|
||||
_HALLUCINATION_REPEAT_RE = re.compile(
|
||||
r'^(?:thank you|thanks|bye|you|ok|okay|the end|\.|\s|,|!)+$',
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_whisper_hallucination(transcript: str) -> bool:
|
||||
"""Check if a transcript is a known Whisper hallucination on silence."""
|
||||
cleaned = transcript.strip().lower()
|
||||
if not cleaned:
|
||||
return True
|
||||
# Exact match against known phrases
|
||||
if cleaned.rstrip('.!') in WHISPER_HALLUCINATIONS or cleaned in WHISPER_HALLUCINATIONS:
|
||||
return True
|
||||
# Repetitive patterns (e.g. "Thank you. Thank you. Thank you. you")
|
||||
if _HALLUCINATION_REPEAT_RE.match(cleaned):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STT dispatch
|
||||
# ============================================================================
|
||||
def transcribe_recording(wav_path: str, model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Transcribe a WAV recording using the existing Whisper pipeline.
|
||||
|
||||
Delegates to ``tools.transcription_tools.transcribe_audio()``.
|
||||
Filters out known Whisper hallucinations on silent audio.
|
||||
|
||||
Args:
|
||||
wav_path: Path to the WAV file.
|
||||
model: Whisper model name (default: from config or ``whisper-1``).
|
||||
|
||||
Returns:
|
||||
Dict with ``success``, ``transcript``, and optionally ``error``.
|
||||
"""
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
|
||||
result = transcribe_audio(wav_path, model=model)
|
||||
|
||||
# Filter out Whisper hallucinations (common on silent/near-silent audio)
|
||||
if result.get("success") and is_whisper_hallucination(result.get("transcript", "")):
|
||||
logger.info("Filtered Whisper hallucination: %r", result["transcript"])
|
||||
return {"success": True, "transcript": "", "filtered": True}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audio playback (interruptable)
|
||||
# ============================================================================
|
||||
|
||||
# Global reference to the active playback process so it can be interrupted.
|
||||
_active_playback: Optional[subprocess.Popen] = None
|
||||
_playback_lock = threading.Lock()
|
||||
|
||||
|
||||
def stop_playback() -> None:
|
||||
"""Interrupt the currently playing audio (if any)."""
|
||||
global _active_playback
|
||||
with _playback_lock:
|
||||
proc = _active_playback
|
||||
_active_playback = None
|
||||
if proc and proc.poll() is None:
|
||||
try:
|
||||
proc.terminate()
|
||||
logger.info("Audio playback interrupted")
|
||||
except Exception:
|
||||
pass
|
||||
# Also stop sounddevice playback if active
|
||||
try:
|
||||
sd, _ = _import_audio()
|
||||
sd.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def play_audio_file(file_path: str) -> bool:
|
||||
"""Play an audio file through the default output device.
|
||||
|
||||
Strategy:
|
||||
1. WAV files via ``sounddevice.play()`` when available.
|
||||
2. System commands: ``afplay`` (macOS), ``ffplay`` (cross-platform),
|
||||
``aplay`` (Linux ALSA).
|
||||
|
||||
Playback can be interrupted by calling ``stop_playback()``.
|
||||
|
||||
Returns:
|
||||
``True`` if playback succeeded, ``False`` otherwise.
|
||||
"""
|
||||
global _active_playback
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("Audio file not found: %s", file_path)
|
||||
return False
|
||||
|
||||
# Try sounddevice for WAV files
|
||||
if file_path.endswith(".wav"):
|
||||
try:
|
||||
sd, np = _import_audio()
|
||||
with wave.open(file_path, "rb") as wf:
|
||||
frames = wf.readframes(wf.getnframes())
|
||||
audio_data = np.frombuffer(frames, dtype=np.int16)
|
||||
sample_rate = wf.getframerate()
|
||||
|
||||
sd.play(audio_data, samplerate=sample_rate)
|
||||
# sd.wait() calls Event.wait() without timeout — hangs forever if
|
||||
# the audio device stalls. Poll with a ceiling and force-stop.
|
||||
duration_secs = len(audio_data) / sample_rate
|
||||
deadline = time.monotonic() + duration_secs + 2.0
|
||||
while sd.get_stream() and sd.get_stream().active and time.monotonic() < deadline:
|
||||
time.sleep(0.01)
|
||||
sd.stop()
|
||||
return True
|
||||
except (ImportError, OSError):
|
||||
pass # audio libs not available, fall through to system players
|
||||
except Exception as e:
|
||||
logger.debug("sounddevice playback failed: %s", e)
|
||||
|
||||
# Fall back to system audio players (using Popen for interruptability)
|
||||
system = platform.system()
|
||||
players = []
|
||||
|
||||
if system == "Darwin":
|
||||
players.append(["afplay", file_path])
|
||||
players.append(["ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", file_path])
|
||||
if system == "Linux":
|
||||
players.append(["aplay", "-q", file_path])
|
||||
|
||||
for cmd in players:
|
||||
exe = shutil.which(cmd[0])
|
||||
if exe:
|
||||
try:
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
with _playback_lock:
|
||||
_active_playback = proc
|
||||
proc.wait(timeout=300)
|
||||
with _playback_lock:
|
||||
_active_playback = None
|
||||
return True
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("System player %s timed out, killing process", cmd[0])
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
with _playback_lock:
|
||||
_active_playback = None
|
||||
except Exception as e:
|
||||
logger.debug("System player %s failed: %s", cmd[0], e)
|
||||
with _playback_lock:
|
||||
_active_playback = None
|
||||
|
||||
logger.warning("No audio player available for %s", file_path)
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Requirements check
|
||||
# ============================================================================
|
||||
def check_voice_requirements() -> Dict[str, Any]:
|
||||
"""Check if all voice mode requirements are met.
|
||||
|
||||
Returns:
|
||||
Dict with ``available``, ``audio_available``, ``stt_available``,
|
||||
``missing_packages``, and ``details``.
|
||||
"""
|
||||
# Determine STT provider availability
|
||||
from tools.transcription_tools import _get_provider, _load_stt_config, is_stt_enabled, _HAS_FASTER_WHISPER
|
||||
stt_config = _load_stt_config()
|
||||
stt_enabled = is_stt_enabled(stt_config)
|
||||
stt_provider = _get_provider(stt_config)
|
||||
stt_available = stt_enabled and stt_provider != "none"
|
||||
|
||||
missing: List[str] = []
|
||||
has_audio = _audio_available()
|
||||
|
||||
if not has_audio:
|
||||
missing.extend(["sounddevice", "numpy"])
|
||||
|
||||
# Environment detection
|
||||
env_check = detect_audio_environment()
|
||||
|
||||
available = has_audio and stt_available and env_check["available"]
|
||||
details_parts = []
|
||||
|
||||
if has_audio:
|
||||
details_parts.append("Audio capture: OK")
|
||||
else:
|
||||
details_parts.append("Audio capture: MISSING (pip install sounddevice numpy)")
|
||||
|
||||
if not stt_enabled:
|
||||
details_parts.append("STT provider: DISABLED in config (stt.enabled: false)")
|
||||
elif stt_provider == "local":
|
||||
details_parts.append("STT provider: OK (local faster-whisper)")
|
||||
elif stt_provider == "groq":
|
||||
details_parts.append("STT provider: OK (Groq)")
|
||||
elif stt_provider == "openai":
|
||||
details_parts.append("STT provider: OK (OpenAI)")
|
||||
else:
|
||||
details_parts.append(
|
||||
"STT provider: MISSING (pip install faster-whisper, "
|
||||
"or set GROQ_API_KEY / VOICE_TOOLS_OPENAI_KEY)"
|
||||
)
|
||||
|
||||
for warning in env_check["warnings"]:
|
||||
details_parts.append(f"Environment: {warning}")
|
||||
|
||||
return {
|
||||
"available": available,
|
||||
"audio_available": has_audio,
|
||||
"stt_available": stt_available,
|
||||
"missing_packages": missing,
|
||||
"details": "\n".join(details_parts),
|
||||
"environment": env_check,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Temp file cleanup
|
||||
# ============================================================================
|
||||
def cleanup_temp_recordings(max_age_seconds: int = 3600) -> int:
|
||||
"""Remove old temporary voice recording files.
|
||||
|
||||
Args:
|
||||
max_age_seconds: Delete files older than this (default: 1 hour).
|
||||
|
||||
Returns:
|
||||
Number of files deleted.
|
||||
"""
|
||||
if not os.path.isdir(_TEMP_DIR):
|
||||
return 0
|
||||
|
||||
deleted = 0
|
||||
now = time.time()
|
||||
|
||||
for entry in os.scandir(_TEMP_DIR):
|
||||
if entry.is_file() and entry.name.startswith("recording_") and entry.name.endswith(".wav"):
|
||||
try:
|
||||
age = now - entry.stat().st_mtime
|
||||
if age > max_age_seconds:
|
||||
os.unlink(entry.path)
|
||||
deleted += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if deleted:
|
||||
logger.debug("Cleaned up %d old voice recordings", deleted)
|
||||
return deleted
|
||||
1727
hermes_code/tools/web_tools.py
Normal file
1727
hermes_code/tools/web_tools.py
Normal file
File diff suppressed because it is too large
Load diff
285
hermes_code/tools/website_policy.py
Normal file
285
hermes_code/tools/website_policy.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""Website access policy helpers for URL-capable tools.
|
||||
|
||||
This module loads a user-managed website blocklist from ~/.hermes/config.yaml
|
||||
and optional shared list files. It is intentionally lightweight so web/browser
|
||||
tools can enforce URL policy without pulling in the heavier CLI config stack.
|
||||
|
||||
Policy is cached in memory with a short TTL so config changes take effect
|
||||
quickly without re-reading the file on every URL check.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_WEBSITE_BLOCKLIST = {
|
||||
"enabled": False,
|
||||
"domains": [],
|
||||
"shared_files": [],
|
||||
}
|
||||
|
||||
# Cache: parsed policy + timestamp. Avoids re-reading config.yaml on every
|
||||
# URL check (a web_crawl with 50 pages would otherwise mean 51 YAML parses).
|
||||
_CACHE_TTL_SECONDS = 30.0
|
||||
_cache_lock = threading.Lock()
|
||||
_cached_policy: Optional[Dict[str, Any]] = None
|
||||
_cached_policy_path: Optional[str] = None
|
||||
_cached_policy_time: float = 0.0
|
||||
|
||||
|
||||
def _get_hermes_home() -> Path:
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def _get_default_config_path() -> Path:
|
||||
return _get_hermes_home() / "config.yaml"
|
||||
|
||||
|
||||
class WebsitePolicyError(Exception):
|
||||
"""Raised when a website policy file is malformed."""
|
||||
|
||||
|
||||
def _normalize_host(host: str) -> str:
|
||||
return (host or "").strip().lower().rstrip(".")
|
||||
|
||||
|
||||
def _normalize_rule(rule: Any) -> Optional[str]:
|
||||
if not isinstance(rule, str):
|
||||
return None
|
||||
value = rule.strip().lower()
|
||||
if not value or value.startswith("#"):
|
||||
return None
|
||||
if "://" in value:
|
||||
parsed = urlparse(value)
|
||||
value = parsed.netloc or parsed.path
|
||||
value = value.split("/", 1)[0].strip().rstrip(".")
|
||||
if value.startswith("www."):
|
||||
value = value[4:]
|
||||
return value or None
|
||||
|
||||
|
||||
def _iter_blocklist_file_rules(path: Path) -> List[str]:
|
||||
"""Load rules from a shared blocklist file.
|
||||
|
||||
Missing or unreadable files log a warning and return an empty list
|
||||
rather than raising — a bad file path should not disable all web tools.
|
||||
"""
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.warning("Shared blocklist file not found (skipping): %s", path)
|
||||
return []
|
||||
except (OSError, UnicodeDecodeError) as exc:
|
||||
logger.warning("Failed to read shared blocklist file %s (skipping): %s", path, exc)
|
||||
return []
|
||||
|
||||
rules: List[str] = []
|
||||
for line in raw.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
normalized = _normalize_rule(stripped)
|
||||
if normalized:
|
||||
rules.append(normalized)
|
||||
return rules
|
||||
|
||||
|
||||
def _load_policy_config(config_path: Optional[Path] = None) -> Dict[str, Any]:
|
||||
config_path = config_path or _get_default_config_path()
|
||||
if not config_path.exists():
|
||||
return dict(_DEFAULT_WEBSITE_BLOCKLIST)
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
logger.debug("PyYAML not installed — website blocklist disabled")
|
||||
return dict(_DEFAULT_WEBSITE_BLOCKLIST)
|
||||
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise WebsitePolicyError(f"Invalid config YAML at {config_path}: {exc}") from exc
|
||||
except OSError as exc:
|
||||
raise WebsitePolicyError(f"Failed to read config file {config_path}: {exc}") from exc
|
||||
if not isinstance(config, dict):
|
||||
raise WebsitePolicyError("config root must be a mapping")
|
||||
|
||||
security = config.get("security", {})
|
||||
if security is None:
|
||||
security = {}
|
||||
if not isinstance(security, dict):
|
||||
raise WebsitePolicyError("security must be a mapping")
|
||||
|
||||
website_blocklist = security.get("website_blocklist", {})
|
||||
if website_blocklist is None:
|
||||
website_blocklist = {}
|
||||
if not isinstance(website_blocklist, dict):
|
||||
raise WebsitePolicyError("security.website_blocklist must be a mapping")
|
||||
|
||||
policy = dict(_DEFAULT_WEBSITE_BLOCKLIST)
|
||||
policy.update(website_blocklist)
|
||||
return policy
|
||||
|
||||
|
||||
def load_website_blocklist(config_path: Optional[Path] = None) -> Dict[str, Any]:
|
||||
"""Load and return the parsed website blocklist policy.
|
||||
|
||||
Results are cached for ``_CACHE_TTL_SECONDS`` to avoid re-reading
|
||||
config.yaml on every URL check. Pass an explicit ``config_path``
|
||||
to bypass the cache (used by tests).
|
||||
"""
|
||||
global _cached_policy, _cached_policy_path, _cached_policy_time
|
||||
|
||||
resolved_path = str(config_path) if config_path else "__default__"
|
||||
now = time.monotonic()
|
||||
|
||||
# Return cached policy if still fresh and same path
|
||||
if config_path is None:
|
||||
with _cache_lock:
|
||||
if (
|
||||
_cached_policy is not None
|
||||
and _cached_policy_path == resolved_path
|
||||
and (now - _cached_policy_time) < _CACHE_TTL_SECONDS
|
||||
):
|
||||
return _cached_policy
|
||||
|
||||
config_path = config_path or _get_default_config_path()
|
||||
policy = _load_policy_config(config_path)
|
||||
|
||||
raw_domains = policy.get("domains", []) or []
|
||||
if not isinstance(raw_domains, list):
|
||||
raise WebsitePolicyError("security.website_blocklist.domains must be a list")
|
||||
|
||||
raw_shared_files = policy.get("shared_files", []) or []
|
||||
if not isinstance(raw_shared_files, list):
|
||||
raise WebsitePolicyError("security.website_blocklist.shared_files must be a list")
|
||||
|
||||
enabled = policy.get("enabled", True)
|
||||
if not isinstance(enabled, bool):
|
||||
raise WebsitePolicyError("security.website_blocklist.enabled must be a boolean")
|
||||
|
||||
rules: List[Dict[str, str]] = []
|
||||
seen: set[Tuple[str, str]] = set()
|
||||
|
||||
for raw_rule in raw_domains:
|
||||
normalized = _normalize_rule(raw_rule)
|
||||
if normalized and ("config", normalized) not in seen:
|
||||
rules.append({"pattern": normalized, "source": "config"})
|
||||
seen.add(("config", normalized))
|
||||
|
||||
for shared_file in raw_shared_files:
|
||||
if not isinstance(shared_file, str) or not shared_file.strip():
|
||||
continue
|
||||
path = Path(shared_file).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = (_get_hermes_home() / path).resolve()
|
||||
for normalized in _iter_blocklist_file_rules(path):
|
||||
key = (str(path), normalized)
|
||||
if key in seen:
|
||||
continue
|
||||
rules.append({"pattern": normalized, "source": str(path)})
|
||||
seen.add(key)
|
||||
|
||||
result = {"enabled": enabled, "rules": rules}
|
||||
|
||||
# Cache the result (only for the default path — explicit paths are tests)
|
||||
if config_path == _get_default_config_path():
|
||||
with _cache_lock:
|
||||
_cached_policy = result
|
||||
_cached_policy_path = "__default__"
|
||||
_cached_policy_time = now
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def invalidate_cache() -> None:
|
||||
"""Force the next ``check_website_access`` call to re-read config."""
|
||||
global _cached_policy
|
||||
with _cache_lock:
|
||||
_cached_policy = None
|
||||
|
||||
|
||||
def _match_host_against_rule(host: str, pattern: str) -> bool:
|
||||
if not host or not pattern:
|
||||
return False
|
||||
if pattern.startswith("*."):
|
||||
return fnmatch.fnmatch(host, pattern)
|
||||
return host == pattern or host.endswith(f".{pattern}")
|
||||
|
||||
|
||||
def _extract_host_from_urlish(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
host = _normalize_host(parsed.hostname or parsed.netloc)
|
||||
if host:
|
||||
return host
|
||||
|
||||
if "://" not in url:
|
||||
schemeless = urlparse(f"//{url}")
|
||||
host = _normalize_host(schemeless.hostname or schemeless.netloc)
|
||||
if host:
|
||||
return host
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def check_website_access(url: str, config_path: Optional[Path] = None) -> Optional[Dict[str, str]]:
|
||||
"""Check whether a URL is allowed by the website blocklist policy.
|
||||
|
||||
Returns ``None`` if access is allowed, or a dict with block metadata
|
||||
(``host``, ``rule``, ``source``, ``message``) if blocked.
|
||||
|
||||
Never raises on policy errors — logs a warning and returns ``None``
|
||||
(fail-open) so a config typo doesn't break all web tools. Pass
|
||||
``config_path`` explicitly (tests) to get strict error propagation.
|
||||
"""
|
||||
# Fast path: if no explicit config_path and the cached policy is disabled
|
||||
# or empty, skip all work (no YAML read, no host extraction).
|
||||
if config_path is None:
|
||||
with _cache_lock:
|
||||
if _cached_policy is not None and not _cached_policy.get("enabled"):
|
||||
return None
|
||||
|
||||
host = _extract_host_from_urlish(url)
|
||||
if not host:
|
||||
return None
|
||||
|
||||
try:
|
||||
policy = load_website_blocklist(config_path)
|
||||
except WebsitePolicyError as exc:
|
||||
if config_path is not None:
|
||||
raise # Tests pass explicit paths — let errors propagate
|
||||
logger.warning("Website policy config error (failing open): %s", exc)
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.warning("Unexpected error loading website policy (failing open): %s", exc)
|
||||
return None
|
||||
|
||||
if not policy.get("enabled"):
|
||||
return None
|
||||
|
||||
for rule in policy.get("rules", []):
|
||||
pattern = rule.get("pattern", "")
|
||||
if _match_host_against_rule(host, pattern):
|
||||
logger.info("Blocked URL %s — matched rule '%s' from %s",
|
||||
url, pattern, rule.get("source", "config"))
|
||||
return {
|
||||
"url": url,
|
||||
"host": host,
|
||||
"rule": pattern,
|
||||
"source": rule.get("source", "config"),
|
||||
"message": (
|
||||
f"Blocked by website policy: '{host}' matched rule '{pattern}'"
|
||||
f" from {rule.get('source', 'config')}"
|
||||
),
|
||||
}
|
||||
return None
|
||||
Loading…
Add table
Add a link
Reference in a new issue