Implement interrupt handling for long-running tool executions in AIAgent
- Added functionality to signal and terminate long-running terminal commands when a new user message is received, allowing for immediate agent response. - Introduced a global interrupt event in the terminal tool to facilitate early termination of subprocesses. - Updated the AIAgent class to handle interrupts gracefully, ensuring that remaining tool calls are skipped and appropriate messages are returned to maintain valid message sequences.
This commit is contained in:
parent
140d609e0c
commit
cfe2f3fe15
3 changed files with 94 additions and 14 deletions
|
|
@ -601,13 +601,9 @@ class GatewayRunner:
|
||||||
if adapter and hasattr(adapter, '_active_sessions') and source.chat_id in adapter._active_sessions:
|
if adapter and hasattr(adapter, '_active_sessions') and source.chat_id in adapter._active_sessions:
|
||||||
adapter._active_sessions[source.chat_id].clear()
|
adapter._active_sessions[source.chat_id].clear()
|
||||||
|
|
||||||
# Add an indicator to the response
|
# Don't send the interrupted response to the user — it's just noise
|
||||||
if response:
|
# like "Operation interrupted." They already know they sent a new
|
||||||
response = response + "\n\n---\n_[Interrupted - processing your new message]_"
|
# message, so go straight to processing it.
|
||||||
|
|
||||||
# Send the interrupted response first
|
|
||||||
if adapter and response:
|
|
||||||
await adapter.send(chat_id=source.chat_id, content=response)
|
|
||||||
|
|
||||||
# Now process the pending message with updated history
|
# Now process the pending message with updated history
|
||||||
updated_history = result.get("messages", history)
|
updated_history = result.get("messages", history)
|
||||||
|
|
|
||||||
23
run_agent.py
23
run_agent.py
|
|
@ -49,7 +49,7 @@ elif not os.getenv("HERMES_QUIET"):
|
||||||
|
|
||||||
# Import our tool system
|
# Import our tool system
|
||||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||||
from tools.terminal_tool import cleanup_vm
|
from tools.terminal_tool import cleanup_vm, set_interrupt_event as _set_terminal_interrupt
|
||||||
from tools.browser_tool import cleanup_browser
|
from tools.browser_tool import cleanup_browser
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
@ -1549,6 +1549,9 @@ class AIAgent:
|
||||||
Call this from another thread (e.g., input handler, message receiver)
|
Call this from another thread (e.g., input handler, message receiver)
|
||||||
to gracefully stop the agent and process a new message.
|
to gracefully stop the agent and process a new message.
|
||||||
|
|
||||||
|
Also signals long-running tool executions (e.g. terminal commands)
|
||||||
|
to terminate early, so the agent can respond immediately.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: Optional new message that triggered the interrupt.
|
message: Optional new message that triggered the interrupt.
|
||||||
If provided, the agent will include this in its response context.
|
If provided, the agent will include this in its response context.
|
||||||
|
|
@ -1565,6 +1568,8 @@ class AIAgent:
|
||||||
"""
|
"""
|
||||||
self._interrupt_requested = True
|
self._interrupt_requested = True
|
||||||
self._interrupt_message = message
|
self._interrupt_message = message
|
||||||
|
# Signal the terminal tool to kill any running subprocess immediately
|
||||||
|
_set_terminal_interrupt(True)
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
print(f"\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
|
print(f"\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
|
||||||
|
|
||||||
|
|
@ -1572,6 +1577,7 @@ class AIAgent:
|
||||||
"""Clear any pending interrupt request."""
|
"""Clear any pending interrupt request."""
|
||||||
self._interrupt_requested = False
|
self._interrupt_requested = False
|
||||||
self._interrupt_message = None
|
self._interrupt_message = None
|
||||||
|
_set_terminal_interrupt(False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_interrupted(self) -> bool:
|
def is_interrupted(self) -> bool:
|
||||||
|
|
@ -2309,6 +2315,21 @@ class AIAgent:
|
||||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
||||||
|
|
||||||
|
# Check for interrupt between tool calls - skip remaining
|
||||||
|
# tools so the agent can respond to the user immediately
|
||||||
|
if self._interrupt_requested and i < len(assistant_message.tool_calls):
|
||||||
|
remaining = len(assistant_message.tool_calls) - i
|
||||||
|
print(f"{self.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)")
|
||||||
|
# Add placeholder results for skipped tool calls so the
|
||||||
|
# message sequence stays valid (assistant tool_calls need matching tool results)
|
||||||
|
for skipped_tc in assistant_message.tool_calls[i:]:
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"content": "[Tool execution skipped - user sent a new message]",
|
||||||
|
"tool_call_id": skipped_tc.id
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
# Delay between tool calls
|
# Delay between tool calls
|
||||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||||
time.sleep(self.tool_delay)
|
time.sleep(self.tool_delay)
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ Usage:
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -39,6 +40,28 @@ import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Global interrupt event: set by the agent when a user interrupt arrives.
|
||||||
|
# The terminal tool polls this during command execution so it can kill
|
||||||
|
# long-running subprocesses immediately instead of blocking until timeout.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
_interrupt_event = threading.Event()
|
||||||
|
|
||||||
|
|
||||||
|
def set_interrupt_event(active: bool) -> None:
|
||||||
|
"""Called by the agent to signal or clear the interrupt."""
|
||||||
|
if active:
|
||||||
|
_interrupt_event.set()
|
||||||
|
else:
|
||||||
|
_interrupt_event.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def is_interrupted() -> bool:
|
||||||
|
"""Check if an interrupt has been requested."""
|
||||||
|
return _interrupt_event.is_set()
|
||||||
|
|
||||||
|
|
||||||
# Add mini-swe-agent to path if not installed
|
# Add mini-swe-agent to path if not installed
|
||||||
mini_swe_path = Path(__file__).parent.parent / "mini-swe-agent" / "src"
|
mini_swe_path = Path(__file__).parent.parent / "mini-swe-agent" / "src"
|
||||||
if mini_swe_path.exists():
|
if mini_swe_path.exists():
|
||||||
|
|
@ -599,7 +622,13 @@ class _LocalEnvironment:
|
||||||
self.env = env or {}
|
self.env = env or {}
|
||||||
|
|
||||||
def execute(self, command: str, cwd: str = "", *, timeout: int | None = None) -> dict:
|
def execute(self, command: str, cwd: str = "", *, timeout: int | None = None) -> dict:
|
||||||
"""Execute a command locally with sudo support."""
|
"""
|
||||||
|
Execute a command locally with sudo support.
|
||||||
|
|
||||||
|
Uses Popen + polling so the global interrupt event can kill the
|
||||||
|
process early when the user sends a new message, instead of
|
||||||
|
blocking for the full timeout.
|
||||||
|
"""
|
||||||
work_dir = cwd or self.cwd or os.getcwd()
|
work_dir = cwd or self.cwd or os.getcwd()
|
||||||
effective_timeout = timeout or self.timeout
|
effective_timeout = timeout or self.timeout
|
||||||
|
|
||||||
|
|
@ -607,22 +636,56 @@ class _LocalEnvironment:
|
||||||
exec_command = _transform_sudo_command(command)
|
exec_command = _transform_sudo_command(command)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
proc = subprocess.Popen(
|
||||||
exec_command,
|
exec_command,
|
||||||
shell=True,
|
shell=True,
|
||||||
text=True,
|
text=True,
|
||||||
cwd=work_dir,
|
cwd=work_dir,
|
||||||
env=os.environ | self.env,
|
env=os.environ | self.env,
|
||||||
timeout=effective_timeout,
|
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
errors="replace",
|
errors="replace",
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
stdin=subprocess.DEVNULL, # Prevent hanging on interactive prompts
|
stdin=subprocess.DEVNULL, # Prevent hanging on interactive prompts
|
||||||
|
# Start in a new process group so we can kill the whole tree
|
||||||
|
preexec_fn=os.setsid,
|
||||||
)
|
)
|
||||||
return {"output": result.stdout, "returncode": result.returncode}
|
|
||||||
except subprocess.TimeoutExpired:
|
deadline = time.monotonic() + effective_timeout
|
||||||
return {"output": f"Command timed out after {effective_timeout}s", "returncode": 124}
|
|
||||||
|
# Poll every 200ms so we notice interrupts quickly
|
||||||
|
while proc.poll() is None:
|
||||||
|
if _interrupt_event.is_set():
|
||||||
|
# User sent a new message — kill the process tree and return
|
||||||
|
# what we have so far
|
||||||
|
try:
|
||||||
|
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||||
|
except (ProcessLookupError, PermissionError):
|
||||||
|
proc.kill()
|
||||||
|
# Grab any partial output
|
||||||
|
partial, _ = proc.communicate(timeout=2)
|
||||||
|
output = partial or ""
|
||||||
|
return {
|
||||||
|
"output": output + "\n[Command interrupted — user sent a new message]",
|
||||||
|
"returncode": 130 # Standard interrupted exit code
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.monotonic() > deadline:
|
||||||
|
# Timeout — kill process tree
|
||||||
|
try:
|
||||||
|
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||||
|
except (ProcessLookupError, PermissionError):
|
||||||
|
proc.kill()
|
||||||
|
proc.communicate(timeout=2)
|
||||||
|
return {"output": f"Command timed out after {effective_timeout}s", "returncode": 124}
|
||||||
|
|
||||||
|
# Short sleep to avoid busy-waiting
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
# Process finished normally — read all output
|
||||||
|
stdout, _ = proc.communicate()
|
||||||
|
return {"output": stdout or "", "returncode": proc.returncode}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"output": f"Execution error: {str(e)}", "returncode": 1}
|
return {"output": f"Execution error: {str(e)}", "returncode": 1}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue