Add tool progress notifications for messaging channels
- Introduced a new callback mechanism in the AIAgent class to send tool progress messages during execution, enhancing user feedback in messaging platforms. - Updated the GatewayRunner to support tool progress notifications, allowing users to enable or disable this feature via environment variables. - Enhanced the CLI setup wizard to prompt users for enabling tool progress messages and selecting the notification mode (all or new), improving configuration options. - Updated relevant documentation to reflect the new features and configuration settings for tool progress notifications.
This commit is contained in:
parent
a09b018bd5
commit
e7f0ffbf5d
4 changed files with 141 additions and 4 deletions
|
|
@ -349,6 +349,7 @@ class GatewayRunner:
|
||||||
This is run in a thread pool to not block the event loop.
|
This is run in a thread pool to not block the event loop.
|
||||||
"""
|
"""
|
||||||
from run_agent import AIAgent
|
from run_agent import AIAgent
|
||||||
|
import queue
|
||||||
|
|
||||||
# Determine toolset based on platform
|
# Determine toolset based on platform
|
||||||
toolset_map = {
|
toolset_map = {
|
||||||
|
|
@ -359,6 +360,76 @@ class GatewayRunner:
|
||||||
}
|
}
|
||||||
toolset = toolset_map.get(source.platform, "hermes-telegram")
|
toolset = toolset_map.get(source.platform, "hermes-telegram")
|
||||||
|
|
||||||
|
# Check if tool progress notifications are enabled
|
||||||
|
tool_progress_enabled = os.getenv("HERMES_TOOL_PROGRESS", "").lower() in ("1", "true", "yes")
|
||||||
|
progress_mode = os.getenv("HERMES_TOOL_PROGRESS_MODE", "new") # "all" or "new" (only new tools)
|
||||||
|
|
||||||
|
# Queue for progress messages (thread-safe)
|
||||||
|
progress_queue = queue.Queue() if tool_progress_enabled else None
|
||||||
|
last_tool = [None] # Mutable container for tracking in closure
|
||||||
|
|
||||||
|
def progress_callback(tool_name: str, preview: str = None):
|
||||||
|
"""Callback invoked by agent when a tool is called."""
|
||||||
|
if not progress_queue:
|
||||||
|
return
|
||||||
|
|
||||||
|
# "new" mode: only report when tool changes
|
||||||
|
if progress_mode == "new" and tool_name == last_tool[0]:
|
||||||
|
return
|
||||||
|
last_tool[0] = tool_name
|
||||||
|
|
||||||
|
# Build progress message
|
||||||
|
tool_emojis = {
|
||||||
|
"terminal": "💻",
|
||||||
|
"web_search": "🔍",
|
||||||
|
"web_extract": "📄",
|
||||||
|
"read_file": "📖",
|
||||||
|
"write_file": "✍️",
|
||||||
|
"list_directory": "📂",
|
||||||
|
"image_generate": "🎨",
|
||||||
|
"browser_navigate": "🌐",
|
||||||
|
"browser_click": "👆",
|
||||||
|
"moa_query": "🧠",
|
||||||
|
}
|
||||||
|
emoji = tool_emojis.get(tool_name, "⚙️")
|
||||||
|
|
||||||
|
if tool_name == "terminal" and preview:
|
||||||
|
msg = f"{emoji} `{preview}`..."
|
||||||
|
else:
|
||||||
|
msg = f"{emoji} {tool_name}..."
|
||||||
|
|
||||||
|
progress_queue.put(msg)
|
||||||
|
|
||||||
|
# Background task to send progress messages
|
||||||
|
async def send_progress_messages():
|
||||||
|
if not progress_queue:
|
||||||
|
return
|
||||||
|
|
||||||
|
adapter = self.adapters.get(source.platform)
|
||||||
|
if not adapter:
|
||||||
|
return
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Non-blocking check with small timeout
|
||||||
|
msg = progress_queue.get_nowait()
|
||||||
|
await adapter.send(chat_id=source.chat_id, content=msg)
|
||||||
|
await asyncio.sleep(0.5) # Small delay between messages
|
||||||
|
except queue.Empty:
|
||||||
|
await asyncio.sleep(0.3) # Check again soon
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Drain remaining messages
|
||||||
|
while not progress_queue.empty():
|
||||||
|
try:
|
||||||
|
msg = progress_queue.get_nowait()
|
||||||
|
await adapter.send(chat_id=source.chat_id, content=msg)
|
||||||
|
except:
|
||||||
|
break
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Gateway] Progress message error: {e}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
def run_sync():
|
def run_sync():
|
||||||
# Read from env var or use default (same as CLI)
|
# Read from env var or use default (same as CLI)
|
||||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||||
|
|
@ -370,6 +441,7 @@ class GatewayRunner:
|
||||||
enabled_toolsets=[toolset],
|
enabled_toolsets=[toolset],
|
||||||
ephemeral_system_prompt=context_prompt,
|
ephemeral_system_prompt=context_prompt,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we have history, we need to restore it
|
# If we have history, we need to restore it
|
||||||
|
|
@ -379,9 +451,23 @@ class GatewayRunner:
|
||||||
result = agent.run_conversation(message)
|
result = agent.run_conversation(message)
|
||||||
return result.get("final_response", "(No response)")
|
return result.get("final_response", "(No response)")
|
||||||
|
|
||||||
# Run in thread pool to not block
|
# Start progress message sender if enabled
|
||||||
loop = asyncio.get_event_loop()
|
progress_task = None
|
||||||
response = await loop.run_in_executor(None, run_sync)
|
if tool_progress_enabled:
|
||||||
|
progress_task = asyncio.create_task(send_progress_messages())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run in thread pool to not block
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
response = await loop.run_in_executor(None, run_sync)
|
||||||
|
finally:
|
||||||
|
# Stop progress sender
|
||||||
|
if progress_task:
|
||||||
|
progress_task.cancel()
|
||||||
|
try:
|
||||||
|
await progress_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -203,11 +203,23 @@ OPTIONAL_ENV_VARS = {
|
||||||
},
|
},
|
||||||
# Agent configuration
|
# Agent configuration
|
||||||
"HERMES_MAX_ITERATIONS": {
|
"HERMES_MAX_ITERATIONS": {
|
||||||
"description": "Maximum tool-calling iterations per conversation (default: 25 for messaging, 10 for CLI)",
|
"description": "Maximum tool-calling iterations per conversation (default: 60)",
|
||||||
"prompt": "Max iterations",
|
"prompt": "Max iterations",
|
||||||
"url": None,
|
"url": None,
|
||||||
"password": False,
|
"password": False,
|
||||||
},
|
},
|
||||||
|
"HERMES_TOOL_PROGRESS": {
|
||||||
|
"description": "Send tool progress messages in messaging channels (true/false)",
|
||||||
|
"prompt": "Enable tool progress messages",
|
||||||
|
"url": None,
|
||||||
|
"password": False,
|
||||||
|
},
|
||||||
|
"HERMES_TOOL_PROGRESS_MODE": {
|
||||||
|
"description": "Progress mode: 'all' (every tool) or 'new' (only when tool changes)",
|
||||||
|
"prompt": "Progress mode (all/new)",
|
||||||
|
"url": None,
|
||||||
|
"password": False,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -713,6 +713,28 @@ def run_setup_wizard(args):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print_warning("Invalid number, keeping current value")
|
print_warning("Invalid number, keeping current value")
|
||||||
|
|
||||||
|
# Tool progress notifications (for messaging)
|
||||||
|
print_info("")
|
||||||
|
print_info("Tool Progress Notifications (Messaging only)")
|
||||||
|
print_info("Send status messages when the agent uses tools.")
|
||||||
|
print_info("Example: '💻 ls -la...' or '🔍 web_search...'")
|
||||||
|
|
||||||
|
current_progress = get_env_value('HERMES_TOOL_PROGRESS') or 'false'
|
||||||
|
if prompt_yes_no("Enable tool progress messages?", current_progress.lower() in ('1', 'true', 'yes')):
|
||||||
|
save_env_value("HERMES_TOOL_PROGRESS", "true")
|
||||||
|
|
||||||
|
# Progress mode
|
||||||
|
current_mode = get_env_value('HERMES_TOOL_PROGRESS_MODE') or 'new'
|
||||||
|
print_info(" Mode options:")
|
||||||
|
print_info(" 'new' - Only when switching tools (less spam)")
|
||||||
|
print_info(" 'all' - Every tool call")
|
||||||
|
mode = prompt(" Progress mode", current_mode)
|
||||||
|
if mode.lower() in ('all', 'new'):
|
||||||
|
save_env_value("HERMES_TOOL_PROGRESS_MODE", mode.lower())
|
||||||
|
print_success("Tool progress enabled")
|
||||||
|
else:
|
||||||
|
save_env_value("HERMES_TOOL_PROGRESS", "false")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Step 6: Context Compression
|
# Step 6: Context Compression
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
|
||||||
17
run_agent.py
17
run_agent.py
|
|
@ -600,6 +600,7 @@ class AIAgent:
|
||||||
providers_order: List[str] = None,
|
providers_order: List[str] = None,
|
||||||
provider_sort: str = None,
|
provider_sort: str = None,
|
||||||
session_id: str = None,
|
session_id: str = None,
|
||||||
|
tool_progress_callback: callable = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the AI Agent.
|
Initialize the AI Agent.
|
||||||
|
|
@ -623,6 +624,7 @@ class AIAgent:
|
||||||
providers_order (List[str]): OpenRouter providers to try in order (optional)
|
providers_order (List[str]): OpenRouter providers to try in order (optional)
|
||||||
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
||||||
session_id (str): Pre-generated session ID for logging (optional, auto-generated if not provided)
|
session_id (str): Pre-generated session ID for logging (optional, auto-generated if not provided)
|
||||||
|
tool_progress_callback (callable): Callback function(tool_name, args_preview) for progress notifications
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
|
|
@ -634,6 +636,8 @@ class AIAgent:
|
||||||
self.log_prefix_chars = log_prefix_chars
|
self.log_prefix_chars = log_prefix_chars
|
||||||
self.log_prefix = f"{log_prefix} " if log_prefix else ""
|
self.log_prefix = f"{log_prefix} " if log_prefix else ""
|
||||||
self.base_url = base_url or "" # Store for OpenRouter detection
|
self.base_url = base_url or "" # Store for OpenRouter detection
|
||||||
|
self.tool_progress_callback = tool_progress_callback
|
||||||
|
self._last_reported_tool = None # Track for "new tool" mode
|
||||||
|
|
||||||
# Store OpenRouter provider preferences
|
# Store OpenRouter provider preferences
|
||||||
self.providers_allowed = providers_allowed
|
self.providers_allowed = providers_allowed
|
||||||
|
|
@ -1793,6 +1797,19 @@ class AIAgent:
|
||||||
args_str = json.dumps(function_args, ensure_ascii=False)
|
args_str = json.dumps(function_args, ensure_ascii=False)
|
||||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
||||||
|
|
||||||
|
# Fire progress callback if registered (for messaging platforms)
|
||||||
|
if self.tool_progress_callback:
|
||||||
|
try:
|
||||||
|
# Build preview for terminal commands
|
||||||
|
if function_name == "terminal":
|
||||||
|
cmd = function_args.get("command", "")
|
||||||
|
preview = cmd[:50] + "..." if len(cmd) > 50 else cmd
|
||||||
|
else:
|
||||||
|
preview = None
|
||||||
|
self.tool_progress_callback(function_name, preview)
|
||||||
|
except Exception as cb_err:
|
||||||
|
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||||
|
|
||||||
tool_start_time = time.time()
|
tool_start_time = time.time()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue