Merge pull request #6 from NousResearch/fix-leakage

Fix VM instance sharing across tasks
This commit is contained in:
Teknium 2025-11-04 02:15:32 -08:00 committed by GitHub
commit 9573b2ac2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 230 additions and 71 deletions

View file

@ -167,8 +167,8 @@ def _process_single_prompt(
ephemeral_system_prompt=config.get("ephemeral_system_prompt") ephemeral_system_prompt=config.get("ephemeral_system_prompt")
) )
# Run the agent # Run the agent with task_id to ensure each task gets its own isolated VM
result = agent.run_conversation(prompt) result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}")
# Extract tool usage statistics # Extract tool usage statistics
tool_stats = _extract_tool_stats(result["messages"]) tool_stats = _extract_tool_stats(result["messages"])

View file

@ -28,7 +28,7 @@ Usage:
import json import json
import asyncio import asyncio
from typing import Dict, Any, List from typing import Dict, Any, List, Optional
from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key
from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION
@ -480,13 +480,14 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
else: else:
return json.dumps({"error": f"Unknown web function: {function_name}"}) return json.dumps({"error": f"Unknown web function: {function_name}"})
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any]) -> str: def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
""" """
Handle function calls for terminal tools. Handle function calls for terminal tools.
Args: Args:
function_name (str): Name of the terminal function to call function_name (str): Name of the terminal function to call
function_args (Dict): Arguments for the function function_args (Dict): Arguments for the function
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
Returns: Returns:
str: Function result as JSON string str: Function result as JSON string
@ -498,7 +499,7 @@ def handle_terminal_function_call(function_name: str, function_args: Dict[str, A
idle_threshold = function_args.get("idle_threshold", 5.0) idle_threshold = function_args.get("idle_threshold", 5.0)
timeout = function_args.get("timeout") timeout = function_args.get("timeout")
return terminal_tool(command, input_keys, None, background, idle_threshold, timeout) return terminal_tool(command, input_keys, None, background, idle_threshold, timeout, task_id)
else: else:
return json.dumps({"error": f"Unknown terminal function: {function_name}"}) return json.dumps({"error": f"Unknown terminal function: {function_name}"})
@ -614,7 +615,7 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
return json.dumps({"error": f"Unknown image generation function: {function_name}"}) return json.dumps({"error": f"Unknown image generation function: {function_name}"})
def handle_function_call(function_name: str, function_args: Dict[str, Any]) -> str: def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
""" """
Main function call dispatcher that routes calls to appropriate toolsets. Main function call dispatcher that routes calls to appropriate toolsets.
@ -625,6 +626,7 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any]) -> s
Args: Args:
function_name (str): Name of the function to call function_name (str): Name of the function to call
function_args (Dict): Arguments for the function function_args (Dict): Arguments for the function
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
Returns: Returns:
str: Function result as JSON string str: Function result as JSON string
@ -639,7 +641,7 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any]) -> s
# Route terminal tools # Route terminal tools
elif function_name in ["terminal"]: elif function_name in ["terminal"]:
return handle_terminal_function_call(function_name, function_args) return handle_terminal_function_call(function_name, function_args, task_id)
# Route vision tools # Route vision tools
elif function_name in ["vision_analyze"]: elif function_name in ["vision_analyze"]:

View file

@ -43,6 +43,7 @@ else:
# Import our tool system # Import our tool system
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
from tools.terminal_tool import cleanup_vm
class AIAgent: class AIAgent:
@ -345,7 +346,8 @@ class AIAgent:
self, self,
user_message: str, user_message: str,
system_message: str = None, system_message: str = None,
conversation_history: List[Dict[str, Any]] = None conversation_history: List[Dict[str, Any]] = None,
task_id: str = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run a complete conversation with tool calling until completion. Run a complete conversation with tool calling until completion.
@ -354,10 +356,14 @@ class AIAgent:
user_message (str): The user's message/question user_message (str): The user's message/question
system_message (str): Custom system message (optional, overrides ephemeral_system_prompt if provided) system_message (str): Custom system message (optional, overrides ephemeral_system_prompt if provided)
conversation_history (List[Dict]): Previous conversation messages (optional) conversation_history (List[Dict]): Previous conversation messages (optional)
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional, auto-generated if not provided)
Returns: Returns:
Dict: Complete conversation result with final response and message history Dict: Complete conversation result with final response and message history
""" """
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
import uuid
effective_task_id = task_id or str(uuid.uuid4())
# Initialize conversation # Initialize conversation
messages = conversation_history or [] messages = conversation_history or []
@ -472,8 +478,8 @@ class AIAgent:
tool_start_time = time.time() tool_start_time = time.time()
# Execute the tool # Execute the tool with task_id to isolate VMs between concurrent tasks
function_result = handle_function_call(function_name, function_args) function_result = handle_function_call(function_name, function_args, effective_task_id)
tool_duration = time.time() - tool_start_time tool_duration = time.time() - tool_start_time
result_preview = function_result[:200] if len(function_result) > 200 else function_result result_preview = function_result[:200] if len(function_result) > 200 else function_result
@ -541,6 +547,13 @@ class AIAgent:
# Save trajectory if enabled # Save trajectory if enabled
self._save_trajectory(messages, user_message, completed) self._save_trajectory(messages, user_message, completed)
# Clean up VM for this task after conversation completes
try:
cleanup_vm(effective_task_id)
except Exception as e:
if self.verbose_logging:
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
return { return {
"final_response": final_response, "final_response": final_response,
"messages": messages, "messages": messages,

View file

@ -4,8 +4,12 @@ Terminal Tool Module
This module provides a single terminal tool using Hecate's VM infrastructure. This module provides a single terminal tool using Hecate's VM infrastructure.
It wraps Hecate's functionality to provide a simple interface for executing commands It wraps Hecate's functionality to provide a simple interface for executing commands
on Morph VMs with automatic lifecycle management. VMs live for 5 minutes after last use. on Morph VMs with automatic lifecycle management.
Timer resets with each use.
VM Lifecycle:
- VMs have a TTL (time to live) set at creation (default: 20 minutes)
- VMs are also cleaned up locally after 5 minutes of inactivity
- Timer resets with each use
Available tool: Available tool:
- terminal_tool: Execute commands with optional interactive session support - terminal_tool: Execute commands with optional interactive session support
@ -24,6 +28,8 @@ import json
import os import os
import uuid import uuid
import threading import threading
import time
import atexit
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
# Detailed description for the terminal tool based on Hermes Terminal system prompt # Detailed description for the terminal tool based on Hermes Terminal system prompt
@ -75,9 +81,137 @@ When commands enter interactive mode (vim, nano, less, git prompts, package mana
# Global state for VM lifecycle management # Global state for VM lifecycle management
# These persist across tool calls to enable session continuity # These persist across tool calls to enable session continuity
_active_instance = None # Changed to dictionaries keyed by task_id to prevent leakage between concurrent tasks
_active_context = None _active_instances: Dict[str, Any] = {}
_active_contexts: Dict[str, Any] = {}
_last_activity: Dict[str, float] = {} # Track last activity time for each VM
_instance_lock = threading.Lock() _instance_lock = threading.Lock()
_cleanup_thread = None
_cleanup_running = False
def _cleanup_inactive_vms(vm_lifetime_seconds: int = 300):
"""
Clean up VMs that have been inactive for longer than vm_lifetime_seconds.
This function should be called periodically by a background thread.
Args:
vm_lifetime_seconds: Maximum lifetime in seconds for inactive VMs (default: 300)
"""
global _active_instances, _active_contexts, _last_activity
current_time = time.time()
tasks_to_cleanup = []
with _instance_lock:
# Find all VMs that have been inactive for too long
for task_id, last_time in list(_last_activity.items()):
if current_time - last_time > vm_lifetime_seconds:
tasks_to_cleanup.append(task_id)
# Clean up the inactive VMs
for task_id in tasks_to_cleanup:
try:
if task_id in _active_instances:
instance = _active_instances[task_id]
# Terminate the VM instance
if hasattr(instance, 'terminate'):
instance.terminate()
elif hasattr(instance, 'stop'):
instance.stop()
elif hasattr(instance, 'delete'):
instance.delete()
# Remove from tracking dictionaries
del _active_instances[task_id]
print(f"[VM Cleanup] Terminated inactive VM for task: {task_id}")
if task_id in _active_contexts:
del _active_contexts[task_id]
if task_id in _last_activity:
del _last_activity[task_id]
except Exception as e:
print(f"[VM Cleanup] Error cleaning up VM for task {task_id}: {e}")
def _cleanup_thread_worker():
"""
Background thread worker that periodically cleans up inactive VMs.
Runs every 60 seconds.
"""
global _cleanup_running
while _cleanup_running:
try:
vm_lifetime = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300"))
_cleanup_inactive_vms(vm_lifetime)
except Exception as e:
print(f"[VM Cleanup] Error in cleanup thread: {e}")
# Sleep for 60 seconds, but check every second if we should stop
for _ in range(60):
if not _cleanup_running:
break
time.sleep(1)
def _start_cleanup_thread():
"""
Start the background cleanup thread if it's not already running.
"""
global _cleanup_thread, _cleanup_running
with _instance_lock:
if _cleanup_thread is None or not _cleanup_thread.is_alive():
_cleanup_running = True
_cleanup_thread = threading.Thread(target=_cleanup_thread_worker, daemon=True)
_cleanup_thread.start()
def _stop_cleanup_thread():
"""
Stop the background cleanup thread.
"""
global _cleanup_running
_cleanup_running = False
if _cleanup_thread is not None:
_cleanup_thread.join(timeout=5)
def cleanup_vm(task_id: str):
"""
Manually clean up a specific VM by task_id.
This should be called when a task is completed.
Args:
task_id: The task ID of the VM to clean up
"""
global _active_instances, _active_contexts, _last_activity
with _instance_lock:
try:
if task_id in _active_instances:
instance = _active_instances[task_id]
# Terminate the VM instance
if hasattr(instance, 'terminate'):
instance.terminate()
elif hasattr(instance, 'stop'):
instance.stop()
elif hasattr(instance, 'delete'):
instance.delete()
# Remove from tracking dictionaries
del _active_instances[task_id]
print(f"[VM Cleanup] Manually terminated VM for task: {task_id}")
if task_id in _active_contexts:
del _active_contexts[task_id]
if task_id in _last_activity:
del _last_activity[task_id]
except Exception as e:
print(f"[VM Cleanup] Error manually cleaning up VM for task {task_id}: {e}")
# Register cleanup on program exit
atexit.register(_stop_cleanup_thread)
def terminal_tool( def terminal_tool(
command: Optional[str] = None, command: Optional[str] = None,
@ -85,7 +219,8 @@ def terminal_tool(
session_id: Optional[str] = None, session_id: Optional[str] = None,
background: bool = False, background: bool = False,
idle_threshold: float = 5.0, idle_threshold: float = 5.0,
timeout: Optional[int] = None timeout: Optional[int] = None,
task_id: Optional[str] = None
) -> str: ) -> str:
""" """
Execute a command on a Morph VM with optional interactive session support. Execute a command on a Morph VM with optional interactive session support.
@ -101,6 +236,7 @@ def terminal_tool(
background: Whether to run the command in the background (default: False) background: Whether to run the command in the background (default: False)
idle_threshold: Seconds to wait for output before considering session idle (default: 5.0) idle_threshold: Seconds to wait for output before considering session idle (default: 5.0)
timeout: Command timeout in seconds (optional) timeout: Command timeout in seconds (optional)
task_id: Unique identifier for this task to isolate VMs between concurrent tasks (optional)
Returns: Returns:
str: JSON string containing command output, session info, exit code, and any errors str: JSON string containing command output, session info, exit code, and any errors
@ -120,7 +256,7 @@ def terminal_tool(
# Run a background task # Run a background task
>>> result = terminal_tool(command="sleep 60", background=True) >>> result = terminal_tool(command="sleep 60", background=True)
""" """
global _active_instance, _active_context global _active_instances, _active_contexts
try: try:
# Import required modules lazily so this module can be imported # Import required modules lazily so this module can be imported
@ -135,14 +271,13 @@ def terminal_tool(
return json.dumps({ return json.dumps({
"output": "", "output": "",
"screen": "", "screen": "",
"session_id": None,
"exit_code": -1, "exit_code": -1,
"error": f"Terminal tool is disabled due to import error: {import_error}", "error": f"Terminal tool is disabled due to import error: {import_error}"
"status": "disabled"
}) })
# Get configuration from environment # Get configuration from environment
vm_lifetime_seconds = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300")) vm_lifetime_seconds = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300"))
vm_ttl_seconds = int(os.getenv("HECATE_VM_TTL_SECONDS", "1200")) # 20 minutes default
snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "snapshot_defv9tjg") snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "snapshot_defv9tjg")
# Check API key # Check API key
@ -151,25 +286,37 @@ def terminal_tool(
return json.dumps({ return json.dumps({
"output": "", "output": "",
"screen": "", "screen": "",
"session_id": None,
"exit_code": -1, "exit_code": -1,
"error": "MORPH_API_KEY environment variable not set", "error": "MORPH_API_KEY environment variable not set"
"status": "disabled"
}) })
# Get or create VM instance and execution context # Use task_id to isolate VMs between concurrent tasks
# If no task_id provided, use "default" for backward compatibility
effective_task_id = task_id or "default"
# Start the cleanup thread if not already running
_start_cleanup_thread()
# Get or create VM instance and execution context per task
# This is critical for interactive session support - the context must persist! # This is critical for interactive session support - the context must persist!
with _instance_lock: with _instance_lock:
if _active_instance is None: if effective_task_id not in _active_instances:
morph_client = MorphCloudClient(api_key=morph_api_key) morph_client = MorphCloudClient(api_key=morph_api_key)
_active_instance = morph_client.instances.start(snapshot_id=snapshot_id) _active_instances[effective_task_id] = morph_client.instances.start(
snapshot_id=snapshot_id,
ttl_seconds=vm_ttl_seconds,
ttl_action="stop"
)
# Get or create persistent execution context # Get or create persistent execution context per task
if _active_context is None: if effective_task_id not in _active_contexts:
_active_context = ExecutionContext() _active_contexts[effective_task_id] = ExecutionContext()
instance = _active_instance # Update last activity time for this VM (resets the inactivity timer)
ctx = _active_context _last_activity[effective_task_id] = time.time()
instance = _active_instances[effective_task_id]
ctx = _active_contexts[effective_task_id]
# Build tool input based on provided parameters # Build tool input based on provided parameters
tool_input = {} tool_input = {}
@ -208,15 +355,13 @@ def terminal_tool(
ctx=ctx ctx=ctx
) )
# Format the result with all possible fields # Format the result with only essential fields for the LLM
# Map hecate's "stdout" to "output" for compatibility # Map hecate's "stdout" to "output" for compatibility
formatted_result = { formatted_result = {
"output": result.get("stdout", result.get("output", "")), "output": result.get("stdout", result.get("output", "")),
"screen": result.get("screen", ""), "screen": result.get("screen", ""),
"session_id": result.get("session_id"),
"exit_code": result.get("returncode", result.get("exit_code", -1)), "exit_code": result.get("returncode", result.get("exit_code", -1)),
"error": result.get("error"), "error": result.get("error")
"status": "active" if result.get("session_id") else "ended"
} }
return json.dumps(formatted_result) return json.dumps(formatted_result)
@ -225,10 +370,8 @@ def terminal_tool(
return json.dumps({ return json.dumps({
"output": "", "output": "",
"screen": "", "screen": "",
"session_id": None,
"exit_code": -1, "exit_code": -1,
"error": f"Failed to execute terminal command: {str(e)}", "error": f"Failed to execute terminal command: {str(e)}"
"status": "error"
}) })
def check_hecate_requirements() -> bool: def check_hecate_requirements() -> bool:
@ -304,5 +447,6 @@ if __name__ == "__main__":
print("\nEnvironment Variables:") print("\nEnvironment Variables:")
print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}") print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}")
print(f" OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not set (optional)'}") print(f" OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not set (optional)'}")
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300)") print(f" HECATE_VM_TTL_SECONDS: {os.getenv('HECATE_VM_TTL_SECONDS', '1200')} (default: 1200 / 20 minutes)")
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300 / 5 minutes)")
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_defv9tjg')} (default: snapshot_defv9tjg)") print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_defv9tjg')} (default: snapshot_defv9tjg)")