merge: resolve conflicts with origin/main
This commit is contained in:
commit
0897e4350e
100 changed files with 11637 additions and 1337 deletions
|
|
@ -65,6 +65,11 @@ import requests
|
|||
from typing import Dict, Any, Optional, List
|
||||
from pathlib import Path
|
||||
from agent.auxiliary_client import call_llm
|
||||
|
||||
try:
|
||||
from tools.website_policy import check_website_access
|
||||
except Exception:
|
||||
check_website_access = lambda url: None # noqa: E731 — fail-open if policy module unavailable
|
||||
from tools.browser_providers.base import CloudBrowserProvider
|
||||
from tools.browser_providers.browserbase import BrowserbaseProvider
|
||||
from tools.browser_providers.browser_use import BrowserUseProvider
|
||||
|
|
@ -550,6 +555,11 @@ def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
|
|||
session_info = provider.create_session(task_id)
|
||||
|
||||
with _cleanup_lock:
|
||||
# Double-check: another thread may have created a session while we
|
||||
# were doing the network call. Use the existing one to avoid leaking
|
||||
# orphan cloud sessions.
|
||||
if task_id in _active_sessions:
|
||||
return _active_sessions[task_id]
|
||||
_active_sessions[task_id] = session_info
|
||||
|
||||
return session_info
|
||||
|
|
@ -901,6 +911,15 @@ def browser_navigate(url: str, task_id: Optional[str] = None) -> str:
|
|||
Returns:
|
||||
JSON string with navigation result (includes stealth features info on first nav)
|
||||
"""
|
||||
# Website policy check — block before navigating
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
|
||||
})
|
||||
|
||||
effective_task_id = task_id or "default"
|
||||
|
||||
# Get session info to check if this is a new session
|
||||
|
|
|
|||
|
|
@ -16,13 +16,10 @@ The parent's context only sees the delegation call and the summary result,
|
|||
never the child's intermediate tool calls or reasoning.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
|
@ -150,7 +147,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||
return _callback
|
||||
|
||||
|
||||
def _run_single_child(
|
||||
def _build_child_agent(
|
||||
task_index: int,
|
||||
goal: str,
|
||||
context: Optional[str],
|
||||
|
|
@ -158,16 +155,15 @@ def _run_single_child(
|
|||
model: Optional[str],
|
||||
max_iterations: int,
|
||||
parent_agent,
|
||||
task_count: int = 1,
|
||||
# Credential overrides from delegation config (provider:model resolution)
|
||||
override_provider: Optional[str] = None,
|
||||
override_base_url: Optional[str] = None,
|
||||
override_api_key: Optional[str] = None,
|
||||
override_api_mode: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
):
|
||||
"""
|
||||
Spawn and run a single child agent. Called from within a thread.
|
||||
Returns a structured result dict.
|
||||
Build a child AIAgent on the main thread (thread-safe construction).
|
||||
Returns the constructed child agent without running it.
|
||||
|
||||
When override_* params are set (from delegation config), the child uses
|
||||
those credentials instead of inheriting from the parent. This enables
|
||||
|
|
@ -176,8 +172,6 @@ def _run_single_child(
|
|||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
child_start = time.monotonic()
|
||||
|
||||
# When no explicit toolsets given, inherit from parent's enabled toolsets
|
||||
# so disabled tools (e.g. web) don't leak to subagents.
|
||||
if toolsets:
|
||||
|
|
@ -188,65 +182,84 @@ def _run_single_child(
|
|||
child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS)
|
||||
|
||||
child_prompt = _build_child_system_prompt(goal, context)
|
||||
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
|
||||
parent_api_key = getattr(parent_agent, "api_key", None)
|
||||
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
|
||||
parent_api_key = parent_agent._client_kwargs.get("api_key")
|
||||
|
||||
try:
|
||||
# Extract parent's API key so subagents inherit auth (e.g. Nous Portal).
|
||||
parent_api_key = getattr(parent_agent, "api_key", None)
|
||||
if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"):
|
||||
parent_api_key = parent_agent._client_kwargs.get("api_key")
|
||||
# Build progress callback to relay tool calls to parent display
|
||||
child_progress_cb = _build_child_progress_callback(task_index, parent_agent)
|
||||
|
||||
# Build progress callback to relay tool calls to parent display
|
||||
child_progress_cb = _build_child_progress_callback(task_index, parent_agent, task_count)
|
||||
# Share the parent's iteration budget so subagent tool calls
|
||||
# count toward the session-wide limit.
|
||||
shared_budget = getattr(parent_agent, "iteration_budget", None)
|
||||
|
||||
# Share the parent's iteration budget so subagent tool calls
|
||||
# count toward the session-wide limit.
|
||||
shared_budget = getattr(parent_agent, "iteration_budget", None)
|
||||
# Resolve effective credentials: config override > parent inherit
|
||||
effective_model = model or parent_agent.model
|
||||
effective_provider = override_provider or getattr(parent_agent, "provider", None)
|
||||
effective_base_url = override_base_url or parent_agent.base_url
|
||||
effective_api_key = override_api_key or parent_api_key
|
||||
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None)
|
||||
|
||||
# Resolve effective credentials: config override > parent inherit
|
||||
effective_model = model or parent_agent.model
|
||||
effective_provider = override_provider or getattr(parent_agent, "provider", None)
|
||||
effective_base_url = override_base_url or parent_agent.base_url
|
||||
effective_api_key = override_api_key or parent_api_key
|
||||
effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None)
|
||||
child = AIAgent(
|
||||
base_url=effective_base_url,
|
||||
api_key=effective_api_key,
|
||||
model=effective_model,
|
||||
provider=effective_provider,
|
||||
api_mode=effective_api_mode,
|
||||
max_iterations=max_iterations,
|
||||
max_tokens=getattr(parent_agent, "max_tokens", None),
|
||||
reasoning_config=getattr(parent_agent, "reasoning_config", None),
|
||||
prefill_messages=getattr(parent_agent, "prefill_messages", None),
|
||||
enabled_toolsets=child_toolsets,
|
||||
quiet_mode=True,
|
||||
ephemeral_system_prompt=child_prompt,
|
||||
log_prefix=f"[subagent-{task_index}]",
|
||||
platform=parent_agent.platform,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
clarify_callback=None,
|
||||
session_db=getattr(parent_agent, '_session_db', None),
|
||||
providers_allowed=parent_agent.providers_allowed,
|
||||
providers_ignored=parent_agent.providers_ignored,
|
||||
providers_order=parent_agent.providers_order,
|
||||
provider_sort=parent_agent.provider_sort,
|
||||
tool_progress_callback=child_progress_cb,
|
||||
iteration_budget=shared_budget,
|
||||
)
|
||||
|
||||
child = AIAgent(
|
||||
base_url=effective_base_url,
|
||||
api_key=effective_api_key,
|
||||
model=effective_model,
|
||||
provider=effective_provider,
|
||||
api_mode=effective_api_mode,
|
||||
max_iterations=max_iterations,
|
||||
max_tokens=getattr(parent_agent, "max_tokens", None),
|
||||
reasoning_config=getattr(parent_agent, "reasoning_config", None),
|
||||
prefill_messages=getattr(parent_agent, "prefill_messages", None),
|
||||
enabled_toolsets=child_toolsets,
|
||||
quiet_mode=True,
|
||||
ephemeral_system_prompt=child_prompt,
|
||||
log_prefix=f"[subagent-{task_index}]",
|
||||
platform=parent_agent.platform,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
clarify_callback=None,
|
||||
session_db=getattr(parent_agent, '_session_db', None),
|
||||
providers_allowed=parent_agent.providers_allowed,
|
||||
providers_ignored=parent_agent.providers_ignored,
|
||||
providers_order=parent_agent.providers_order,
|
||||
provider_sort=parent_agent.provider_sort,
|
||||
tool_progress_callback=child_progress_cb,
|
||||
iteration_budget=shared_budget,
|
||||
)
|
||||
# Set delegation depth so children can't spawn grandchildren
|
||||
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
||||
|
||||
# Set delegation depth so children can't spawn grandchildren
|
||||
child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1
|
||||
|
||||
# Register child for interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
# Register child for interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
lock = getattr(parent_agent, '_active_children_lock', None)
|
||||
if lock:
|
||||
with lock:
|
||||
parent_agent._active_children.append(child)
|
||||
else:
|
||||
parent_agent._active_children.append(child)
|
||||
|
||||
# Run with stdout/stderr suppressed to prevent interleaved output
|
||||
devnull = io.StringIO()
|
||||
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
|
||||
result = child.run_conversation(user_message=goal)
|
||||
return child
|
||||
|
||||
def _run_single_child(
|
||||
task_index: int,
|
||||
goal: str,
|
||||
child=None,
|
||||
parent_agent=None,
|
||||
**_kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a pre-built child agent. Called from within a thread.
|
||||
Returns a structured result dict.
|
||||
"""
|
||||
child_start = time.monotonic()
|
||||
|
||||
# Get the progress callback from the child agent
|
||||
child_progress_cb = getattr(child, 'tool_progress_callback', None)
|
||||
|
||||
try:
|
||||
result = child.run_conversation(user_message=goal)
|
||||
|
||||
# Flush any remaining batched progress to gateway
|
||||
if child_progress_cb and hasattr(child_progress_cb, '_flush'):
|
||||
|
|
@ -355,11 +368,15 @@ def _run_single_child(
|
|||
# Unregister child from interrupt propagation
|
||||
if hasattr(parent_agent, '_active_children'):
|
||||
try:
|
||||
parent_agent._active_children.remove(child)
|
||||
lock = getattr(parent_agent, '_active_children_lock', None)
|
||||
if lock:
|
||||
with lock:
|
||||
parent_agent._active_children.remove(child)
|
||||
else:
|
||||
parent_agent._active_children.remove(child)
|
||||
except (ValueError, UnboundLocalError) as e:
|
||||
logger.debug("Could not remove child from active_children: %s", e)
|
||||
|
||||
|
||||
def delegate_task(
|
||||
goal: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
|
|
@ -428,51 +445,38 @@ def delegate_task(
|
|||
# Track goal labels for progress display (truncated for readability)
|
||||
task_labels = [t["goal"][:40] for t in task_list]
|
||||
|
||||
if n_tasks == 1:
|
||||
# Single task -- run directly (no thread pool overhead)
|
||||
t = task_list[0]
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal=t["goal"],
|
||||
context=t.get("context"),
|
||||
toolsets=t.get("toolsets") or toolsets,
|
||||
model=creds["model"],
|
||||
max_iterations=effective_max_iter,
|
||||
parent_agent=parent_agent,
|
||||
task_count=1,
|
||||
override_provider=creds["provider"],
|
||||
override_base_url=creds["base_url"],
|
||||
# Build all child agents on the main thread (thread-safe construction)
|
||||
children = []
|
||||
for i, t in enumerate(task_list):
|
||||
child = _build_child_agent(
|
||||
task_index=i, goal=t["goal"], context=t.get("context"),
|
||||
toolsets=t.get("toolsets") or toolsets, model=creds["model"],
|
||||
max_iterations=effective_max_iter, parent_agent=parent_agent,
|
||||
override_provider=creds["provider"], override_base_url=creds["base_url"],
|
||||
override_api_key=creds["api_key"],
|
||||
override_api_mode=creds["api_mode"],
|
||||
)
|
||||
children.append((i, t, child))
|
||||
|
||||
if n_tasks == 1:
|
||||
# Single task -- run directly (no thread pool overhead)
|
||||
_i, _t, child = children[0]
|
||||
result = _run_single_child(0, _t["goal"], child, parent_agent)
|
||||
results.append(result)
|
||||
else:
|
||||
# Batch -- run in parallel with per-task progress lines
|
||||
completed_count = 0
|
||||
spinner_ref = getattr(parent_agent, '_delegate_spinner', None)
|
||||
|
||||
# Save stdout/stderr before the executor — redirect_stdout in child
|
||||
# threads races on sys.stdout and can leave it as devnull permanently.
|
||||
_saved_stdout = sys.stdout
|
||||
_saved_stderr = sys.stderr
|
||||
|
||||
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor:
|
||||
futures = {}
|
||||
for i, t in enumerate(task_list):
|
||||
for i, t, child in children:
|
||||
future = executor.submit(
|
||||
_run_single_child,
|
||||
task_index=i,
|
||||
goal=t["goal"],
|
||||
context=t.get("context"),
|
||||
toolsets=t.get("toolsets") or toolsets,
|
||||
model=creds["model"],
|
||||
max_iterations=effective_max_iter,
|
||||
child=child,
|
||||
parent_agent=parent_agent,
|
||||
task_count=n_tasks,
|
||||
override_provider=creds["provider"],
|
||||
override_base_url=creds["base_url"],
|
||||
override_api_key=creds["api_key"],
|
||||
override_api_mode=creds["api_mode"],
|
||||
)
|
||||
futures[future] = i
|
||||
|
||||
|
|
@ -515,10 +519,6 @@ def delegate_task(
|
|||
except Exception as e:
|
||||
logger.debug("Spinner update_text failed: %s", e)
|
||||
|
||||
# Restore stdout/stderr in case redirect_stdout race left them as devnull
|
||||
sys.stdout = _saved_stdout
|
||||
sys.stderr = _saved_stderr
|
||||
|
||||
# Sort by task_index so results match input order
|
||||
results.sort(key=lambda r: r["task_index"])
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,9 @@ def _build_provider_env_blocklist() -> frozenset:
|
|||
"FIREWORKS_API_KEY", # Fireworks AI
|
||||
"XAI_API_KEY", # xAI (Grok)
|
||||
"HELICONE_API_KEY", # LLM Observability proxy
|
||||
"PARALLEL_API_KEY",
|
||||
"FIRECRAWL_API_KEY",
|
||||
"FIRECRAWL_API_URL",
|
||||
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
|
||||
"TELEGRAM_HOME_CHANNEL",
|
||||
"TELEGRAM_HOME_CHANNEL_NAME",
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ def _get_safe_write_root() -> Optional[str]:
|
|||
|
||||
def _is_write_denied(path: str) -> bool:
|
||||
"""Return True if path is on the write deny list."""
|
||||
resolved = os.path.realpath(os.path.expanduser(path))
|
||||
resolved = os.path.realpath(os.path.expanduser(str(path)))
|
||||
|
||||
# 1) Static deny list
|
||||
if resolved in WRITE_DENIED_PATHS:
|
||||
|
|
|
|||
|
|
@ -254,10 +254,9 @@ def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, in
|
|||
|
||||
if '\n'.join(check_lines) == modified_pattern:
|
||||
# Found match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
start_pos, end_pos = _calculate_line_positions(
|
||||
content_lines, i, i + pattern_line_count, len(content)
|
||||
)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
|
@ -309,9 +308,10 @@ def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
|
|||
|
||||
if similarity >= threshold:
|
||||
# Calculate positions using ORIGINAL lines to ensure correct character offsets in the file
|
||||
start_pos = sum(len(line) + 1 for line in orig_content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in orig_content_lines[:i + pattern_line_count]) - 1
|
||||
matches.append((start_pos, min(end_pos, len(content))))
|
||||
start_pos, end_pos = _calculate_line_positions(
|
||||
orig_content_lines, i, i + pattern_line_count, len(content)
|
||||
)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
||||
|
|
@ -343,10 +343,9 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]
|
|||
|
||||
# Need at least 50% of lines to have high similarity
|
||||
if high_similarity_count >= len(pattern_lines) * 0.5:
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
start_pos, end_pos = _calculate_line_positions(
|
||||
content_lines, i, i + pattern_line_count, len(content)
|
||||
)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
|
@ -356,6 +355,26 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]
|
|||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def _calculate_line_positions(content_lines: List[str], start_line: int,
|
||||
end_line: int, content_length: int) -> Tuple[int, int]:
|
||||
"""Calculate start and end character positions from line indices.
|
||||
|
||||
Args:
|
||||
content_lines: List of lines (without newlines)
|
||||
start_line: Starting line index (0-based)
|
||||
end_line: Ending line index (exclusive, 0-based)
|
||||
content_length: Total length of the original content string
|
||||
|
||||
Returns:
|
||||
Tuple of (start_pos, end_pos) in the original content
|
||||
"""
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:start_line])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:end_line]) - 1
|
||||
if end_pos >= content_length:
|
||||
end_pos = content_length
|
||||
return start_pos, end_pos
|
||||
|
||||
|
||||
def _find_normalized_matches(content: str, content_lines: List[str],
|
||||
content_normalized_lines: List[str],
|
||||
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
|
||||
|
|
@ -383,13 +402,9 @@ def _find_normalized_matches(content: str, content_lines: List[str],
|
|||
|
||||
if block == pattern_normalized:
|
||||
# Found a match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1
|
||||
|
||||
# Handle case where end is past content
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
|
||||
start_pos, end_pos = _calculate_line_positions(
|
||||
content_lines, i, i + num_pattern_lines, len(content)
|
||||
)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
|
|
|||
|
|
@ -1624,6 +1624,72 @@ def get_mcp_status() -> List[dict]:
|
|||
return result
|
||||
|
||||
|
||||
def probe_mcp_server_tools() -> Dict[str, List[tuple]]:
|
||||
"""Temporarily connect to configured MCP servers and list their tools.
|
||||
|
||||
Designed for ``hermes tools`` interactive configuration — connects to each
|
||||
enabled server, grabs tool names and descriptions, then disconnects.
|
||||
Does NOT register tools in the Hermes registry.
|
||||
|
||||
Returns:
|
||||
Dict mapping server name to list of (tool_name, description) tuples.
|
||||
Servers that fail to connect are omitted from the result.
|
||||
"""
|
||||
if not _MCP_AVAILABLE:
|
||||
return {}
|
||||
|
||||
servers_config = _load_mcp_config()
|
||||
if not servers_config:
|
||||
return {}
|
||||
|
||||
enabled = {
|
||||
k: v for k, v in servers_config.items()
|
||||
if _parse_boolish(v.get("enabled", True), default=True)
|
||||
}
|
||||
if not enabled:
|
||||
return {}
|
||||
|
||||
_ensure_mcp_loop()
|
||||
|
||||
result: Dict[str, List[tuple]] = {}
|
||||
probed_servers: List[MCPServerTask] = []
|
||||
|
||||
async def _probe_all():
|
||||
names = list(enabled.keys())
|
||||
coros = []
|
||||
for name, cfg in enabled.items():
|
||||
ct = cfg.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||
coros.append(asyncio.wait_for(_connect_server(name, cfg), timeout=ct))
|
||||
|
||||
outcomes = await asyncio.gather(*coros, return_exceptions=True)
|
||||
|
||||
for name, outcome in zip(names, outcomes):
|
||||
if isinstance(outcome, Exception):
|
||||
logger.debug("Probe: failed to connect to '%s': %s", name, outcome)
|
||||
continue
|
||||
probed_servers.append(outcome)
|
||||
tools = []
|
||||
for t in outcome._tools:
|
||||
desc = getattr(t, "description", "") or ""
|
||||
tools.append((t.name, desc))
|
||||
result[name] = tools
|
||||
|
||||
# Shut down all probed connections
|
||||
await asyncio.gather(
|
||||
*(s.shutdown() for s in probed_servers),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
try:
|
||||
_run_on_mcp_loop(_probe_all(), timeout=120)
|
||||
except Exception as exc:
|
||||
logger.debug("MCP probe failed: %s", exc)
|
||||
finally:
|
||||
_stop_mcp_loop()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def shutdown_mcp_servers():
|
||||
"""Close all MCP server connections and stop the background loop.
|
||||
|
||||
|
|
|
|||
|
|
@ -23,11 +23,13 @@ Design:
|
|||
- Frozen snapshot pattern: system prompt is stable, tool responses show live state
|
||||
"""
|
||||
|
||||
import fcntl
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
|
|
@ -120,14 +122,43 @@ class MemoryStore:
|
|||
"user": self._render_block("user", self.user_entries),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def _file_lock(path: Path):
|
||||
"""Acquire an exclusive file lock for read-modify-write safety.
|
||||
|
||||
Uses a separate .lock file so the memory file itself can still be
|
||||
atomically replaced via os.replace().
|
||||
"""
|
||||
lock_path = path.with_suffix(path.suffix + ".lock")
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = open(lock_path, "w")
|
||||
try:
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
yield
|
||||
finally:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
fd.close()
|
||||
|
||||
@staticmethod
|
||||
def _path_for(target: str) -> Path:
|
||||
if target == "user":
|
||||
return MEMORY_DIR / "USER.md"
|
||||
return MEMORY_DIR / "MEMORY.md"
|
||||
|
||||
def _reload_target(self, target: str):
|
||||
"""Re-read entries from disk into in-memory state.
|
||||
|
||||
Called under file lock to get the latest state before mutating.
|
||||
"""
|
||||
fresh = self._read_file(self._path_for(target))
|
||||
fresh = list(dict.fromkeys(fresh)) # deduplicate
|
||||
self._set_entries(target, fresh)
|
||||
|
||||
def save_to_disk(self, target: str):
|
||||
"""Persist entries to the appropriate file. Called after every mutation."""
|
||||
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if target == "memory":
|
||||
self._write_file(MEMORY_DIR / "MEMORY.md", self.memory_entries)
|
||||
elif target == "user":
|
||||
self._write_file(MEMORY_DIR / "USER.md", self.user_entries)
|
||||
self._write_file(self._path_for(target), self._entries_for(target))
|
||||
|
||||
def _entries_for(self, target: str) -> List[str]:
|
||||
if target == "user":
|
||||
|
|
@ -162,33 +193,37 @@ class MemoryStore:
|
|||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
entries = self._entries_for(target)
|
||||
limit = self._char_limit(target)
|
||||
with self._file_lock(self._path_for(target)):
|
||||
# Re-read from disk under lock to pick up writes from other sessions
|
||||
self._reload_target(target)
|
||||
|
||||
# Reject exact duplicates
|
||||
if content in entries:
|
||||
return self._success_response(target, "Entry already exists (no duplicate added).")
|
||||
entries = self._entries_for(target)
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Calculate what the new total would be
|
||||
new_entries = entries + [content]
|
||||
new_total = len(ENTRY_DELIMITER.join(new_entries))
|
||||
# Reject exact duplicates
|
||||
if content in entries:
|
||||
return self._success_response(target, "Entry already exists (no duplicate added).")
|
||||
|
||||
if new_total > limit:
|
||||
current = self._char_count(target)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Memory at {current:,}/{limit:,} chars. "
|
||||
f"Adding this entry ({len(content)} chars) would exceed the limit. "
|
||||
f"Replace or remove existing entries first."
|
||||
),
|
||||
"current_entries": entries,
|
||||
"usage": f"{current:,}/{limit:,}",
|
||||
}
|
||||
# Calculate what the new total would be
|
||||
new_entries = entries + [content]
|
||||
new_total = len(ENTRY_DELIMITER.join(new_entries))
|
||||
|
||||
entries.append(content)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
if new_total > limit:
|
||||
current = self._char_count(target)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Memory at {current:,}/{limit:,} chars. "
|
||||
f"Adding this entry ({len(content)} chars) would exceed the limit. "
|
||||
f"Replace or remove existing entries first."
|
||||
),
|
||||
"current_entries": entries,
|
||||
"usage": f"{current:,}/{limit:,}",
|
||||
}
|
||||
|
||||
entries.append(content)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry added.")
|
||||
|
||||
|
|
@ -206,44 +241,47 @@ class MemoryStore:
|
|||
if scan_error:
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), operate on the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), operate on the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to replace just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Check that replacement doesn't blow the budget
|
||||
test_entries = entries.copy()
|
||||
test_entries[idx] = new_content
|
||||
new_total = len(ENTRY_DELIMITER.join(test_entries))
|
||||
|
||||
if new_total > limit:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
"error": (
|
||||
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
|
||||
f"Shorten the new content or remove other entries first."
|
||||
),
|
||||
}
|
||||
# All identical -- safe to replace just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
limit = self._char_limit(target)
|
||||
|
||||
# Check that replacement doesn't blow the budget
|
||||
test_entries = entries.copy()
|
||||
test_entries[idx] = new_content
|
||||
new_total = len(ENTRY_DELIMITER.join(test_entries))
|
||||
|
||||
if new_total > limit:
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Replacement would put memory at {new_total:,}/{limit:,} chars. "
|
||||
f"Shorten the new content or remove other entries first."
|
||||
),
|
||||
}
|
||||
|
||||
entries[idx] = new_content
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
entries[idx] = new_content
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry replaced.")
|
||||
|
||||
|
|
@ -253,28 +291,31 @@ class MemoryStore:
|
|||
if not old_text:
|
||||
return {"success": False, "error": "old_text cannot be empty."}
|
||||
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
with self._file_lock(self._path_for(target)):
|
||||
self._reload_target(target)
|
||||
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
entries = self._entries_for(target)
|
||||
matches = [(i, e) for i, e in enumerate(entries) if old_text in e]
|
||||
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), remove the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to remove just the first
|
||||
if len(matches) == 0:
|
||||
return {"success": False, "error": f"No entry matched '{old_text}'."}
|
||||
|
||||
idx = matches[0][0]
|
||||
entries.pop(idx)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
if len(matches) > 1:
|
||||
# If all matches are identical (exact duplicates), remove the first one
|
||||
unique_texts = set(e for _, e in matches)
|
||||
if len(unique_texts) > 1:
|
||||
previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches]
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Multiple entries matched '{old_text}'. Be more specific.",
|
||||
"matches": previews,
|
||||
}
|
||||
# All identical -- safe to remove just the first
|
||||
|
||||
idx = matches[0][0]
|
||||
entries.pop(idx)
|
||||
self._set_entries(target, entries)
|
||||
self.save_to_disk(target)
|
||||
|
||||
return self._success_response(target, "Entry removed.")
|
||||
|
||||
|
|
|
|||
|
|
@ -78,6 +78,11 @@ class ProcessSession:
|
|||
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
||||
max_output_chars: int = MAX_OUTPUT_CHARS
|
||||
detached: bool = False # True if recovered from crash (no pipe)
|
||||
# Watcher/notification metadata (persisted for crash recovery)
|
||||
watcher_platform: str = ""
|
||||
watcher_chat_id: str = ""
|
||||
watcher_thread_id: str = ""
|
||||
watcher_interval: int = 0 # 0 = no watcher configured
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
||||
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
||||
|
|
@ -709,6 +714,10 @@ class ProcessRegistry:
|
|||
"started_at": s.started_at,
|
||||
"task_id": s.task_id,
|
||||
"session_key": s.session_key,
|
||||
"watcher_platform": s.watcher_platform,
|
||||
"watcher_chat_id": s.watcher_chat_id,
|
||||
"watcher_thread_id": s.watcher_thread_id,
|
||||
"watcher_interval": s.watcher_interval,
|
||||
})
|
||||
|
||||
# Atomic write to avoid corruption on crash
|
||||
|
|
@ -755,12 +764,27 @@ class ProcessRegistry:
|
|||
cwd=entry.get("cwd"),
|
||||
started_at=entry.get("started_at", time.time()),
|
||||
detached=True, # Can't read output, but can report status + kill
|
||||
watcher_platform=entry.get("watcher_platform", ""),
|
||||
watcher_chat_id=entry.get("watcher_chat_id", ""),
|
||||
watcher_thread_id=entry.get("watcher_thread_id", ""),
|
||||
watcher_interval=entry.get("watcher_interval", 0),
|
||||
)
|
||||
with self._lock:
|
||||
self._running[session.id] = session
|
||||
recovered += 1
|
||||
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
|
||||
|
||||
# Re-enqueue watcher so gateway can resume notifications
|
||||
if session.watcher_interval > 0:
|
||||
self.pending_watchers.append({
|
||||
"session_id": session.id,
|
||||
"check_interval": session.watcher_interval,
|
||||
"session_key": session.session_key,
|
||||
"platform": session.watcher_platform,
|
||||
"chat_id": session.watcher_chat_id,
|
||||
"thread_id": session.watcher_thread_id,
|
||||
})
|
||||
|
||||
# Clear the checkpoint (will be rewritten as processes finish)
|
||||
try:
|
||||
from utils import atomic_json_write
|
||||
|
|
|
|||
|
|
@ -355,20 +355,31 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
|
|||
"""Send via Telegram Bot API (one-shot, no polling needed).
|
||||
|
||||
Applies markdown→MarkdownV2 formatting (same as the gateway adapter)
|
||||
so that bold, links, and headers render correctly.
|
||||
so that bold, links, and headers render correctly. If the message
|
||||
already contains HTML tags, it is sent with ``parse_mode='HTML'``
|
||||
instead, bypassing MarkdownV2 conversion.
|
||||
"""
|
||||
try:
|
||||
from telegram import Bot
|
||||
from telegram.constants import ParseMode
|
||||
|
||||
# Reuse the gateway adapter's format_message for markdown→MarkdownV2
|
||||
try:
|
||||
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2
|
||||
_adapter = TelegramAdapter.__new__(TelegramAdapter)
|
||||
formatted = _adapter.format_message(message)
|
||||
except Exception:
|
||||
# Fallback: send as-is if formatting unavailable
|
||||
# Auto-detect HTML tags — if present, skip MarkdownV2 and send as HTML.
|
||||
# Inspired by github.com/ashaney — PR #1568.
|
||||
_has_html = bool(re.search(r'<[a-zA-Z/][^>]*>', message))
|
||||
|
||||
if _has_html:
|
||||
formatted = message
|
||||
send_parse_mode = ParseMode.HTML
|
||||
else:
|
||||
# Reuse the gateway adapter's format_message for markdown→MarkdownV2
|
||||
try:
|
||||
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2
|
||||
_adapter = TelegramAdapter.__new__(TelegramAdapter)
|
||||
formatted = _adapter.format_message(message)
|
||||
except Exception:
|
||||
# Fallback: send as-is if formatting unavailable
|
||||
formatted = message
|
||||
send_parse_mode = ParseMode.MARKDOWN_V2
|
||||
|
||||
bot = Bot(token=token)
|
||||
int_chat_id = int(chat_id)
|
||||
|
|
@ -384,16 +395,19 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
|
|||
try:
|
||||
last_msg = await bot.send_message(
|
||||
chat_id=int_chat_id, text=formatted,
|
||||
parse_mode=ParseMode.MARKDOWN_V2, **thread_kwargs
|
||||
parse_mode=send_parse_mode, **thread_kwargs
|
||||
)
|
||||
except Exception as md_error:
|
||||
# MarkdownV2 failed, fall back to plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("MarkdownV2 parse failed in _send_telegram, falling back to plain text: %s", md_error)
|
||||
try:
|
||||
from gateway.platforms.telegram import _strip_mdv2
|
||||
plain = _strip_mdv2(formatted)
|
||||
except Exception:
|
||||
# Parse failed, fall back to plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower() or "html" in str(md_error).lower():
|
||||
logger.warning("Parse mode %s failed in _send_telegram, falling back to plain text: %s", send_parse_mode, md_error)
|
||||
if not _has_html:
|
||||
try:
|
||||
from gateway.platforms.telegram import _strip_mdv2
|
||||
plain = _strip_mdv2(formatted)
|
||||
except Exception:
|
||||
plain = message
|
||||
else:
|
||||
plain = message
|
||||
last_msg = await bot.send_message(
|
||||
chat_id=int_chat_id, text=plain,
|
||||
|
|
@ -565,50 +579,55 @@ async def _send_email(extra, chat_id, message):
|
|||
return {"error": f"Email send failed: {e}"}
|
||||
|
||||
|
||||
async def _send_sms(api_key, chat_id, message):
|
||||
"""Send via Telnyx SMS REST API (one-shot, no persistent connection needed)."""
|
||||
async def _send_sms(auth_token, chat_id, message):
|
||||
"""Send a single SMS via Twilio REST API.
|
||||
|
||||
Uses HTTP Basic auth (Account SID : Auth Token) and form-encoded POST.
|
||||
Chunking is handled by _send_to_platform() before this is called.
|
||||
"""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
return {"error": "aiohttp not installed. Run: pip install aiohttp"}
|
||||
|
||||
import base64
|
||||
|
||||
account_sid = os.getenv("TWILIO_ACCOUNT_SID", "")
|
||||
from_number = os.getenv("TWILIO_PHONE_NUMBER", "")
|
||||
if not account_sid or not auth_token or not from_number:
|
||||
return {"error": "SMS not configured (TWILIO_ACCOUNT_SID, TWILIO_AUTH_TOKEN, TWILIO_PHONE_NUMBER required)"}
|
||||
|
||||
# Strip markdown — SMS renders it as literal characters
|
||||
message = re.sub(r"\*\*(.+?)\*\*", r"\1", message, flags=re.DOTALL)
|
||||
message = re.sub(r"\*(.+?)\*", r"\1", message, flags=re.DOTALL)
|
||||
message = re.sub(r"__(.+?)__", r"\1", message, flags=re.DOTALL)
|
||||
message = re.sub(r"_(.+?)_", r"\1", message, flags=re.DOTALL)
|
||||
message = re.sub(r"```[a-z]*\n?", "", message)
|
||||
message = re.sub(r"`(.+?)`", r"\1", message)
|
||||
message = re.sub(r"^#{1,6}\s+", "", message, flags=re.MULTILINE)
|
||||
message = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", message)
|
||||
message = re.sub(r"\n{3,}", "\n\n", message)
|
||||
message = message.strip()
|
||||
|
||||
try:
|
||||
from_number = os.getenv("TELNYX_FROM_NUMBERS", "").split(",")[0].strip()
|
||||
if not from_number:
|
||||
return {"error": "TELNYX_FROM_NUMBERS not configured"}
|
||||
if not api_key:
|
||||
api_key = os.getenv("TELNYX_API_KEY", "")
|
||||
if not api_key:
|
||||
return {"error": "TELNYX_API_KEY not configured"}
|
||||
creds = f"{account_sid}:{auth_token}"
|
||||
encoded = base64.b64encode(creds.encode("ascii")).decode("ascii")
|
||||
url = f"https://api.twilio.com/2010-04-01/Accounts/{account_sid}/Messages.json"
|
||||
headers = {"Authorization": f"Basic {encoded}"}
|
||||
|
||||
# Strip markdown for SMS
|
||||
text = re.sub(r"\*\*(.+?)\*\*", r"\1", message, flags=re.DOTALL)
|
||||
text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL)
|
||||
text = re.sub(r"```[a-z]*\n?", "", text)
|
||||
text = re.sub(r"`(.+?)`", r"\1", text)
|
||||
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
|
||||
text = text.strip()
|
||||
|
||||
# Chunk to 1600 chars
|
||||
chunks = [text[i:i+1600] for i in range(0, len(text), 1600)] if len(text) > 1600 else [text]
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
message_ids = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for chunk in chunks:
|
||||
payload = {"from": from_number, "to": chat_id, "text": chunk}
|
||||
async with session.post(
|
||||
"https://api.telnyx.com/v2/messages",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
) as resp:
|
||||
body = await resp.json()
|
||||
if resp.status >= 400:
|
||||
return {"error": f"Telnyx API error ({resp.status}): {body}"}
|
||||
message_ids.append(body.get("data", {}).get("id", ""))
|
||||
return {"success": True, "platform": "sms", "chat_id": chat_id, "message_ids": message_ids}
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field("From", from_number)
|
||||
form_data.add_field("To", chat_id)
|
||||
form_data.add_field("Body", message)
|
||||
|
||||
async with session.post(url, data=form_data, headers=headers) as resp:
|
||||
body = await resp.json()
|
||||
if resp.status >= 400:
|
||||
error_msg = body.get("message", str(body))
|
||||
return {"error": f"Twilio API error ({resp.status}): {error_msg}"}
|
||||
msg_sid = body.get("sid", "")
|
||||
return {"success": True, "platform": "sms", "chat_id": chat_id, "message_id": msg_sid}
|
||||
except Exception as e:
|
||||
return {"error": f"SMS send failed: {e}"}
|
||||
|
||||
|
|
|
|||
|
|
@ -1082,13 +1082,23 @@ def terminal_tool(
|
|||
result_data["check_interval_note"] = (
|
||||
f"Requested {check_interval}s raised to minimum 30s"
|
||||
)
|
||||
watcher_platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
watcher_chat_id = os.getenv("HERMES_SESSION_CHAT_ID", "")
|
||||
watcher_thread_id = os.getenv("HERMES_SESSION_THREAD_ID", "")
|
||||
|
||||
# Store on session for checkpoint persistence
|
||||
proc_session.watcher_platform = watcher_platform
|
||||
proc_session.watcher_chat_id = watcher_chat_id
|
||||
proc_session.watcher_thread_id = watcher_thread_id
|
||||
proc_session.watcher_interval = effective_interval
|
||||
|
||||
process_registry.pending_watchers.append({
|
||||
"session_id": proc_session.id,
|
||||
"check_interval": effective_interval,
|
||||
"session_key": session_key,
|
||||
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
|
||||
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
|
||||
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID", ""),
|
||||
"platform": watcher_platform,
|
||||
"chat_id": watcher_chat_id,
|
||||
"thread_id": watcher_thread_id,
|
||||
})
|
||||
|
||||
return json.dumps(result_data, ensure_ascii=False)
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@
|
|||
Standalone Web Tools Module
|
||||
|
||||
This module provides generic web tools that work with multiple backend providers.
|
||||
Currently uses Firecrawl as the backend, and the interface makes it easy to swap
|
||||
providers without changing the function signatures.
|
||||
Backend is selected during ``hermes tools`` setup (web.backend in config.yaml).
|
||||
|
||||
Available tools:
|
||||
- web_search_tool: Search the web for information
|
||||
- web_extract_tool: Extract content from specific web pages
|
||||
- web_crawl_tool: Crawl websites with specific instructions
|
||||
- web_crawl_tool: Crawl websites with specific instructions (Firecrawl only)
|
||||
|
||||
Backend compatibility:
|
||||
- Firecrawl: https://docs.firecrawl.dev/introduction
|
||||
- Firecrawl: https://docs.firecrawl.dev/introduction (search, extract, crawl)
|
||||
- Parallel: https://docs.parallel.ai (search, extract)
|
||||
|
||||
LLM Processing:
|
||||
- Uses OpenRouter API with Gemini 3 Flash Preview for intelligent content extraction
|
||||
|
|
@ -46,12 +46,50 @@ import os
|
|||
import re
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
import httpx
|
||||
from firecrawl import Firecrawl
|
||||
from agent.auxiliary_client import async_call_llm
|
||||
from tools.debug_helpers import DebugSession
|
||||
from tools.website_policy import check_website_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── Backend Selection ────────────────────────────────────────────────────────
|
||||
|
||||
def _load_web_config() -> dict:
|
||||
"""Load the ``web:`` section from ~/.hermes/config.yaml."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
return load_config().get("web", {})
|
||||
except (ImportError, Exception):
|
||||
return {}
|
||||
|
||||
|
||||
def _get_backend() -> str:
|
||||
"""Determine which web backend to use.
|
||||
|
||||
Reads ``web.backend`` from config.yaml (set by ``hermes tools``).
|
||||
Falls back to whichever API key is present for users who configured
|
||||
keys manually without running setup.
|
||||
"""
|
||||
configured = _load_web_config().get("backend", "").lower().strip()
|
||||
if configured in ("parallel", "firecrawl", "tavily"):
|
||||
return configured
|
||||
# Fallback for manual / legacy config — use whichever key is present.
|
||||
has_firecrawl = bool(os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL"))
|
||||
has_parallel = bool(os.getenv("PARALLEL_API_KEY"))
|
||||
has_tavily = bool(os.getenv("TAVILY_API_KEY"))
|
||||
if has_tavily and not has_firecrawl and not has_parallel:
|
||||
return "tavily"
|
||||
if has_parallel and not has_firecrawl:
|
||||
return "parallel"
|
||||
# Default to firecrawl (backward compat, or when both are set)
|
||||
return "firecrawl"
|
||||
|
||||
|
||||
# ─── Firecrawl Client ────────────────────────────────────────────────────────
|
||||
|
||||
_firecrawl_client = None
|
||||
|
||||
def _get_firecrawl_client():
|
||||
|
|
@ -80,6 +118,129 @@ def _get_firecrawl_client():
|
|||
_firecrawl_client = Firecrawl(**kwargs)
|
||||
return _firecrawl_client
|
||||
|
||||
|
||||
# ─── Parallel Client ─────────────────────────────────────────────────────────
|
||||
|
||||
_parallel_client = None
|
||||
_async_parallel_client = None
|
||||
|
||||
def _get_parallel_client():
|
||||
"""Get or create the Parallel sync client (lazy initialization).
|
||||
|
||||
Requires PARALLEL_API_KEY environment variable.
|
||||
"""
|
||||
from parallel import Parallel
|
||||
global _parallel_client
|
||||
if _parallel_client is None:
|
||||
api_key = os.getenv("PARALLEL_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"PARALLEL_API_KEY environment variable not set. "
|
||||
"Get your API key at https://parallel.ai"
|
||||
)
|
||||
_parallel_client = Parallel(api_key=api_key)
|
||||
return _parallel_client
|
||||
|
||||
|
||||
def _get_async_parallel_client():
|
||||
"""Get or create the Parallel async client (lazy initialization).
|
||||
|
||||
Requires PARALLEL_API_KEY environment variable.
|
||||
"""
|
||||
from parallel import AsyncParallel
|
||||
global _async_parallel_client
|
||||
if _async_parallel_client is None:
|
||||
api_key = os.getenv("PARALLEL_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"PARALLEL_API_KEY environment variable not set. "
|
||||
"Get your API key at https://parallel.ai"
|
||||
)
|
||||
_async_parallel_client = AsyncParallel(api_key=api_key)
|
||||
return _async_parallel_client
|
||||
|
||||
# ─── Tavily Client ───────────────────────────────────────────────────────────
|
||||
|
||||
_TAVILY_BASE_URL = "https://api.tavily.com"
|
||||
|
||||
|
||||
def _tavily_request(endpoint: str, payload: dict) -> dict:
|
||||
"""Send a POST request to the Tavily API.
|
||||
|
||||
Auth is provided via ``api_key`` in the JSON body (no header-based auth).
|
||||
Raises ``ValueError`` if ``TAVILY_API_KEY`` is not set.
|
||||
"""
|
||||
api_key = os.getenv("TAVILY_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"TAVILY_API_KEY environment variable not set. "
|
||||
"Get your API key at https://app.tavily.com/home"
|
||||
)
|
||||
payload["api_key"] = api_key
|
||||
url = f"{_TAVILY_BASE_URL}/{endpoint.lstrip('/')}"
|
||||
logger.info("Tavily %s request to %s", endpoint, url)
|
||||
response = httpx.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _normalize_tavily_search_results(response: dict) -> dict:
|
||||
"""Normalize Tavily /search response to the standard web search format.
|
||||
|
||||
Tavily returns ``{results: [{title, url, content, score, ...}]}``.
|
||||
We map to ``{success, data: {web: [{title, url, description, position}]}}``.
|
||||
"""
|
||||
web_results = []
|
||||
for i, result in enumerate(response.get("results", [])):
|
||||
web_results.append({
|
||||
"title": result.get("title", ""),
|
||||
"url": result.get("url", ""),
|
||||
"description": result.get("content", ""),
|
||||
"position": i + 1,
|
||||
})
|
||||
return {"success": True, "data": {"web": web_results}}
|
||||
|
||||
|
||||
def _normalize_tavily_documents(response: dict, fallback_url: str = "") -> List[Dict[str, Any]]:
|
||||
"""Normalize Tavily /extract or /crawl response to the standard document format.
|
||||
|
||||
Maps results to ``{url, title, content, raw_content, metadata}`` and
|
||||
includes any ``failed_results`` / ``failed_urls`` as error entries.
|
||||
"""
|
||||
documents: List[Dict[str, Any]] = []
|
||||
for result in response.get("results", []):
|
||||
url = result.get("url", fallback_url)
|
||||
raw = result.get("raw_content", "") or result.get("content", "")
|
||||
documents.append({
|
||||
"url": url,
|
||||
"title": result.get("title", ""),
|
||||
"content": raw,
|
||||
"raw_content": raw,
|
||||
"metadata": {"sourceURL": url, "title": result.get("title", "")},
|
||||
})
|
||||
# Handle failed results
|
||||
for fail in response.get("failed_results", []):
|
||||
documents.append({
|
||||
"url": fail.get("url", fallback_url),
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": fail.get("error", "extraction failed"),
|
||||
"metadata": {"sourceURL": fail.get("url", fallback_url)},
|
||||
})
|
||||
for fail_url in response.get("failed_urls", []):
|
||||
url_str = fail_url if isinstance(fail_url, str) else str(fail_url)
|
||||
documents.append({
|
||||
"url": url_str,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": "extraction failed",
|
||||
"metadata": {"sourceURL": url_str},
|
||||
})
|
||||
return documents
|
||||
|
||||
|
||||
DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000
|
||||
|
||||
# Allow per-task override via env var
|
||||
|
|
@ -427,13 +588,89 @@ def clean_base64_images(text: str) -> str:
|
|||
return cleaned_text
|
||||
|
||||
|
||||
# ─── Parallel Search & Extract Helpers ────────────────────────────────────────
|
||||
|
||||
def _parallel_search(query: str, limit: int = 5) -> dict:
|
||||
"""Search using the Parallel SDK and return results as a dict."""
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return {"error": "Interrupted", "success": False}
|
||||
|
||||
mode = os.getenv("PARALLEL_SEARCH_MODE", "agentic").lower().strip()
|
||||
if mode not in ("fast", "one-shot", "agentic"):
|
||||
mode = "agentic"
|
||||
|
||||
logger.info("Parallel search: '%s' (mode=%s, limit=%d)", query, mode, limit)
|
||||
response = _get_parallel_client().beta.search(
|
||||
search_queries=[query],
|
||||
objective=query,
|
||||
mode=mode,
|
||||
max_results=min(limit, 20),
|
||||
)
|
||||
|
||||
web_results = []
|
||||
for i, result in enumerate(response.results or []):
|
||||
excerpts = result.excerpts or []
|
||||
web_results.append({
|
||||
"url": result.url or "",
|
||||
"title": result.title or "",
|
||||
"description": " ".join(excerpts) if excerpts else "",
|
||||
"position": i + 1,
|
||||
})
|
||||
|
||||
return {"success": True, "data": {"web": web_results}}
|
||||
|
||||
|
||||
async def _parallel_extract(urls: List[str]) -> List[Dict[str, Any]]:
|
||||
"""Extract content from URLs using the Parallel async SDK.
|
||||
|
||||
Returns a list of result dicts matching the structure expected by the
|
||||
LLM post-processing pipeline (url, title, content, metadata).
|
||||
"""
|
||||
from tools.interrupt import is_interrupted
|
||||
if is_interrupted():
|
||||
return [{"url": u, "error": "Interrupted", "title": ""} for u in urls]
|
||||
|
||||
logger.info("Parallel extract: %d URL(s)", len(urls))
|
||||
response = await _get_async_parallel_client().beta.extract(
|
||||
urls=urls,
|
||||
full_content=True,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in response.results or []:
|
||||
content = result.full_content or ""
|
||||
if not content:
|
||||
content = "\n\n".join(result.excerpts or [])
|
||||
url = result.url or ""
|
||||
title = result.title or ""
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"raw_content": content,
|
||||
"metadata": {"sourceURL": url, "title": title},
|
||||
})
|
||||
|
||||
for error in response.errors or []:
|
||||
results.append({
|
||||
"url": error.url or "",
|
||||
"title": "",
|
||||
"content": "",
|
||||
"error": error.content or error.error_type or "extraction failed",
|
||||
"metadata": {"sourceURL": error.url or ""},
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
"""
|
||||
Search the web for information using available search API backend.
|
||||
|
||||
|
||||
This function provides a generic interface for web search that can work
|
||||
with multiple backends. Currently uses Firecrawl.
|
||||
|
||||
with multiple backends (Parallel or Firecrawl).
|
||||
|
||||
Note: This function returns search result metadata only (URLs, titles, descriptions).
|
||||
Use web_extract_tool to get full content from specific URLs.
|
||||
|
||||
|
|
@ -477,17 +714,44 @@ def web_search_tool(query: str, limit: int = 5) -> str:
|
|||
if is_interrupted():
|
||||
return json.dumps({"error": "Interrupted", "success": False})
|
||||
|
||||
# Dispatch to the configured backend
|
||||
backend = _get_backend()
|
||||
if backend == "parallel":
|
||||
response_data = _parallel_search(query, limit)
|
||||
debug_call_data["results_count"] = len(response_data.get("data", {}).get("web", []))
|
||||
result_json = json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
debug_call_data["final_response_size"] = len(result_json)
|
||||
_debug.log_call("web_search_tool", debug_call_data)
|
||||
_debug.save()
|
||||
return result_json
|
||||
|
||||
if backend == "tavily":
|
||||
logger.info("Tavily search: '%s' (limit: %d)", query, limit)
|
||||
raw = _tavily_request("search", {
|
||||
"query": query,
|
||||
"max_results": min(limit, 20),
|
||||
"include_raw_content": False,
|
||||
"include_images": False,
|
||||
})
|
||||
response_data = _normalize_tavily_search_results(raw)
|
||||
debug_call_data["results_count"] = len(response_data.get("data", {}).get("web", []))
|
||||
result_json = json.dumps(response_data, indent=2, ensure_ascii=False)
|
||||
debug_call_data["final_response_size"] = len(result_json)
|
||||
_debug.log_call("web_search_tool", debug_call_data)
|
||||
_debug.save()
|
||||
return result_json
|
||||
|
||||
logger.info("Searching the web for: '%s' (limit: %d)", query, limit)
|
||||
|
||||
|
||||
response = _get_firecrawl_client().search(
|
||||
query=query,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
|
||||
# The response is a SearchData object with web, news, and images attributes
|
||||
# When not scraping, the results are directly in these attributes
|
||||
web_results = []
|
||||
|
||||
|
||||
# Check if response has web attribute (SearchData object)
|
||||
if hasattr(response, 'web'):
|
||||
# Response is a SearchData object with web attribute
|
||||
|
|
@ -595,100 +859,137 @@ async def web_extract_tool(
|
|||
|
||||
try:
|
||||
logger.info("Extracting content from %d URL(s)", len(urls))
|
||||
|
||||
# Determine requested formats for Firecrawl v2
|
||||
formats: List[str] = []
|
||||
if format == "markdown":
|
||||
formats = ["markdown"]
|
||||
elif format == "html":
|
||||
formats = ["html"]
|
||||
else:
|
||||
# Default: request markdown for LLM-readiness and include html as backup
|
||||
formats = ["markdown", "html"]
|
||||
|
||||
# Always use individual scraping for simplicity and reliability
|
||||
# Batch scraping adds complexity without much benefit for small numbers of URLs
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
for url in urls:
|
||||
if _is_interrupted():
|
||||
results.append({"url": url, "error": "Interrupted", "title": ""})
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info("Scraping: %s", url)
|
||||
scrape_result = _get_firecrawl_client().scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
|
||||
# Process the result - properly handle object serialization
|
||||
metadata = {}
|
||||
title = ""
|
||||
content_markdown = None
|
||||
content_html = None
|
||||
|
||||
# Extract data from the scrape result
|
||||
if hasattr(scrape_result, 'model_dump'):
|
||||
# Pydantic model - use model_dump to get dict
|
||||
result_dict = scrape_result.model_dump()
|
||||
content_markdown = result_dict.get('markdown')
|
||||
content_html = result_dict.get('html')
|
||||
metadata = result_dict.get('metadata', {})
|
||||
elif hasattr(scrape_result, '__dict__'):
|
||||
# Regular object with attributes
|
||||
content_markdown = getattr(scrape_result, 'markdown', None)
|
||||
content_html = getattr(scrape_result, 'html', None)
|
||||
|
||||
# Handle metadata - convert to dict if it's an object
|
||||
metadata_obj = getattr(scrape_result, 'metadata', {})
|
||||
if hasattr(metadata_obj, 'model_dump'):
|
||||
metadata = metadata_obj.model_dump()
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
elif isinstance(metadata_obj, dict):
|
||||
metadata = metadata_obj
|
||||
else:
|
||||
metadata = {}
|
||||
elif isinstance(scrape_result, dict):
|
||||
# Already a dictionary
|
||||
content_markdown = scrape_result.get('markdown')
|
||||
content_html = scrape_result.get('html')
|
||||
metadata = scrape_result.get('metadata', {})
|
||||
|
||||
# Ensure metadata is a dict (not an object)
|
||||
if not isinstance(metadata, dict):
|
||||
if hasattr(metadata, 'model_dump'):
|
||||
metadata = metadata.model_dump()
|
||||
elif hasattr(metadata, '__dict__'):
|
||||
metadata = metadata.__dict__
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Get title from metadata
|
||||
title = metadata.get("title", "")
|
||||
|
||||
# Choose content based on requested format
|
||||
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
|
||||
|
||||
results.append({
|
||||
"url": metadata.get("sourceURL", url),
|
||||
"title": title,
|
||||
"content": chosen_content,
|
||||
"raw_content": chosen_content,
|
||||
"metadata": metadata # Now guaranteed to be a dict
|
||||
})
|
||||
|
||||
except Exception as scrape_err:
|
||||
logger.debug("Scrape failed for %s: %s", url, scrape_err)
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": str(scrape_err)
|
||||
})
|
||||
# Dispatch to the configured backend
|
||||
backend = _get_backend()
|
||||
|
||||
if backend == "parallel":
|
||||
results = await _parallel_extract(urls)
|
||||
elif backend == "tavily":
|
||||
logger.info("Tavily extract: %d URL(s)", len(urls))
|
||||
raw = _tavily_request("extract", {
|
||||
"urls": urls,
|
||||
"include_images": False,
|
||||
})
|
||||
results = _normalize_tavily_documents(raw, fallback_url=urls[0] if urls else "")
|
||||
else:
|
||||
# ── Firecrawl extraction ──
|
||||
# Determine requested formats for Firecrawl v2
|
||||
formats: List[str] = []
|
||||
if format == "markdown":
|
||||
formats = ["markdown"]
|
||||
elif format == "html":
|
||||
formats = ["html"]
|
||||
else:
|
||||
# Default: request markdown for LLM-readiness and include html as backup
|
||||
formats = ["markdown", "html"]
|
||||
|
||||
# Always use individual scraping for simplicity and reliability
|
||||
# Batch scraping adds complexity without much benefit for small numbers of URLs
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
from tools.interrupt import is_interrupted as _is_interrupted
|
||||
for url in urls:
|
||||
if _is_interrupted():
|
||||
results.append({"url": url, "error": "Interrupted", "title": ""})
|
||||
continue
|
||||
|
||||
# Website policy check — block before fetching
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
results.append({
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info("Scraping: %s", url)
|
||||
scrape_result = _get_firecrawl_client().scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
|
||||
# Process the result - properly handle object serialization
|
||||
metadata = {}
|
||||
title = ""
|
||||
content_markdown = None
|
||||
content_html = None
|
||||
|
||||
# Extract data from the scrape result
|
||||
if hasattr(scrape_result, 'model_dump'):
|
||||
# Pydantic model - use model_dump to get dict
|
||||
result_dict = scrape_result.model_dump()
|
||||
content_markdown = result_dict.get('markdown')
|
||||
content_html = result_dict.get('html')
|
||||
metadata = result_dict.get('metadata', {})
|
||||
elif hasattr(scrape_result, '__dict__'):
|
||||
# Regular object with attributes
|
||||
content_markdown = getattr(scrape_result, 'markdown', None)
|
||||
content_html = getattr(scrape_result, 'html', None)
|
||||
|
||||
# Handle metadata - convert to dict if it's an object
|
||||
metadata_obj = getattr(scrape_result, 'metadata', {})
|
||||
if hasattr(metadata_obj, 'model_dump'):
|
||||
metadata = metadata_obj.model_dump()
|
||||
elif hasattr(metadata_obj, '__dict__'):
|
||||
metadata = metadata_obj.__dict__
|
||||
elif isinstance(metadata_obj, dict):
|
||||
metadata = metadata_obj
|
||||
else:
|
||||
metadata = {}
|
||||
elif isinstance(scrape_result, dict):
|
||||
# Already a dictionary
|
||||
content_markdown = scrape_result.get('markdown')
|
||||
content_html = scrape_result.get('html')
|
||||
metadata = scrape_result.get('metadata', {})
|
||||
|
||||
# Ensure metadata is a dict (not an object)
|
||||
if not isinstance(metadata, dict):
|
||||
if hasattr(metadata, 'model_dump'):
|
||||
metadata = metadata.model_dump()
|
||||
elif hasattr(metadata, '__dict__'):
|
||||
metadata = metadata.__dict__
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Get title from metadata
|
||||
title = metadata.get("title", "")
|
||||
|
||||
# Re-check final URL after redirect
|
||||
final_url = metadata.get("sourceURL", url)
|
||||
final_blocked = check_website_access(final_url)
|
||||
if final_blocked:
|
||||
logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"])
|
||||
results.append({
|
||||
"url": final_url, "title": title, "content": "", "raw_content": "",
|
||||
"error": final_blocked["message"],
|
||||
"blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
# Choose content based on requested format
|
||||
chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or ""
|
||||
|
||||
results.append({
|
||||
"url": final_url,
|
||||
"title": title,
|
||||
"content": chosen_content,
|
||||
"raw_content": chosen_content,
|
||||
"metadata": metadata # Now guaranteed to be a dict
|
||||
})
|
||||
|
||||
except Exception as scrape_err:
|
||||
logger.debug("Scrape failed for %s: %s", url, scrape_err)
|
||||
results.append({
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"raw_content": "",
|
||||
"error": str(scrape_err)
|
||||
})
|
||||
|
||||
response = {"results": results}
|
||||
|
||||
|
|
@ -778,6 +1079,7 @@ async def web_extract_tool(
|
|||
"title": r.get("title", ""),
|
||||
"content": r.get("content", ""),
|
||||
"error": r.get("error"),
|
||||
**({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {}),
|
||||
}
|
||||
for r in response.get("results", [])
|
||||
]
|
||||
|
|
@ -862,6 +1164,91 @@ async def web_crawl_tool(
|
|||
}
|
||||
|
||||
try:
|
||||
backend = _get_backend()
|
||||
|
||||
# Tavily supports crawl via its /crawl endpoint
|
||||
if backend == "tavily":
|
||||
# Ensure URL has protocol
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = f'https://{url}'
|
||||
|
||||
# Website policy check
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_crawl for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
return json.dumps({"results": [{"url": url, "title": "", "content": "", "error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}}]}, ensure_ascii=False)
|
||||
|
||||
from tools.interrupt import is_interrupted as _is_int
|
||||
if _is_int():
|
||||
return json.dumps({"error": "Interrupted", "success": False})
|
||||
|
||||
logger.info("Tavily crawl: %s", url)
|
||||
payload: Dict[str, Any] = {
|
||||
"url": url,
|
||||
"limit": 20,
|
||||
"extract_depth": depth,
|
||||
}
|
||||
if instructions:
|
||||
payload["instructions"] = instructions
|
||||
raw = _tavily_request("crawl", payload)
|
||||
results = _normalize_tavily_documents(raw, fallback_url=url)
|
||||
|
||||
response = {"results": results}
|
||||
# Fall through to the shared LLM processing and trimming below
|
||||
# (skip the Firecrawl-specific crawl logic)
|
||||
pages_crawled = len(response.get('results', []))
|
||||
logger.info("Crawled %d pages", pages_crawled)
|
||||
debug_call_data["pages_crawled"] = pages_crawled
|
||||
debug_call_data["original_response_size"] = len(json.dumps(response))
|
||||
|
||||
# Process each result with LLM if enabled
|
||||
if use_llm_processing:
|
||||
logger.info("Processing crawled content with LLM (parallel)...")
|
||||
debug_call_data["processing_applied"].append("llm_processing")
|
||||
|
||||
async def _process_tavily_crawl(result):
|
||||
page_url = result.get('url', 'Unknown URL')
|
||||
title = result.get('title', '')
|
||||
content = result.get('content', '')
|
||||
if not content:
|
||||
return result, None, "no_content"
|
||||
original_size = len(content)
|
||||
processed = await process_content_with_llm(content, page_url, title, model, min_length)
|
||||
if processed:
|
||||
result['raw_content'] = content
|
||||
result['content'] = processed
|
||||
metrics = {"url": page_url, "original_size": original_size, "processed_size": len(processed),
|
||||
"compression_ratio": len(processed) / original_size if original_size else 1.0, "model_used": model}
|
||||
return result, metrics, "processed"
|
||||
metrics = {"url": page_url, "original_size": original_size, "processed_size": original_size,
|
||||
"compression_ratio": 1.0, "model_used": None, "reason": "content_too_short"}
|
||||
return result, metrics, "too_short"
|
||||
|
||||
tasks = [_process_tavily_crawl(r) for r in response.get('results', [])]
|
||||
processed_results = await asyncio.gather(*tasks)
|
||||
for result, metrics, status in processed_results:
|
||||
if status == "processed":
|
||||
debug_call_data["compression_metrics"].append(metrics)
|
||||
debug_call_data["pages_processed_with_llm"] += 1
|
||||
|
||||
trimmed_results = [{"url": r.get("url", ""), "title": r.get("title", ""), "content": r.get("content", ""), "error": r.get("error"),
|
||||
**({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {})} for r in response.get("results", [])]
|
||||
result_json = json.dumps({"results": trimmed_results}, indent=2, ensure_ascii=False)
|
||||
cleaned_result = clean_base64_images(result_json)
|
||||
debug_call_data["final_response_size"] = len(cleaned_result)
|
||||
_debug.log_call("web_crawl_tool", debug_call_data)
|
||||
_debug.save()
|
||||
return cleaned_result
|
||||
|
||||
# web_crawl requires Firecrawl — Parallel has no crawl API
|
||||
if not (os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL")):
|
||||
return json.dumps({
|
||||
"error": "web_crawl requires Firecrawl. Set FIRECRAWL_API_KEY, "
|
||||
"or use web_search + web_extract instead.",
|
||||
"success": False,
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Ensure URL has protocol
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = f'https://{url}'
|
||||
|
|
@ -870,6 +1257,13 @@ async def web_crawl_tool(
|
|||
instructions_text = f" with instructions: '{instructions}'" if instructions else ""
|
||||
logger.info("Crawling %s%s", url, instructions_text)
|
||||
|
||||
# Website policy check — block before crawling
|
||||
blocked = check_website_access(url)
|
||||
if blocked:
|
||||
logger.info("Blocked web_crawl for %s by rule %s", blocked["host"], blocked["rule"])
|
||||
return json.dumps({"results": [{"url": url, "title": "", "content": "", "error": blocked["message"],
|
||||
"blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}}]}, ensure_ascii=False)
|
||||
|
||||
# Use Firecrawl's v2 crawl functionality
|
||||
# Docs: https://docs.firecrawl.dev/features/crawl
|
||||
# The crawl() method automatically waits for completion and returns all data
|
||||
|
|
@ -975,6 +1369,17 @@ async def web_crawl_tool(
|
|||
page_url = metadata.get("sourceURL", metadata.get("url", "Unknown URL"))
|
||||
title = metadata.get("title", "")
|
||||
|
||||
# Re-check crawled page URL against policy
|
||||
page_blocked = check_website_access(page_url)
|
||||
if page_blocked:
|
||||
logger.info("Blocked crawled page %s by rule %s", page_blocked["host"], page_blocked["rule"])
|
||||
pages.append({
|
||||
"url": page_url, "title": title, "content": "", "raw_content": "",
|
||||
"error": page_blocked["message"],
|
||||
"blocked_by_policy": {"host": page_blocked["host"], "rule": page_blocked["rule"], "source": page_blocked["source"]},
|
||||
})
|
||||
continue
|
||||
|
||||
# Choose content (prefer markdown)
|
||||
content = content_markdown or content_html or ""
|
||||
|
||||
|
|
@ -1070,9 +1475,11 @@ async def web_crawl_tool(
|
|||
# Trim output to minimal fields per entry: title, content, error
|
||||
trimmed_results = [
|
||||
{
|
||||
"url": r.get("url", ""),
|
||||
"title": r.get("title", ""),
|
||||
"content": r.get("content", ""),
|
||||
"error": r.get("error")
|
||||
"error": r.get("error"),
|
||||
**({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {}),
|
||||
}
|
||||
for r in response.get("results", [])
|
||||
]
|
||||
|
|
@ -1106,13 +1513,23 @@ async def web_crawl_tool(
|
|||
def check_firecrawl_api_key() -> bool:
|
||||
"""
|
||||
Check if the Firecrawl API key is available in environment variables.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
return bool(os.getenv("FIRECRAWL_API_KEY"))
|
||||
|
||||
|
||||
def check_web_api_key() -> bool:
|
||||
"""Check if any web backend API key is available (Parallel, Firecrawl, or Tavily)."""
|
||||
return bool(
|
||||
os.getenv("PARALLEL_API_KEY")
|
||||
or os.getenv("FIRECRAWL_API_KEY")
|
||||
or os.getenv("FIRECRAWL_API_URL")
|
||||
or os.getenv("TAVILY_API_KEY")
|
||||
)
|
||||
|
||||
|
||||
def check_auxiliary_model() -> bool:
|
||||
"""Check if an auxiliary text model is available for LLM content processing."""
|
||||
try:
|
||||
|
|
@ -1139,26 +1556,32 @@ if __name__ == "__main__":
|
|||
print("=" * 40)
|
||||
|
||||
# Check if API keys are available
|
||||
firecrawl_available = check_firecrawl_api_key()
|
||||
web_available = check_web_api_key()
|
||||
nous_available = check_auxiliary_model()
|
||||
|
||||
if not firecrawl_available:
|
||||
print("❌ FIRECRAWL_API_KEY environment variable not set")
|
||||
print("Please set your API key: export FIRECRAWL_API_KEY='your-key-here'")
|
||||
print("Get API key at: https://firecrawl.dev/")
|
||||
|
||||
if web_available:
|
||||
backend = _get_backend()
|
||||
print(f"✅ Web backend: {backend}")
|
||||
if backend == "parallel":
|
||||
print(" Using Parallel API (https://parallel.ai)")
|
||||
elif backend == "tavily":
|
||||
print(" Using Tavily API (https://tavily.com)")
|
||||
else:
|
||||
print(" Using Firecrawl API (https://firecrawl.dev)")
|
||||
else:
|
||||
print("✅ Firecrawl API key found")
|
||||
|
||||
print("❌ No web search backend configured")
|
||||
print("Set PARALLEL_API_KEY, TAVILY_API_KEY, or FIRECRAWL_API_KEY")
|
||||
|
||||
if not nous_available:
|
||||
print("❌ No auxiliary model available for LLM content processing")
|
||||
print("Set OPENROUTER_API_KEY, configure Nous Portal, or set OPENAI_BASE_URL + OPENAI_API_KEY")
|
||||
print("⚠️ Without an auxiliary model, LLM content processing will be disabled")
|
||||
else:
|
||||
print(f"✅ Auxiliary model available: {DEFAULT_SUMMARIZER_MODEL}")
|
||||
|
||||
if not firecrawl_available:
|
||||
|
||||
if not web_available:
|
||||
exit(1)
|
||||
|
||||
|
||||
print("🛠️ Web tools ready for use!")
|
||||
|
||||
if nous_available:
|
||||
|
|
@ -1256,8 +1679,8 @@ registry.register(
|
|||
toolset="web",
|
||||
schema=WEB_SEARCH_SCHEMA,
|
||||
handler=lambda args, **kw: web_search_tool(args.get("query", ""), limit=5),
|
||||
check_fn=check_firecrawl_api_key,
|
||||
requires_env=["FIRECRAWL_API_KEY"],
|
||||
check_fn=check_web_api_key,
|
||||
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "TAVILY_API_KEY"],
|
||||
emoji="🔍",
|
||||
)
|
||||
registry.register(
|
||||
|
|
@ -1266,8 +1689,8 @@ registry.register(
|
|||
schema=WEB_EXTRACT_SCHEMA,
|
||||
handler=lambda args, **kw: web_extract_tool(
|
||||
args.get("urls", [])[:5] if isinstance(args.get("urls"), list) else [], "markdown"),
|
||||
check_fn=check_firecrawl_api_key,
|
||||
requires_env=["FIRECRAWL_API_KEY"],
|
||||
check_fn=check_web_api_key,
|
||||
requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "TAVILY_API_KEY"],
|
||||
is_async=True,
|
||||
emoji="📄",
|
||||
)
|
||||
|
|
|
|||
285
tools/website_policy.py
Normal file
285
tools/website_policy.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""Website access policy helpers for URL-capable tools.
|
||||
|
||||
This module loads a user-managed website blocklist from ~/.hermes/config.yaml
|
||||
and optional shared list files. It is intentionally lightweight so web/browser
|
||||
tools can enforce URL policy without pulling in the heavier CLI config stack.
|
||||
|
||||
Policy is cached in memory with a short TTL so config changes take effect
|
||||
quickly without re-reading the file on every URL check.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_WEBSITE_BLOCKLIST = {
|
||||
"enabled": False,
|
||||
"domains": [],
|
||||
"shared_files": [],
|
||||
}
|
||||
|
||||
# Cache: parsed policy + timestamp. Avoids re-reading config.yaml on every
|
||||
# URL check (a web_crawl with 50 pages would otherwise mean 51 YAML parses).
|
||||
_CACHE_TTL_SECONDS = 30.0
|
||||
_cache_lock = threading.Lock()
|
||||
_cached_policy: Optional[Dict[str, Any]] = None
|
||||
_cached_policy_path: Optional[str] = None
|
||||
_cached_policy_time: float = 0.0
|
||||
|
||||
|
||||
def _get_hermes_home() -> Path:
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def _get_default_config_path() -> Path:
|
||||
return _get_hermes_home() / "config.yaml"
|
||||
|
||||
|
||||
class WebsitePolicyError(Exception):
|
||||
"""Raised when a website policy file is malformed."""
|
||||
|
||||
|
||||
def _normalize_host(host: str) -> str:
|
||||
return (host or "").strip().lower().rstrip(".")
|
||||
|
||||
|
||||
def _normalize_rule(rule: Any) -> Optional[str]:
|
||||
if not isinstance(rule, str):
|
||||
return None
|
||||
value = rule.strip().lower()
|
||||
if not value or value.startswith("#"):
|
||||
return None
|
||||
if "://" in value:
|
||||
parsed = urlparse(value)
|
||||
value = parsed.netloc or parsed.path
|
||||
value = value.split("/", 1)[0].strip().rstrip(".")
|
||||
if value.startswith("www."):
|
||||
value = value[4:]
|
||||
return value or None
|
||||
|
||||
|
||||
def _iter_blocklist_file_rules(path: Path) -> List[str]:
|
||||
"""Load rules from a shared blocklist file.
|
||||
|
||||
Missing or unreadable files log a warning and return an empty list
|
||||
rather than raising — a bad file path should not disable all web tools.
|
||||
"""
|
||||
try:
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
logger.warning("Shared blocklist file not found (skipping): %s", path)
|
||||
return []
|
||||
except (OSError, UnicodeDecodeError) as exc:
|
||||
logger.warning("Failed to read shared blocklist file %s (skipping): %s", path, exc)
|
||||
return []
|
||||
|
||||
rules: List[str] = []
|
||||
for line in raw.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
normalized = _normalize_rule(stripped)
|
||||
if normalized:
|
||||
rules.append(normalized)
|
||||
return rules
|
||||
|
||||
|
||||
def _load_policy_config(config_path: Optional[Path] = None) -> Dict[str, Any]:
|
||||
config_path = config_path or _get_default_config_path()
|
||||
if not config_path.exists():
|
||||
return dict(_DEFAULT_WEBSITE_BLOCKLIST)
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
logger.debug("PyYAML not installed — website blocklist disabled")
|
||||
return dict(_DEFAULT_WEBSITE_BLOCKLIST)
|
||||
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f) or {}
|
||||
except yaml.YAMLError as exc:
|
||||
raise WebsitePolicyError(f"Invalid config YAML at {config_path}: {exc}") from exc
|
||||
except OSError as exc:
|
||||
raise WebsitePolicyError(f"Failed to read config file {config_path}: {exc}") from exc
|
||||
if not isinstance(config, dict):
|
||||
raise WebsitePolicyError("config root must be a mapping")
|
||||
|
||||
security = config.get("security", {})
|
||||
if security is None:
|
||||
security = {}
|
||||
if not isinstance(security, dict):
|
||||
raise WebsitePolicyError("security must be a mapping")
|
||||
|
||||
website_blocklist = security.get("website_blocklist", {})
|
||||
if website_blocklist is None:
|
||||
website_blocklist = {}
|
||||
if not isinstance(website_blocklist, dict):
|
||||
raise WebsitePolicyError("security.website_blocklist must be a mapping")
|
||||
|
||||
policy = dict(_DEFAULT_WEBSITE_BLOCKLIST)
|
||||
policy.update(website_blocklist)
|
||||
return policy
|
||||
|
||||
|
||||
def load_website_blocklist(config_path: Optional[Path] = None) -> Dict[str, Any]:
|
||||
"""Load and return the parsed website blocklist policy.
|
||||
|
||||
Results are cached for ``_CACHE_TTL_SECONDS`` to avoid re-reading
|
||||
config.yaml on every URL check. Pass an explicit ``config_path``
|
||||
to bypass the cache (used by tests).
|
||||
"""
|
||||
global _cached_policy, _cached_policy_path, _cached_policy_time
|
||||
|
||||
resolved_path = str(config_path) if config_path else "__default__"
|
||||
now = time.monotonic()
|
||||
|
||||
# Return cached policy if still fresh and same path
|
||||
if config_path is None:
|
||||
with _cache_lock:
|
||||
if (
|
||||
_cached_policy is not None
|
||||
and _cached_policy_path == resolved_path
|
||||
and (now - _cached_policy_time) < _CACHE_TTL_SECONDS
|
||||
):
|
||||
return _cached_policy
|
||||
|
||||
config_path = config_path or _get_default_config_path()
|
||||
policy = _load_policy_config(config_path)
|
||||
|
||||
raw_domains = policy.get("domains", []) or []
|
||||
if not isinstance(raw_domains, list):
|
||||
raise WebsitePolicyError("security.website_blocklist.domains must be a list")
|
||||
|
||||
raw_shared_files = policy.get("shared_files", []) or []
|
||||
if not isinstance(raw_shared_files, list):
|
||||
raise WebsitePolicyError("security.website_blocklist.shared_files must be a list")
|
||||
|
||||
enabled = policy.get("enabled", True)
|
||||
if not isinstance(enabled, bool):
|
||||
raise WebsitePolicyError("security.website_blocklist.enabled must be a boolean")
|
||||
|
||||
rules: List[Dict[str, str]] = []
|
||||
seen: set[Tuple[str, str]] = set()
|
||||
|
||||
for raw_rule in raw_domains:
|
||||
normalized = _normalize_rule(raw_rule)
|
||||
if normalized and ("config", normalized) not in seen:
|
||||
rules.append({"pattern": normalized, "source": "config"})
|
||||
seen.add(("config", normalized))
|
||||
|
||||
for shared_file in raw_shared_files:
|
||||
if not isinstance(shared_file, str) or not shared_file.strip():
|
||||
continue
|
||||
path = Path(shared_file).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = (_get_hermes_home() / path).resolve()
|
||||
for normalized in _iter_blocklist_file_rules(path):
|
||||
key = (str(path), normalized)
|
||||
if key in seen:
|
||||
continue
|
||||
rules.append({"pattern": normalized, "source": str(path)})
|
||||
seen.add(key)
|
||||
|
||||
result = {"enabled": enabled, "rules": rules}
|
||||
|
||||
# Cache the result (only for the default path — explicit paths are tests)
|
||||
if config_path == _get_default_config_path():
|
||||
with _cache_lock:
|
||||
_cached_policy = result
|
||||
_cached_policy_path = "__default__"
|
||||
_cached_policy_time = now
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def invalidate_cache() -> None:
|
||||
"""Force the next ``check_website_access`` call to re-read config."""
|
||||
global _cached_policy
|
||||
with _cache_lock:
|
||||
_cached_policy = None
|
||||
|
||||
|
||||
def _match_host_against_rule(host: str, pattern: str) -> bool:
|
||||
if not host or not pattern:
|
||||
return False
|
||||
if pattern.startswith("*."):
|
||||
return fnmatch.fnmatch(host, pattern)
|
||||
return host == pattern or host.endswith(f".{pattern}")
|
||||
|
||||
|
||||
def _extract_host_from_urlish(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
host = _normalize_host(parsed.hostname or parsed.netloc)
|
||||
if host:
|
||||
return host
|
||||
|
||||
if "://" not in url:
|
||||
schemeless = urlparse(f"//{url}")
|
||||
host = _normalize_host(schemeless.hostname or schemeless.netloc)
|
||||
if host:
|
||||
return host
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def check_website_access(url: str, config_path: Optional[Path] = None) -> Optional[Dict[str, str]]:
|
||||
"""Check whether a URL is allowed by the website blocklist policy.
|
||||
|
||||
Returns ``None`` if access is allowed, or a dict with block metadata
|
||||
(``host``, ``rule``, ``source``, ``message``) if blocked.
|
||||
|
||||
Never raises on policy errors — logs a warning and returns ``None``
|
||||
(fail-open) so a config typo doesn't break all web tools. Pass
|
||||
``config_path`` explicitly (tests) to get strict error propagation.
|
||||
"""
|
||||
# Fast path: if no explicit config_path and the cached policy is disabled
|
||||
# or empty, skip all work (no YAML read, no host extraction).
|
||||
if config_path is None:
|
||||
with _cache_lock:
|
||||
if _cached_policy is not None and not _cached_policy.get("enabled"):
|
||||
return None
|
||||
|
||||
host = _extract_host_from_urlish(url)
|
||||
if not host:
|
||||
return None
|
||||
|
||||
try:
|
||||
policy = load_website_blocklist(config_path)
|
||||
except WebsitePolicyError as exc:
|
||||
if config_path is not None:
|
||||
raise # Tests pass explicit paths — let errors propagate
|
||||
logger.warning("Website policy config error (failing open): %s", exc)
|
||||
return None
|
||||
except Exception as exc:
|
||||
logger.warning("Unexpected error loading website policy (failing open): %s", exc)
|
||||
return None
|
||||
|
||||
if not policy.get("enabled"):
|
||||
return None
|
||||
|
||||
for rule in policy.get("rules", []):
|
||||
pattern = rule.get("pattern", "")
|
||||
if _match_host_against_rule(host, pattern):
|
||||
logger.info("Blocked URL %s — matched rule '%s' from %s",
|
||||
url, pattern, rule.get("source", "config"))
|
||||
return {
|
||||
"url": url,
|
||||
"host": host,
|
||||
"rule": pattern,
|
||||
"source": rule.get("source", "config"),
|
||||
"message": (
|
||||
f"Blocked by website policy: '{host}' matched rule '{pattern}'"
|
||||
f" from {rule.get('source', 'config')}"
|
||||
),
|
||||
}
|
||||
return None
|
||||
Loading…
Add table
Add a link
Reference in a new issue