fix: persist clean voice transcripts and /voice off state

- keep CLI voice prefixes API-local while storing the original user text
- persist explicit gateway off state and restore adapter auto-TTS suppression on restart
- add regression coverage for both behaviors
This commit is contained in:
teknium1 2026-03-14 06:14:22 -07:00
parent 523a1b6faf
commit 7b10881b9e
5 changed files with 192 additions and 29 deletions

13
cli.py
View file

@ -4218,9 +4218,8 @@ class HermesCLI:
text_queue.put(delta) text_queue.put(delta)
# When voice mode is active, prepend a brief instruction so the # When voice mode is active, prepend a brief instruction so the
# model responds concisely. The prefix is API-call-local only — # model responds concisely. The prefix is API-call-local only —
# we strip it from the returned history so it never persists to # run_conversation persists the original clean user message.
# session DB or resumed sessions.
_voice_prefix = "" _voice_prefix = ""
if self._voice_mode and isinstance(message, str): if self._voice_mode and isinstance(message, str):
_voice_prefix = ( _voice_prefix = (
@ -4236,6 +4235,7 @@ class HermesCLI:
conversation_history=self.conversation_history[:-1], # Exclude the message we just added conversation_history=self.conversation_history[:-1], # Exclude the message we just added
stream_callback=stream_callback, stream_callback=stream_callback,
task_id=self.session_id, task_id=self.session_id,
persist_user_message=message if _voice_prefix else None,
) )
# Start agent in background thread # Start agent in background thread
@ -4302,13 +4302,6 @@ class HermesCLI:
# Update history with full conversation # Update history with full conversation
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
# Strip voice prefix from history so it never persists
if _voice_prefix and self.conversation_history:
for msg in self.conversation_history:
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
if msg["content"].startswith(_voice_prefix):
msg["content"] = msg["content"][len(_voice_prefix):]
# Get the final response # Get the final response
response = result.get("final_response", "") if result else "" response = result.get("final_response", "") if result else ""

View file

@ -348,10 +348,20 @@ class GatewayRunner:
def _load_voice_modes(self) -> Dict[str, str]: def _load_voice_modes(self) -> Dict[str, str]:
try: try:
return json.loads(self._VOICE_MODE_PATH.read_text()) data = json.loads(self._VOICE_MODE_PATH.read_text())
except (FileNotFoundError, json.JSONDecodeError, OSError): except (FileNotFoundError, json.JSONDecodeError, OSError):
return {} return {}
if not isinstance(data, dict):
return {}
valid_modes = {"off", "voice_only", "all"}
return {
str(chat_id): mode
for chat_id, mode in data.items()
if mode in valid_modes
}
def _save_voice_modes(self) -> None: def _save_voice_modes(self) -> None:
try: try:
self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True) self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True)
@ -361,6 +371,26 @@ class GatewayRunner:
except OSError as e: except OSError as e:
logger.warning("Failed to save voice modes: %s", e) logger.warning("Failed to save voice modes: %s", e)
def _set_adapter_auto_tts_disabled(self, adapter, chat_id: str, disabled: bool) -> None:
"""Update an adapter's in-memory auto-TTS suppression set if present."""
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
if not isinstance(disabled_chats, set):
return
if disabled:
disabled_chats.add(chat_id)
else:
disabled_chats.discard(chat_id)
def _sync_voice_mode_state_to_adapter(self, adapter) -> None:
"""Restore persisted /voice off state into a live platform adapter."""
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
if not isinstance(disabled_chats, set):
return
disabled_chats.clear()
disabled_chats.update(
chat_id for chat_id, mode in self._voice_mode.items() if mode == "off"
)
# ----------------------------------------------------------------- # -----------------------------------------------------------------
def _flush_memories_for_session(self, old_session_id: str): def _flush_memories_for_session(self, old_session_id: str):
@ -666,6 +696,7 @@ class GatewayRunner:
success = await adapter.connect() success = await adapter.connect()
if success: if success:
self.adapters[platform] = adapter self.adapters[platform] = adapter
self._sync_voice_mode_state_to_adapter(adapter)
connected_count += 1 connected_count += 1
logger.info("%s connected", platform.value) logger.info("%s connected", platform.value)
else: else:
@ -2140,23 +2171,23 @@ class GatewayRunner:
self._voice_mode[chat_id] = "voice_only" self._voice_mode[chat_id] = "voice_only"
self._save_voice_modes() self._save_voice_modes()
if adapter: if adapter:
adapter._auto_tts_disabled_chats.discard(chat_id) self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return ( return (
"Voice mode enabled.\n" "Voice mode enabled.\n"
"I'll reply with voice when you send voice messages.\n" "I'll reply with voice when you send voice messages.\n"
"Use /voice tts to get voice replies for all messages." "Use /voice tts to get voice replies for all messages."
) )
elif args in ("off", "disable"): elif args in ("off", "disable"):
self._voice_mode.pop(chat_id, None) self._voice_mode[chat_id] = "off"
self._save_voice_modes() self._save_voice_modes()
if adapter: if adapter:
adapter._auto_tts_disabled_chats.add(chat_id) self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
return "Voice mode disabled. Text-only replies." return "Voice mode disabled. Text-only replies."
elif args == "tts": elif args == "tts":
self._voice_mode[chat_id] = "all" self._voice_mode[chat_id] = "all"
self._save_voice_modes() self._save_voice_modes()
if adapter: if adapter:
adapter._auto_tts_disabled_chats.discard(chat_id) self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return ( return (
"Auto-TTS enabled.\n" "Auto-TTS enabled.\n"
"All replies will include a voice message." "All replies will include a voice message."
@ -2195,13 +2226,13 @@ class GatewayRunner:
self._voice_mode[chat_id] = "voice_only" self._voice_mode[chat_id] = "voice_only"
self._save_voice_modes() self._save_voice_modes()
if adapter: if adapter:
adapter._auto_tts_disabled_chats.discard(chat_id) self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return "Voice mode enabled." return "Voice mode enabled."
else: else:
self._voice_mode.pop(chat_id, None) self._voice_mode[chat_id] = "off"
self._save_voice_modes() self._save_voice_modes()
if adapter: if adapter:
adapter._auto_tts_disabled_chats.add(chat_id) self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
return "Voice mode disabled." return "Voice mode disabled."
async def _handle_voice_channel_join(self, event: MessageEvent) -> str: async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
@ -2238,7 +2269,7 @@ class GatewayRunner:
adapter._voice_text_channels[guild_id] = int(event.source.chat_id) adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
self._voice_mode[event.source.chat_id] = "all" self._voice_mode[event.source.chat_id] = "all"
self._save_voice_modes() self._save_voice_modes()
adapter._auto_tts_disabled_chats.discard(event.source.chat_id) self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False)
return ( return (
f"Joined voice channel **{voice_channel.name}**.\n" f"Joined voice channel **{voice_channel.name}**.\n"
f"I'll speak my replies and listen to you. Use /voice leave to disconnect." f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
@ -2263,8 +2294,9 @@ class GatewayRunner:
except Exception as e: except Exception as e:
logger.warning("Error leaving voice channel: %s", e) logger.warning("Error leaving voice channel: %s", e)
# Always clean up state even if leave raised an exception # Always clean up state even if leave raised an exception
self._voice_mode.pop(event.source.chat_id, None) self._voice_mode[event.source.chat_id] = "off"
self._save_voice_modes() self._save_voice_modes()
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True)
if hasattr(adapter, "_voice_input_callback"): if hasattr(adapter, "_voice_input_callback"):
adapter._voice_input_callback = None adapter._voice_input_callback = None
return "Left voice channel." return "Left voice channel."
@ -2274,8 +2306,10 @@ class GatewayRunner:
Cleans up runner-side voice_mode state that the adapter cannot reach. Cleans up runner-side voice_mode state that the adapter cannot reach.
""" """
self._voice_mode.pop(chat_id, None) self._voice_mode[chat_id] = "off"
self._save_voice_modes() self._save_voice_modes()
adapter = self.adapters.get(Platform.DISCORD)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
async def _handle_voice_channel_input( async def _handle_voice_channel_input(
self, guild_id: int, user_id: int, transcript: str self, guild_id: int, user_id: int, transcript: str

View file

@ -497,6 +497,12 @@ class AIAgent:
# Initialized here so _vprint can reference it before run_conversation. # Initialized here so _vprint can reference it before run_conversation.
self._stream_callback = None self._stream_callback = None
# Optional current-turn user-message override used when the API-facing
# user message intentionally differs from the persisted transcript
# (e.g. CLI voice mode adds a temporary prefix for the live call only).
self._persist_user_message_idx = None
self._persist_user_message_override = None
# Initialize LLM client via centralized provider router. # Initialize LLM client via centralized provider router.
# The router handles auth resolution, base URL, headers, and # The router handles auth resolution, base URL, headers, and
# Codex/Anthropic wrapping for all known providers. # Codex/Anthropic wrapping for all known providers.
@ -998,11 +1004,30 @@ class AIAgent:
if self.verbose_logging: if self.verbose_logging:
logging.warning(f"Failed to cleanup browser for task {task_id}: {e}") logging.warning(f"Failed to cleanup browser for task {task_id}: {e}")
def _apply_persist_user_message_override(self, messages: List[Dict]) -> None:
"""Rewrite the current-turn user message before persistence/return.
Some call paths need an API-only user-message variant without letting
that synthetic text leak into persisted transcripts or resumed session
history. When an override is configured for the active turn, mutate the
in-memory messages list in place so both persistence and returned
history stay clean.
"""
idx = getattr(self, "_persist_user_message_idx", None)
override = getattr(self, "_persist_user_message_override", None)
if override is None or idx is None:
return
if 0 <= idx < len(messages):
msg = messages[idx]
if isinstance(msg, dict) and msg.get("role") == "user":
msg["content"] = override
def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None): def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None):
"""Save session state to both JSON log and SQLite on any exit path. """Save session state to both JSON log and SQLite on any exit path.
Ensures conversations are never lost, even on errors or early returns. Ensures conversations are never lost, even on errors or early returns.
""" """
self._apply_persist_user_message_override(messages)
self._session_messages = messages self._session_messages = messages
self._save_session_log(messages) self._save_session_log(messages)
self._flush_messages_to_session_db(messages, conversation_history) self._flush_messages_to_session_db(messages, conversation_history)
@ -1016,6 +1041,7 @@ class AIAgent:
""" """
if not self._session_db: if not self._session_db:
return return
self._apply_persist_user_message_override(messages)
try: try:
start_idx = len(conversation_history) if conversation_history else 0 start_idx = len(conversation_history) if conversation_history else 0
flush_from = max(start_idx, self._last_flushed_db_idx) flush_from = max(start_idx, self._last_flushed_db_idx)
@ -4065,6 +4091,7 @@ class AIAgent:
conversation_history: List[Dict[str, Any]] = None, conversation_history: List[Dict[str, Any]] = None,
task_id: str = None, task_id: str = None,
stream_callback: Optional[callable] = None, stream_callback: Optional[callable] = None,
persist_user_message: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run a complete conversation with tool calling until completion. Run a complete conversation with tool calling until completion.
@ -4077,6 +4104,9 @@ class AIAgent:
stream_callback: Optional callback invoked with each text delta during streaming. stream_callback: Optional callback invoked with each text delta during streaming.
Used by the TTS pipeline to start audio generation before the full response. Used by the TTS pipeline to start audio generation before the full response.
When None (default), API calls use the standard non-streaming path. When None (default), API calls use the standard non-streaming path.
persist_user_message: Optional clean user message to store in
transcripts/history when user_message contains API-only
synthetic prefixes.
Returns: Returns:
Dict: Complete conversation result with final response and message history Dict: Complete conversation result with final response and message history
@ -4087,6 +4117,8 @@ class AIAgent:
# Store stream callback for _interruptible_api_call to pick up # Store stream callback for _interruptible_api_call to pick up
self._stream_callback = stream_callback self._stream_callback = stream_callback
self._persist_user_message_idx = None
self._persist_user_message_override = persist_user_message
# Generate unique task_id if not provided to isolate VMs between concurrent tasks # Generate unique task_id if not provided to isolate VMs between concurrent tasks
effective_task_id = task_id or str(uuid.uuid4()) effective_task_id = task_id or str(uuid.uuid4())
@ -4121,7 +4153,7 @@ class AIAgent:
# Preserve the original user message before nudge injection. # Preserve the original user message before nudge injection.
# Honcho should receive the actual user input, not system nudges. # Honcho should receive the actual user input, not system nudges.
original_user_message = user_message original_user_message = persist_user_message if persist_user_message is not None else user_message
# Periodic memory nudge: remind the model to consider saving memories. # Periodic memory nudge: remind the model to consider saving memories.
# Counter resets whenever the memory tool is actually used. # Counter resets whenever the memory tool is actually used.
@ -4159,7 +4191,7 @@ class AIAgent:
_recall_mode = (self._honcho_config.recall_mode if self._honcho_config else "hybrid") _recall_mode = (self._honcho_config.recall_mode if self._honcho_config else "hybrid")
if self._honcho and self._honcho_session_key and _recall_mode != "tools": if self._honcho and self._honcho_session_key and _recall_mode != "tools":
try: try:
prefetched_context = self._honcho_prefetch(user_message) prefetched_context = self._honcho_prefetch(original_user_message)
if prefetched_context: if prefetched_context:
if not conversation_history: if not conversation_history:
self._honcho_context = prefetched_context self._honcho_context = prefetched_context
@ -4172,6 +4204,7 @@ class AIAgent:
user_msg = {"role": "user", "content": user_message} user_msg = {"role": "user", "content": user_message}
messages.append(user_msg) messages.append(user_msg)
current_turn_user_idx = len(messages) - 1 current_turn_user_idx = len(messages) - 1
self._persist_user_message_idx = current_turn_user_idx
if not self.quiet_mode: if not self.quiet_mode:
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'") print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")

View file

@ -3,12 +3,53 @@
import json import json
import os import os
import queue import queue
import sys
import threading import threading
import time import time
import pytest import pytest
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
def _ensure_discord_mock():
"""Install a lightweight discord mock when discord.py isn't available."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
discord_mod.opus = SimpleNamespace(is_loaded=lambda: True, load_opus=lambda *_args, **_kwargs: None)
discord_mod.FFmpegPCMAudio = MagicMock
discord_mod.PCMVolumeTransformer = MagicMock
discord_mod.http = SimpleNamespace(Route=MagicMock)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
from gateway.platforms.base import MessageEvent, MessageType, SessionSource from gateway.platforms.base import MessageEvent, MessageType, SessionSource
@ -65,7 +106,7 @@ class TestHandleVoiceCommand:
event = _make_event("/voice off") event = _make_event("/voice off")
result = await runner._handle_voice_command(event) result = await runner._handle_voice_command(event)
assert "disabled" in result.lower() assert "disabled" in result.lower()
assert "123" not in runner._voice_mode assert runner._voice_mode["123"] == "off"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_voice_tts(self, runner): async def test_voice_tts(self, runner):
@ -100,7 +141,7 @@ class TestHandleVoiceCommand:
event = _make_event("/voice") event = _make_event("/voice")
result = await runner._handle_voice_command(event) result = await runner._handle_voice_command(event)
assert "disabled" in result.lower() assert "disabled" in result.lower()
assert "123" not in runner._voice_mode assert runner._voice_mode["123"] == "off"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_persistence_saved(self, runner): async def test_persistence_saved(self, runner):
@ -116,6 +157,33 @@ class TestHandleVoiceCommand:
loaded = runner._load_voice_modes() loaded = runner._load_voice_modes()
assert loaded == {"456": "all"} assert loaded == {"456": "all"}
@pytest.mark.asyncio
async def test_persistence_saved_for_off(self, runner):
event = _make_event("/voice off")
await runner._handle_voice_command(event)
data = json.loads(runner._VOICE_MODE_PATH.read_text())
assert data["123"] == "off"
def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner):
runner._voice_mode = {"123": "off", "456": "all"}
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
runner._sync_voice_mode_state_to_adapter(adapter)
assert adapter._auto_tts_disabled_chats == {"123"}
def test_restart_restores_voice_off_state(self, runner, tmp_path):
runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"}))
restored_runner = _make_runner(tmp_path)
restored_runner._voice_mode = restored_runner._load_voice_modes()
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
restored_runner._sync_voice_mode_state_to_adapter(adapter)
assert restored_runner._voice_mode["123"] == "off"
assert adapter._auto_tts_disabled_chats == {"123"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_per_chat_isolation(self, runner): async def test_per_chat_isolation(self, runner):
e1 = _make_event("/voice on", chat_id="aaa") e1 = _make_event("/voice on", chat_id="aaa")
@ -693,7 +761,7 @@ class TestVoiceChannelCommands:
runner._voice_mode["123"] = "all" runner._voice_mode["123"] = "all"
result = await runner._handle_voice_channel_leave(event) result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower() assert "left" in result.lower()
assert "123" not in runner._voice_mode assert runner._voice_mode["123"] == "off"
mock_adapter.leave_voice_channel.assert_called_once_with(111) mock_adapter.leave_voice_channel.assert_called_once_with(111)
# -- _handle_voice_channel_input -- # -- _handle_voice_channel_input --
@ -1163,7 +1231,7 @@ class TestLeaveExceptionHandling:
result = await runner._handle_voice_channel_leave(event) result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower() assert "left" in result.lower()
assert "123" not in runner._voice_mode assert runner._voice_mode["123"] == "off"
assert mock_adapter._voice_input_callback is None assert mock_adapter._voice_input_callback is None
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1626,8 +1694,8 @@ class TestVoiceTimeoutCleansRunnerState:
runner._handle_voice_timeout_cleanup("999") runner._handle_voice_timeout_cleanup("999")
assert "999" not in runner._voice_mode, \ assert runner._voice_mode["999"] == "off", \
"voice_mode must be removed after timeout cleanup" "voice_mode must persist explicit off state after timeout cleanup"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_timeout_without_callback_does_not_crash(self, adapter): async def test_timeout_without_callback_does_not_crash(self, adapter):

View file

@ -2383,6 +2383,41 @@ class TestStreamCallbackNonStreamingProvider:
assert received == ["Hello from Claude"] assert received == ["Hello from Claude"]
# ---------------------------------------------------------------------------
# Bugfix: API-only user message prefixes must not persist
# ---------------------------------------------------------------------------
class TestPersistUserMessageOverride:
"""Synthetic API-only user prefixes should never leak into transcripts."""
def test_persist_session_rewrites_current_turn_user_message(self, agent):
agent._session_db = MagicMock()
agent.session_id = "session-123"
agent._last_flushed_db_idx = 0
agent._persist_user_message_idx = 0
agent._persist_user_message_override = "Hello there"
messages = [
{
"role": "user",
"content": (
"[Voice input — respond concisely and conversationally, "
"2-3 sentences max. No code blocks or markdown.] Hello there"
),
},
{"role": "assistant", "content": "Hi!"},
]
with patch.object(agent, "_save_session_log") as mock_save:
agent._persist_session(messages, [])
assert messages[0]["content"] == "Hello there"
saved_messages = mock_save.call_args.args[0]
assert saved_messages[0]["content"] == "Hello there"
first_db_write = agent._session_db.append_message.call_args_list[0].kwargs
assert first_db_write["content"] == "Hello there"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Bugfix: _vprint force=True on error messages during TTS # Bugfix: _vprint force=True on error messages during TTS
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------