feat: add /voice command for auto voice reply in Telegram gateway
- /voice on: reply with voice when user sends voice messages - /voice tts: reply with voice to all messages - /voice off: disable, text-only replies - /voice status: show current mode - Per-chat state persisted to gateway_voice_mode.json - Dedup: skips auto-reply if agent already called text_to_speech tool - drop_pending_updates=True to ignore stale Telegram messages on restart - 25 tests covering command handler, reply logic, and edge cases
This commit is contained in:
parent
8aab13d12d
commit
d80da5ddd8
3 changed files with 434 additions and 5 deletions
|
|
@ -150,7 +150,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
drop_pending_updates=True,
|
||||
)
|
||||
|
||||
# Register bot commands so Telegram shows a hint menu when users type /
|
||||
try:
|
||||
|
|
@ -174,6 +177,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("voice", "Toggle voice reply mode"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
])
|
||||
except Exception as e:
|
||||
|
|
|
|||
148
gateway/run.py
148
gateway/run.py
|
|
@ -14,12 +14,15 @@ Usage:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import signal
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
|
@ -280,6 +283,9 @@ class GatewayRunner:
|
|||
from gateway.hooks import HookRegistry
|
||||
self.hooks = HookRegistry()
|
||||
|
||||
# Per-chat voice reply mode: "off" | "voice_only" | "all"
|
||||
self._voice_mode: Dict[str, str] = self._load_voice_modes()
|
||||
|
||||
def _get_or_create_gateway_honcho(self, session_key: str):
|
||||
"""Return a persistent Honcho manager/config pair for this gateway session."""
|
||||
if not hasattr(self, "_honcho_managers"):
|
||||
|
|
@ -335,6 +341,27 @@ class GatewayRunner:
|
|||
for session_key in list(managers.keys()):
|
||||
self._shutdown_gateway_honcho(session_key)
|
||||
|
||||
# -- Voice mode persistence ------------------------------------------
|
||||
|
||||
_VOICE_MODE_PATH = _hermes_home / "gateway_voice_mode.json"
|
||||
|
||||
def _load_voice_modes(self) -> Dict[str, str]:
|
||||
try:
|
||||
return json.loads(self._VOICE_MODE_PATH.read_text())
|
||||
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
||||
return {}
|
||||
|
||||
def _save_voice_modes(self) -> None:
|
||||
try:
|
||||
self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._VOICE_MODE_PATH.write_text(
|
||||
json.dumps(self._voice_mode, indent=2)
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to save voice modes: %s", e)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def _flush_memories_for_session(self, old_session_id: str):
|
||||
"""Prompt the agent to save memories/skills before context is lost.
|
||||
|
||||
|
|
@ -887,7 +914,7 @@ class GatewayRunner:
|
|||
7. Return response
|
||||
"""
|
||||
source = event.source
|
||||
|
||||
|
||||
# Check if user is authorized
|
||||
if not self._is_user_authorized(source):
|
||||
logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value)
|
||||
|
|
@ -939,7 +966,7 @@ class GatewayRunner:
|
|||
"personality", "retry", "undo", "sethome", "set-home",
|
||||
"compress", "usage", "insights", "reload-mcp", "reload_mcp",
|
||||
"update", "title", "resume", "provider", "rollback",
|
||||
"background", "reasoning"}
|
||||
"background", "reasoning", "voice"}
|
||||
if command and command in _known_commands:
|
||||
await self.hooks.emit(f"command:{command}", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
|
|
@ -1010,7 +1037,11 @@ class GatewayRunner:
|
|||
|
||||
if command == "reasoning":
|
||||
return await self._handle_reasoning_command(event)
|
||||
|
||||
|
||||
if command == "voice":
|
||||
return await self._handle_voice_command(event)
|
||||
|
||||
|
||||
# User-defined quick commands (bypass agent loop, no LLM call)
|
||||
if command:
|
||||
quick_commands = self.config.get("quick_commands", {})
|
||||
|
|
@ -1568,7 +1599,28 @@ class GatewayRunner:
|
|||
session_entry.session_key,
|
||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||
)
|
||||
|
||||
|
||||
# Auto voice reply: send TTS audio before the text response
|
||||
chat_id = source.chat_id
|
||||
voice_mode = self._voice_mode.get(chat_id, "off")
|
||||
is_voice_input = (event.message_type == MessageType.VOICE)
|
||||
should_voice_reply = (
|
||||
(voice_mode == "all")
|
||||
or (voice_mode == "voice_only" and is_voice_input)
|
||||
)
|
||||
if should_voice_reply and response and not response.startswith("Error:"):
|
||||
# Skip if agent already called TTS tool (avoid double voice)
|
||||
has_agent_tts = any(
|
||||
msg.get("role") == "assistant"
|
||||
and any(
|
||||
tc.get("function", {}).get("name") == "text_to_speech"
|
||||
for tc in (msg.get("tool_calls") or [])
|
||||
)
|
||||
for msg in agent_messages
|
||||
)
|
||||
if not has_agent_tts:
|
||||
await self._send_voice_reply(event, response)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -1677,6 +1729,7 @@ class GatewayRunner:
|
|||
"`/reasoning [level|show|hide]` — Set reasoning effort or toggle display",
|
||||
"`/rollback [number]` — List or restore filesystem checkpoints",
|
||||
"`/background <prompt>` — Run a prompt in a separate background session",
|
||||
"`/voice [on|off|tts|status]` — Toggle voice reply mode",
|
||||
"`/reload-mcp` — Reload MCP servers from config",
|
||||
"`/update` — Update Hermes Agent to the latest version",
|
||||
"`/help` — Show this message",
|
||||
|
|
@ -2052,6 +2105,93 @@ class GatewayRunner:
|
|||
f"Cron jobs and cross-platform messages will be delivered here."
|
||||
)
|
||||
|
||||
async def _handle_voice_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /voice [on|off|tts|status] command."""
|
||||
args = event.get_command_args().strip().lower()
|
||||
chat_id = event.source.chat_id
|
||||
|
||||
if args in ("on", "enable"):
|
||||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
return (
|
||||
"Voice mode enabled.\n"
|
||||
"I'll reply with voice when you send voice messages.\n"
|
||||
"Use /voice tts to get voice replies for all messages."
|
||||
)
|
||||
elif args in ("off", "disable"):
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._save_voice_modes()
|
||||
return "Voice mode disabled. Text-only replies."
|
||||
elif args == "tts":
|
||||
self._voice_mode[chat_id] = "all"
|
||||
self._save_voice_modes()
|
||||
return (
|
||||
"Auto-TTS enabled.\n"
|
||||
"All replies will include a voice message."
|
||||
)
|
||||
elif args == "status":
|
||||
mode = self._voice_mode.get(chat_id, "off")
|
||||
labels = {
|
||||
"off": "Off (text only)",
|
||||
"voice_only": "On (voice reply to voice messages)",
|
||||
"all": "TTS (voice reply to all messages)",
|
||||
}
|
||||
return f"Voice mode: {labels.get(mode, mode)}"
|
||||
else:
|
||||
# Toggle: off → on, on/all → off
|
||||
current = self._voice_mode.get(chat_id, "off")
|
||||
if current == "off":
|
||||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
return "Voice mode enabled."
|
||||
else:
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._save_voice_modes()
|
||||
return "Voice mode disabled."
|
||||
|
||||
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
|
||||
"""Generate TTS audio and send as a voice message before the text reply."""
|
||||
try:
|
||||
from tools.tts_tool import text_to_speech_tool, _strip_markdown_for_tts
|
||||
|
||||
tts_text = _strip_markdown_for_tts(text[:4000])
|
||||
if not tts_text:
|
||||
return
|
||||
|
||||
ogg_path = os.path.join(
|
||||
tempfile.gettempdir(), "hermes_voice",
|
||||
f"tts_reply_{int(time.time())}_{id(event) % 10000}.ogg",
|
||||
)
|
||||
os.makedirs(os.path.dirname(ogg_path), exist_ok=True)
|
||||
|
||||
result_json = await asyncio.to_thread(
|
||||
text_to_speech_tool, text=tts_text, output_path=ogg_path
|
||||
)
|
||||
result = json.loads(result_json)
|
||||
|
||||
if not result.get("success") or not os.path.isfile(ogg_path):
|
||||
logger.warning("Auto voice reply TTS failed: %s", result.get("error"))
|
||||
return
|
||||
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
if adapter and hasattr(adapter, "send_voice"):
|
||||
_thread_md = (
|
||||
{"thread_id": event.source.thread_id}
|
||||
if event.source.thread_id else None
|
||||
)
|
||||
await adapter.send_voice(
|
||||
event.source.chat_id,
|
||||
audio_path=ogg_path,
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_md,
|
||||
)
|
||||
try:
|
||||
os.unlink(ogg_path)
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Auto voice reply failed: %s", e)
|
||||
|
||||
async def _handle_rollback_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /rollback command — list or restore filesystem checkpoints."""
|
||||
from tools.checkpoint_manager import CheckpointManager, format_checkpoint_list
|
||||
|
|
|
|||
285
tests/gateway/test_voice_command.py
Normal file
285
tests/gateway/test_voice_command.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""Tests for the /voice command and auto voice reply in the gateway."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_event(text: str = "", message_type=MessageType.TEXT, chat_id="123") -> MessageEvent:
|
||||
source = SessionSource(
|
||||
chat_id=chat_id,
|
||||
user_id="user1",
|
||||
platform=MagicMock(),
|
||||
)
|
||||
source.platform.value = "telegram"
|
||||
source.thread_id = None
|
||||
event = MessageEvent(text=text, message_type=message_type, source=source)
|
||||
event.message_id = "msg42"
|
||||
return event
|
||||
|
||||
|
||||
def _make_runner(tmp_path):
|
||||
"""Create a bare GatewayRunner without calling __init__."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._VOICE_MODE_PATH = tmp_path / "gateway_voice_mode.json"
|
||||
runner._session_db = None
|
||||
runner.session_store = MagicMock()
|
||||
return runner
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# /voice command handler
|
||||
# =====================================================================
|
||||
|
||||
class TestHandleVoiceCommand:
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self, tmp_path):
|
||||
return _make_runner(tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_on(self, runner):
|
||||
event = _make_event("/voice on")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "enabled" in result.lower()
|
||||
assert runner._voice_mode["123"] == "voice_only"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_off(self, runner):
|
||||
runner._voice_mode["123"] = "voice_only"
|
||||
event = _make_event("/voice off")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "disabled" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_tts(self, runner):
|
||||
event = _make_event("/voice tts")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "tts" in result.lower()
|
||||
assert runner._voice_mode["123"] == "all"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_status_off(self, runner):
|
||||
event = _make_event("/voice status")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "off" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_status_on(self, runner):
|
||||
runner._voice_mode["123"] = "voice_only"
|
||||
event = _make_event("/voice status")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "voice reply" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_off_to_on(self, runner):
|
||||
event = _make_event("/voice")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "enabled" in result.lower()
|
||||
assert runner._voice_mode["123"] == "voice_only"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_on_to_off(self, runner):
|
||||
runner._voice_mode["123"] = "voice_only"
|
||||
event = _make_event("/voice")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "disabled" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_saved(self, runner):
|
||||
event = _make_event("/voice on")
|
||||
await runner._handle_voice_command(event)
|
||||
assert runner._VOICE_MODE_PATH.exists()
|
||||
data = json.loads(runner._VOICE_MODE_PATH.read_text())
|
||||
assert data["123"] == "voice_only"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_loaded(self, runner):
|
||||
runner._VOICE_MODE_PATH.write_text(json.dumps({"456": "all"}))
|
||||
loaded = runner._load_voice_modes()
|
||||
assert loaded == {"456": "all"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_chat_isolation(self, runner):
|
||||
e1 = _make_event("/voice on", chat_id="aaa")
|
||||
e2 = _make_event("/voice tts", chat_id="bbb")
|
||||
await runner._handle_voice_command(e1)
|
||||
await runner._handle_voice_command(e2)
|
||||
assert runner._voice_mode["aaa"] == "voice_only"
|
||||
assert runner._voice_mode["bbb"] == "all"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Auto voice reply decision logic
|
||||
# =====================================================================
|
||||
|
||||
class TestAutoVoiceReply:
|
||||
"""Test the should_voice_reply decision logic (extracted from _handle_message)."""
|
||||
|
||||
def _should_reply(self, voice_mode, message_type, agent_messages=None, response="Hello!"):
|
||||
"""Replicate the auto voice reply decision from _handle_message."""
|
||||
if not response or response.startswith("Error:"):
|
||||
return False
|
||||
|
||||
is_voice_input = (message_type == MessageType.VOICE)
|
||||
should = (
|
||||
(voice_mode == "all")
|
||||
or (voice_mode == "voice_only" and is_voice_input)
|
||||
)
|
||||
if not should:
|
||||
return False
|
||||
|
||||
# Dedup check
|
||||
if agent_messages:
|
||||
has_agent_tts = any(
|
||||
msg.get("role") == "assistant"
|
||||
and any(
|
||||
tc.get("function", {}).get("name") == "text_to_speech"
|
||||
for tc in (msg.get("tool_calls") or [])
|
||||
)
|
||||
for msg in agent_messages
|
||||
)
|
||||
if has_agent_tts:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_voice_only_voice_input(self):
|
||||
assert self._should_reply("voice_only", MessageType.VOICE) is True
|
||||
|
||||
def test_voice_only_text_input(self):
|
||||
assert self._should_reply("voice_only", MessageType.TEXT) is False
|
||||
|
||||
def test_all_mode_text_input(self):
|
||||
assert self._should_reply("all", MessageType.TEXT) is True
|
||||
|
||||
def test_all_mode_voice_input(self):
|
||||
assert self._should_reply("all", MessageType.VOICE) is True
|
||||
|
||||
def test_off_mode(self):
|
||||
assert self._should_reply("off", MessageType.VOICE) is False
|
||||
assert self._should_reply("off", MessageType.TEXT) is False
|
||||
|
||||
def test_error_response_skipped(self):
|
||||
assert self._should_reply("all", MessageType.TEXT, response="Error: boom") is False
|
||||
|
||||
def test_empty_response_skipped(self):
|
||||
assert self._should_reply("all", MessageType.TEXT, response="") is False
|
||||
|
||||
def test_dedup_skips_when_agent_called_tts(self):
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "text_to_speech", "arguments": "{}"},
|
||||
}],
|
||||
}]
|
||||
assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is False
|
||||
|
||||
def test_no_dedup_for_other_tools(self):
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "web_search", "arguments": "{}"},
|
||||
}],
|
||||
}]
|
||||
assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is True
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# _send_voice_reply
|
||||
# =====================================================================
|
||||
|
||||
class TestSendVoiceReply:
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self, tmp_path):
|
||||
return _make_runner(tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_tts_and_send_voice(self, runner):
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter.send_voice = AsyncMock()
|
||||
event = _make_event()
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
|
||||
tts_result = json.dumps({"success": True, "file_path": "/tmp/test.ogg"})
|
||||
|
||||
with patch("tools.tts_tool.text_to_speech_tool", return_value=tts_result), \
|
||||
patch("tools.tts_tool._strip_markdown_for_tts", side_effect=lambda t: t), \
|
||||
patch("os.path.isfile", return_value=True), \
|
||||
patch("os.unlink"), \
|
||||
patch("os.makedirs"):
|
||||
await runner._send_voice_reply(event, "Hello world")
|
||||
|
||||
mock_adapter.send_voice.assert_called_once()
|
||||
call_args = mock_adapter.send_voice.call_args
|
||||
assert call_args[0][0] == "123" # chat_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text_after_strip_skips(self, runner):
|
||||
event = _make_event()
|
||||
|
||||
with patch("tools.tts_tool.text_to_speech_tool") as mock_tts, \
|
||||
patch("tools.tts_tool._strip_markdown_for_tts", return_value=""):
|
||||
await runner._send_voice_reply(event, "```code only```")
|
||||
|
||||
mock_tts.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tts_failure_no_crash(self, runner):
|
||||
event = _make_event()
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
tts_result = json.dumps({"success": False, "error": "API error"})
|
||||
|
||||
with patch("tools.tts_tool.text_to_speech_tool", return_value=tts_result), \
|
||||
patch("tools.tts_tool._strip_markdown_for_tts", side_effect=lambda t: t), \
|
||||
patch("os.path.isfile", return_value=False), \
|
||||
patch("os.makedirs"):
|
||||
await runner._send_voice_reply(event, "Hello")
|
||||
|
||||
mock_adapter.send_voice.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_caught(self, runner):
|
||||
event = _make_event()
|
||||
with patch("tools.tts_tool.text_to_speech_tool", side_effect=RuntimeError("boom")), \
|
||||
patch("tools.tts_tool._strip_markdown_for_tts", side_effect=lambda t: t), \
|
||||
patch("os.makedirs"):
|
||||
# Should not raise
|
||||
await runner._send_voice_reply(event, "Hello")
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Help text + known commands
|
||||
# =====================================================================
|
||||
|
||||
class TestVoiceInHelp:
|
||||
|
||||
def test_voice_in_help_output(self):
|
||||
from gateway.run import GatewayRunner
|
||||
import inspect
|
||||
source = inspect.getsource(GatewayRunner._handle_help_command)
|
||||
assert "/voice" in source
|
||||
|
||||
def test_voice_is_known_command(self):
|
||||
from gateway.run import GatewayRunner
|
||||
import inspect
|
||||
source = inspect.getsource(GatewayRunner._handle_message)
|
||||
assert '"voice"' in source
|
||||
Loading…
Add table
Add a link
Reference in a new issue