fix: persist clean voice transcripts and /voice off state
- keep CLI voice prefixes API-local while storing the original user text - persist explicit gateway off state and restore adapter auto-TTS suppression on restart - add regression coverage for both behaviors
This commit is contained in:
parent
523a1b6faf
commit
7b10881b9e
5 changed files with 192 additions and 29 deletions
13
cli.py
13
cli.py
|
|
@ -4218,9 +4218,8 @@ class HermesCLI:
|
|||
text_queue.put(delta)
|
||||
|
||||
# When voice mode is active, prepend a brief instruction so the
|
||||
# model responds concisely. The prefix is API-call-local only —
|
||||
# we strip it from the returned history so it never persists to
|
||||
# session DB or resumed sessions.
|
||||
# model responds concisely. The prefix is API-call-local only —
|
||||
# run_conversation persists the original clean user message.
|
||||
_voice_prefix = ""
|
||||
if self._voice_mode and isinstance(message, str):
|
||||
_voice_prefix = (
|
||||
|
|
@ -4236,6 +4235,7 @@ class HermesCLI:
|
|||
conversation_history=self.conversation_history[:-1], # Exclude the message we just added
|
||||
stream_callback=stream_callback,
|
||||
task_id=self.session_id,
|
||||
persist_user_message=message if _voice_prefix else None,
|
||||
)
|
||||
|
||||
# Start agent in background thread
|
||||
|
|
@ -4302,13 +4302,6 @@ class HermesCLI:
|
|||
# Update history with full conversation
|
||||
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
||||
|
||||
# Strip voice prefix from history so it never persists
|
||||
if _voice_prefix and self.conversation_history:
|
||||
for msg in self.conversation_history:
|
||||
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
||||
if msg["content"].startswith(_voice_prefix):
|
||||
msg["content"] = msg["content"][len(_voice_prefix):]
|
||||
|
||||
# Get the final response
|
||||
response = result.get("final_response", "") if result else ""
|
||||
|
||||
|
|
|
|||
|
|
@ -348,10 +348,20 @@ class GatewayRunner:
|
|||
|
||||
def _load_voice_modes(self) -> Dict[str, str]:
|
||||
try:
|
||||
return json.loads(self._VOICE_MODE_PATH.read_text())
|
||||
data = json.loads(self._VOICE_MODE_PATH.read_text())
|
||||
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
||||
return {}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
|
||||
valid_modes = {"off", "voice_only", "all"}
|
||||
return {
|
||||
str(chat_id): mode
|
||||
for chat_id, mode in data.items()
|
||||
if mode in valid_modes
|
||||
}
|
||||
|
||||
def _save_voice_modes(self) -> None:
|
||||
try:
|
||||
self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -361,6 +371,26 @@ class GatewayRunner:
|
|||
except OSError as e:
|
||||
logger.warning("Failed to save voice modes: %s", e)
|
||||
|
||||
def _set_adapter_auto_tts_disabled(self, adapter, chat_id: str, disabled: bool) -> None:
|
||||
"""Update an adapter's in-memory auto-TTS suppression set if present."""
|
||||
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
|
||||
if not isinstance(disabled_chats, set):
|
||||
return
|
||||
if disabled:
|
||||
disabled_chats.add(chat_id)
|
||||
else:
|
||||
disabled_chats.discard(chat_id)
|
||||
|
||||
def _sync_voice_mode_state_to_adapter(self, adapter) -> None:
|
||||
"""Restore persisted /voice off state into a live platform adapter."""
|
||||
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
|
||||
if not isinstance(disabled_chats, set):
|
||||
return
|
||||
disabled_chats.clear()
|
||||
disabled_chats.update(
|
||||
chat_id for chat_id, mode in self._voice_mode.items() if mode == "off"
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def _flush_memories_for_session(self, old_session_id: str):
|
||||
|
|
@ -666,6 +696,7 @@ class GatewayRunner:
|
|||
success = await adapter.connect()
|
||||
if success:
|
||||
self.adapters[platform] = adapter
|
||||
self._sync_voice_mode_state_to_adapter(adapter)
|
||||
connected_count += 1
|
||||
logger.info("✓ %s connected", platform.value)
|
||||
else:
|
||||
|
|
@ -2140,23 +2171,23 @@ class GatewayRunner:
|
|||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
|
||||
return (
|
||||
"Voice mode enabled.\n"
|
||||
"I'll reply with voice when you send voice messages.\n"
|
||||
"Use /voice tts to get voice replies for all messages."
|
||||
)
|
||||
elif args in ("off", "disable"):
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._voice_mode[chat_id] = "off"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.add(chat_id)
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
|
||||
return "Voice mode disabled. Text-only replies."
|
||||
elif args == "tts":
|
||||
self._voice_mode[chat_id] = "all"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
|
||||
return (
|
||||
"Auto-TTS enabled.\n"
|
||||
"All replies will include a voice message."
|
||||
|
|
@ -2195,13 +2226,13 @@ class GatewayRunner:
|
|||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
|
||||
return "Voice mode enabled."
|
||||
else:
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._voice_mode[chat_id] = "off"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.add(chat_id)
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
|
||||
return "Voice mode disabled."
|
||||
|
||||
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
|
||||
|
|
@ -2238,7 +2269,7 @@ class GatewayRunner:
|
|||
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
||||
self._voice_mode[event.source.chat_id] = "all"
|
||||
self._save_voice_modes()
|
||||
adapter._auto_tts_disabled_chats.discard(event.source.chat_id)
|
||||
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False)
|
||||
return (
|
||||
f"Joined voice channel **{voice_channel.name}**.\n"
|
||||
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
|
||||
|
|
@ -2263,8 +2294,9 @@ class GatewayRunner:
|
|||
except Exception as e:
|
||||
logger.warning("Error leaving voice channel: %s", e)
|
||||
# Always clean up state even if leave raised an exception
|
||||
self._voice_mode.pop(event.source.chat_id, None)
|
||||
self._voice_mode[event.source.chat_id] = "off"
|
||||
self._save_voice_modes()
|
||||
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True)
|
||||
if hasattr(adapter, "_voice_input_callback"):
|
||||
adapter._voice_input_callback = None
|
||||
return "Left voice channel."
|
||||
|
|
@ -2274,8 +2306,10 @@ class GatewayRunner:
|
|||
|
||||
Cleans up runner-side voice_mode state that the adapter cannot reach.
|
||||
"""
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._voice_mode[chat_id] = "off"
|
||||
self._save_voice_modes()
|
||||
adapter = self.adapters.get(Platform.DISCORD)
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
|
||||
|
||||
async def _handle_voice_channel_input(
|
||||
self, guild_id: int, user_id: int, transcript: str
|
||||
|
|
|
|||
37
run_agent.py
37
run_agent.py
|
|
@ -497,6 +497,12 @@ class AIAgent:
|
|||
# Initialized here so _vprint can reference it before run_conversation.
|
||||
self._stream_callback = None
|
||||
|
||||
# Optional current-turn user-message override used when the API-facing
|
||||
# user message intentionally differs from the persisted transcript
|
||||
# (e.g. CLI voice mode adds a temporary prefix for the live call only).
|
||||
self._persist_user_message_idx = None
|
||||
self._persist_user_message_override = None
|
||||
|
||||
# Initialize LLM client via centralized provider router.
|
||||
# The router handles auth resolution, base URL, headers, and
|
||||
# Codex/Anthropic wrapping for all known providers.
|
||||
|
|
@ -998,11 +1004,30 @@ class AIAgent:
|
|||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup browser for task {task_id}: {e}")
|
||||
|
||||
def _apply_persist_user_message_override(self, messages: List[Dict]) -> None:
|
||||
"""Rewrite the current-turn user message before persistence/return.
|
||||
|
||||
Some call paths need an API-only user-message variant without letting
|
||||
that synthetic text leak into persisted transcripts or resumed session
|
||||
history. When an override is configured for the active turn, mutate the
|
||||
in-memory messages list in place so both persistence and returned
|
||||
history stay clean.
|
||||
"""
|
||||
idx = getattr(self, "_persist_user_message_idx", None)
|
||||
override = getattr(self, "_persist_user_message_override", None)
|
||||
if override is None or idx is None:
|
||||
return
|
||||
if 0 <= idx < len(messages):
|
||||
msg = messages[idx]
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
msg["content"] = override
|
||||
|
||||
def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None):
|
||||
"""Save session state to both JSON log and SQLite on any exit path.
|
||||
|
||||
Ensures conversations are never lost, even on errors or early returns.
|
||||
"""
|
||||
self._apply_persist_user_message_override(messages)
|
||||
self._session_messages = messages
|
||||
self._save_session_log(messages)
|
||||
self._flush_messages_to_session_db(messages, conversation_history)
|
||||
|
|
@ -1016,6 +1041,7 @@ class AIAgent:
|
|||
"""
|
||||
if not self._session_db:
|
||||
return
|
||||
self._apply_persist_user_message_override(messages)
|
||||
try:
|
||||
start_idx = len(conversation_history) if conversation_history else 0
|
||||
flush_from = max(start_idx, self._last_flushed_db_idx)
|
||||
|
|
@ -4065,6 +4091,7 @@ class AIAgent:
|
|||
conversation_history: List[Dict[str, Any]] = None,
|
||||
task_id: str = None,
|
||||
stream_callback: Optional[callable] = None,
|
||||
persist_user_message: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a complete conversation with tool calling until completion.
|
||||
|
|
@ -4077,6 +4104,9 @@ class AIAgent:
|
|||
stream_callback: Optional callback invoked with each text delta during streaming.
|
||||
Used by the TTS pipeline to start audio generation before the full response.
|
||||
When None (default), API calls use the standard non-streaming path.
|
||||
persist_user_message: Optional clean user message to store in
|
||||
transcripts/history when user_message contains API-only
|
||||
synthetic prefixes.
|
||||
|
||||
Returns:
|
||||
Dict: Complete conversation result with final response and message history
|
||||
|
|
@ -4087,6 +4117,8 @@ class AIAgent:
|
|||
|
||||
# Store stream callback for _interruptible_api_call to pick up
|
||||
self._stream_callback = stream_callback
|
||||
self._persist_user_message_idx = None
|
||||
self._persist_user_message_override = persist_user_message
|
||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
|
|
@ -4121,7 +4153,7 @@ class AIAgent:
|
|||
|
||||
# Preserve the original user message before nudge injection.
|
||||
# Honcho should receive the actual user input, not system nudges.
|
||||
original_user_message = user_message
|
||||
original_user_message = persist_user_message if persist_user_message is not None else user_message
|
||||
|
||||
# Periodic memory nudge: remind the model to consider saving memories.
|
||||
# Counter resets whenever the memory tool is actually used.
|
||||
|
|
@ -4159,7 +4191,7 @@ class AIAgent:
|
|||
_recall_mode = (self._honcho_config.recall_mode if self._honcho_config else "hybrid")
|
||||
if self._honcho and self._honcho_session_key and _recall_mode != "tools":
|
||||
try:
|
||||
prefetched_context = self._honcho_prefetch(user_message)
|
||||
prefetched_context = self._honcho_prefetch(original_user_message)
|
||||
if prefetched_context:
|
||||
if not conversation_history:
|
||||
self._honcho_context = prefetched_context
|
||||
|
|
@ -4172,6 +4204,7 @@ class AIAgent:
|
|||
user_msg = {"role": "user", "content": user_message}
|
||||
messages.append(user_msg)
|
||||
current_turn_user_idx = len(messages) - 1
|
||||
self._persist_user_message_idx = current_turn_user_idx
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")
|
||||
|
|
|
|||
|
|
@ -3,12 +3,53 @@
|
|||
import json
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a lightweight discord mock when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
discord_mod.opus = SimpleNamespace(is_loaded=lambda: True, load_opus=lambda *_args, **_kwargs: None)
|
||||
discord_mod.FFmpegPCMAudio = MagicMock
|
||||
discord_mod.PCMVolumeTransformer = MagicMock
|
||||
discord_mod.http = SimpleNamespace(Route=MagicMock)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
|
||||
|
||||
|
||||
|
|
@ -65,7 +106,7 @@ class TestHandleVoiceCommand:
|
|||
event = _make_event("/voice off")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "disabled" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_tts(self, runner):
|
||||
|
|
@ -100,7 +141,7 @@ class TestHandleVoiceCommand:
|
|||
event = _make_event("/voice")
|
||||
result = await runner._handle_voice_command(event)
|
||||
assert "disabled" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_saved(self, runner):
|
||||
|
|
@ -116,6 +157,33 @@ class TestHandleVoiceCommand:
|
|||
loaded = runner._load_voice_modes()
|
||||
assert loaded == {"456": "all"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistence_saved_for_off(self, runner):
|
||||
event = _make_event("/voice off")
|
||||
await runner._handle_voice_command(event)
|
||||
data = json.loads(runner._VOICE_MODE_PATH.read_text())
|
||||
assert data["123"] == "off"
|
||||
|
||||
def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner):
|
||||
runner._voice_mode = {"123": "off", "456": "all"}
|
||||
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
|
||||
|
||||
runner._sync_voice_mode_state_to_adapter(adapter)
|
||||
|
||||
assert adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
def test_restart_restores_voice_off_state(self, runner, tmp_path):
|
||||
runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"}))
|
||||
|
||||
restored_runner = _make_runner(tmp_path)
|
||||
restored_runner._voice_mode = restored_runner._load_voice_modes()
|
||||
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
|
||||
|
||||
restored_runner._sync_voice_mode_state_to_adapter(adapter)
|
||||
|
||||
assert restored_runner._voice_mode["123"] == "off"
|
||||
assert adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_chat_isolation(self, runner):
|
||||
e1 = _make_event("/voice on", chat_id="aaa")
|
||||
|
|
@ -693,7 +761,7 @@ class TestVoiceChannelCommands:
|
|||
runner._voice_mode["123"] = "all"
|
||||
result = await runner._handle_voice_channel_leave(event)
|
||||
assert "left" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
mock_adapter.leave_voice_channel.assert_called_once_with(111)
|
||||
|
||||
# -- _handle_voice_channel_input --
|
||||
|
|
@ -1163,7 +1231,7 @@ class TestLeaveExceptionHandling:
|
|||
|
||||
result = await runner._handle_voice_channel_leave(event)
|
||||
assert "left" in result.lower()
|
||||
assert "123" not in runner._voice_mode
|
||||
assert runner._voice_mode["123"] == "off"
|
||||
assert mock_adapter._voice_input_callback is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1626,8 +1694,8 @@ class TestVoiceTimeoutCleansRunnerState:
|
|||
|
||||
runner._handle_voice_timeout_cleanup("999")
|
||||
|
||||
assert "999" not in runner._voice_mode, \
|
||||
"voice_mode must be removed after timeout cleanup"
|
||||
assert runner._voice_mode["999"] == "off", \
|
||||
"voice_mode must persist explicit off state after timeout cleanup"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_without_callback_does_not_crash(self, adapter):
|
||||
|
|
|
|||
|
|
@ -2383,6 +2383,41 @@ class TestStreamCallbackNonStreamingProvider:
|
|||
assert received == ["Hello from Claude"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bugfix: API-only user message prefixes must not persist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPersistUserMessageOverride:
|
||||
"""Synthetic API-only user prefixes should never leak into transcripts."""
|
||||
|
||||
def test_persist_session_rewrites_current_turn_user_message(self, agent):
|
||||
agent._session_db = MagicMock()
|
||||
agent.session_id = "session-123"
|
||||
agent._last_flushed_db_idx = 0
|
||||
agent._persist_user_message_idx = 0
|
||||
agent._persist_user_message_override = "Hello there"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"[Voice input — respond concisely and conversationally, "
|
||||
"2-3 sentences max. No code blocks or markdown.] Hello there"
|
||||
),
|
||||
},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
]
|
||||
|
||||
with patch.object(agent, "_save_session_log") as mock_save:
|
||||
agent._persist_session(messages, [])
|
||||
|
||||
assert messages[0]["content"] == "Hello there"
|
||||
saved_messages = mock_save.call_args.args[0]
|
||||
assert saved_messages[0]["content"] == "Hello there"
|
||||
first_db_write = agent._session_db.append_message.call_args_list[0].kwargs
|
||||
assert first_db_write["content"] == "Hello there"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bugfix: _vprint force=True on error messages during TTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue