fix: persist clean voice transcripts and /voice off state

- keep CLI voice prefixes API-local while storing the original user text
- persist explicit gateway off state and restore adapter auto-TTS suppression on restart
- add regression coverage for both behaviors
This commit is contained in:
teknium1 2026-03-14 06:14:22 -07:00
parent 523a1b6faf
commit 7b10881b9e
5 changed files with 192 additions and 29 deletions

13
cli.py
View file

@ -4218,9 +4218,8 @@ class HermesCLI:
text_queue.put(delta)
# When voice mode is active, prepend a brief instruction so the
# model responds concisely. The prefix is API-call-local only —
# we strip it from the returned history so it never persists to
# session DB or resumed sessions.
# model responds concisely. The prefix is API-call-local only —
# run_conversation persists the original clean user message.
_voice_prefix = ""
if self._voice_mode and isinstance(message, str):
_voice_prefix = (
@ -4236,6 +4235,7 @@ class HermesCLI:
conversation_history=self.conversation_history[:-1], # Exclude the message we just added
stream_callback=stream_callback,
task_id=self.session_id,
persist_user_message=message if _voice_prefix else None,
)
# Start agent in background thread
@ -4302,13 +4302,6 @@ class HermesCLI:
# Update history with full conversation
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
# Strip voice prefix from history so it never persists
if _voice_prefix and self.conversation_history:
for msg in self.conversation_history:
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
if msg["content"].startswith(_voice_prefix):
msg["content"] = msg["content"][len(_voice_prefix):]
# Get the final response
response = result.get("final_response", "") if result else ""

View file

@ -348,10 +348,20 @@ class GatewayRunner:
def _load_voice_modes(self) -> Dict[str, str]:
try:
return json.loads(self._VOICE_MODE_PATH.read_text())
data = json.loads(self._VOICE_MODE_PATH.read_text())
except (FileNotFoundError, json.JSONDecodeError, OSError):
return {}
if not isinstance(data, dict):
return {}
valid_modes = {"off", "voice_only", "all"}
return {
str(chat_id): mode
for chat_id, mode in data.items()
if mode in valid_modes
}
def _save_voice_modes(self) -> None:
try:
self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True)
@ -361,6 +371,26 @@ class GatewayRunner:
except OSError as e:
logger.warning("Failed to save voice modes: %s", e)
def _set_adapter_auto_tts_disabled(self, adapter, chat_id: str, disabled: bool) -> None:
"""Update an adapter's in-memory auto-TTS suppression set if present."""
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
if not isinstance(disabled_chats, set):
return
if disabled:
disabled_chats.add(chat_id)
else:
disabled_chats.discard(chat_id)
def _sync_voice_mode_state_to_adapter(self, adapter) -> None:
"""Restore persisted /voice off state into a live platform adapter."""
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
if not isinstance(disabled_chats, set):
return
disabled_chats.clear()
disabled_chats.update(
chat_id for chat_id, mode in self._voice_mode.items() if mode == "off"
)
# -----------------------------------------------------------------
def _flush_memories_for_session(self, old_session_id: str):
@ -666,6 +696,7 @@ class GatewayRunner:
success = await adapter.connect()
if success:
self.adapters[platform] = adapter
self._sync_voice_mode_state_to_adapter(adapter)
connected_count += 1
logger.info("%s connected", platform.value)
else:
@ -2140,23 +2171,23 @@ class GatewayRunner:
self._voice_mode[chat_id] = "voice_only"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.discard(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return (
"Voice mode enabled.\n"
"I'll reply with voice when you send voice messages.\n"
"Use /voice tts to get voice replies for all messages."
)
elif args in ("off", "disable"):
self._voice_mode.pop(chat_id, None)
self._voice_mode[chat_id] = "off"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.add(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
return "Voice mode disabled. Text-only replies."
elif args == "tts":
self._voice_mode[chat_id] = "all"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.discard(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return (
"Auto-TTS enabled.\n"
"All replies will include a voice message."
@ -2195,13 +2226,13 @@ class GatewayRunner:
self._voice_mode[chat_id] = "voice_only"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.discard(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return "Voice mode enabled."
else:
self._voice_mode.pop(chat_id, None)
self._voice_mode[chat_id] = "off"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.add(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
return "Voice mode disabled."
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
@ -2238,7 +2269,7 @@ class GatewayRunner:
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
self._voice_mode[event.source.chat_id] = "all"
self._save_voice_modes()
adapter._auto_tts_disabled_chats.discard(event.source.chat_id)
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False)
return (
f"Joined voice channel **{voice_channel.name}**.\n"
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
@ -2263,8 +2294,9 @@ class GatewayRunner:
except Exception as e:
logger.warning("Error leaving voice channel: %s", e)
# Always clean up state even if leave raised an exception
self._voice_mode.pop(event.source.chat_id, None)
self._voice_mode[event.source.chat_id] = "off"
self._save_voice_modes()
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True)
if hasattr(adapter, "_voice_input_callback"):
adapter._voice_input_callback = None
return "Left voice channel."
@ -2274,8 +2306,10 @@ class GatewayRunner:
Cleans up runner-side voice_mode state that the adapter cannot reach.
"""
self._voice_mode.pop(chat_id, None)
self._voice_mode[chat_id] = "off"
self._save_voice_modes()
adapter = self.adapters.get(Platform.DISCORD)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
async def _handle_voice_channel_input(
self, guild_id: int, user_id: int, transcript: str

View file

@ -497,6 +497,12 @@ class AIAgent:
# Initialized here so _vprint can reference it before run_conversation.
self._stream_callback = None
# Optional current-turn user-message override used when the API-facing
# user message intentionally differs from the persisted transcript
# (e.g. CLI voice mode adds a temporary prefix for the live call only).
self._persist_user_message_idx = None
self._persist_user_message_override = None
# Initialize LLM client via centralized provider router.
# The router handles auth resolution, base URL, headers, and
# Codex/Anthropic wrapping for all known providers.
@ -998,11 +1004,30 @@ class AIAgent:
if self.verbose_logging:
logging.warning(f"Failed to cleanup browser for task {task_id}: {e}")
def _apply_persist_user_message_override(self, messages: List[Dict]) -> None:
"""Rewrite the current-turn user message before persistence/return.
Some call paths need an API-only user-message variant without letting
that synthetic text leak into persisted transcripts or resumed session
history. When an override is configured for the active turn, mutate the
in-memory messages list in place so both persistence and returned
history stay clean.
"""
idx = getattr(self, "_persist_user_message_idx", None)
override = getattr(self, "_persist_user_message_override", None)
if override is None or idx is None:
return
if 0 <= idx < len(messages):
msg = messages[idx]
if isinstance(msg, dict) and msg.get("role") == "user":
msg["content"] = override
def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None):
"""Save session state to both JSON log and SQLite on any exit path.
Ensures conversations are never lost, even on errors or early returns.
"""
self._apply_persist_user_message_override(messages)
self._session_messages = messages
self._save_session_log(messages)
self._flush_messages_to_session_db(messages, conversation_history)
@ -1016,6 +1041,7 @@ class AIAgent:
"""
if not self._session_db:
return
self._apply_persist_user_message_override(messages)
try:
start_idx = len(conversation_history) if conversation_history else 0
flush_from = max(start_idx, self._last_flushed_db_idx)
@ -4065,6 +4091,7 @@ class AIAgent:
conversation_history: List[Dict[str, Any]] = None,
task_id: str = None,
stream_callback: Optional[callable] = None,
persist_user_message: Optional[str] = None,
) -> Dict[str, Any]:
"""
Run a complete conversation with tool calling until completion.
@ -4077,6 +4104,9 @@ class AIAgent:
stream_callback: Optional callback invoked with each text delta during streaming.
Used by the TTS pipeline to start audio generation before the full response.
When None (default), API calls use the standard non-streaming path.
persist_user_message: Optional clean user message to store in
transcripts/history when user_message contains API-only
synthetic prefixes.
Returns:
Dict: Complete conversation result with final response and message history
@ -4087,6 +4117,8 @@ class AIAgent:
# Store stream callback for _interruptible_api_call to pick up
self._stream_callback = stream_callback
self._persist_user_message_idx = None
self._persist_user_message_override = persist_user_message
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
effective_task_id = task_id or str(uuid.uuid4())
@ -4121,7 +4153,7 @@ class AIAgent:
# Preserve the original user message before nudge injection.
# Honcho should receive the actual user input, not system nudges.
original_user_message = user_message
original_user_message = persist_user_message if persist_user_message is not None else user_message
# Periodic memory nudge: remind the model to consider saving memories.
# Counter resets whenever the memory tool is actually used.
@ -4159,7 +4191,7 @@ class AIAgent:
_recall_mode = (self._honcho_config.recall_mode if self._honcho_config else "hybrid")
if self._honcho and self._honcho_session_key and _recall_mode != "tools":
try:
prefetched_context = self._honcho_prefetch(user_message)
prefetched_context = self._honcho_prefetch(original_user_message)
if prefetched_context:
if not conversation_history:
self._honcho_context = prefetched_context
@ -4172,6 +4204,7 @@ class AIAgent:
user_msg = {"role": "user", "content": user_message}
messages.append(user_msg)
current_turn_user_idx = len(messages) - 1
self._persist_user_message_idx = current_turn_user_idx
if not self.quiet_mode:
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")

View file

@ -3,12 +3,53 @@
import json
import os
import queue
import sys
import threading
import time
import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
def _ensure_discord_mock():
"""Install a lightweight discord mock when discord.py isn't available."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
discord_mod.opus = SimpleNamespace(is_loaded=lambda: True, load_opus=lambda *_args, **_kwargs: None)
discord_mod.FFmpegPCMAudio = MagicMock
discord_mod.PCMVolumeTransformer = MagicMock
discord_mod.http = SimpleNamespace(Route=MagicMock)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
@ -65,7 +106,7 @@ class TestHandleVoiceCommand:
event = _make_event("/voice off")
result = await runner._handle_voice_command(event)
assert "disabled" in result.lower()
assert "123" not in runner._voice_mode
assert runner._voice_mode["123"] == "off"
@pytest.mark.asyncio
async def test_voice_tts(self, runner):
@ -100,7 +141,7 @@ class TestHandleVoiceCommand:
event = _make_event("/voice")
result = await runner._handle_voice_command(event)
assert "disabled" in result.lower()
assert "123" not in runner._voice_mode
assert runner._voice_mode["123"] == "off"
@pytest.mark.asyncio
async def test_persistence_saved(self, runner):
@ -116,6 +157,33 @@ class TestHandleVoiceCommand:
loaded = runner._load_voice_modes()
assert loaded == {"456": "all"}
@pytest.mark.asyncio
async def test_persistence_saved_for_off(self, runner):
event = _make_event("/voice off")
await runner._handle_voice_command(event)
data = json.loads(runner._VOICE_MODE_PATH.read_text())
assert data["123"] == "off"
def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner):
runner._voice_mode = {"123": "off", "456": "all"}
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
runner._sync_voice_mode_state_to_adapter(adapter)
assert adapter._auto_tts_disabled_chats == {"123"}
def test_restart_restores_voice_off_state(self, runner, tmp_path):
runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"}))
restored_runner = _make_runner(tmp_path)
restored_runner._voice_mode = restored_runner._load_voice_modes()
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
restored_runner._sync_voice_mode_state_to_adapter(adapter)
assert restored_runner._voice_mode["123"] == "off"
assert adapter._auto_tts_disabled_chats == {"123"}
@pytest.mark.asyncio
async def test_per_chat_isolation(self, runner):
e1 = _make_event("/voice on", chat_id="aaa")
@ -693,7 +761,7 @@ class TestVoiceChannelCommands:
runner._voice_mode["123"] = "all"
result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower()
assert "123" not in runner._voice_mode
assert runner._voice_mode["123"] == "off"
mock_adapter.leave_voice_channel.assert_called_once_with(111)
# -- _handle_voice_channel_input --
@ -1163,7 +1231,7 @@ class TestLeaveExceptionHandling:
result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower()
assert "123" not in runner._voice_mode
assert runner._voice_mode["123"] == "off"
assert mock_adapter._voice_input_callback is None
@pytest.mark.asyncio
@ -1626,8 +1694,8 @@ class TestVoiceTimeoutCleansRunnerState:
runner._handle_voice_timeout_cleanup("999")
assert "999" not in runner._voice_mode, \
"voice_mode must be removed after timeout cleanup"
assert runner._voice_mode["999"] == "off", \
"voice_mode must persist explicit off state after timeout cleanup"
@pytest.mark.asyncio
async def test_timeout_without_callback_does_not_crash(self, adapter):

View file

@ -2383,6 +2383,41 @@ class TestStreamCallbackNonStreamingProvider:
assert received == ["Hello from Claude"]
# ---------------------------------------------------------------------------
# Bugfix: API-only user message prefixes must not persist
# ---------------------------------------------------------------------------
class TestPersistUserMessageOverride:
"""Synthetic API-only user prefixes should never leak into transcripts."""
def test_persist_session_rewrites_current_turn_user_message(self, agent):
agent._session_db = MagicMock()
agent.session_id = "session-123"
agent._last_flushed_db_idx = 0
agent._persist_user_message_idx = 0
agent._persist_user_message_override = "Hello there"
messages = [
{
"role": "user",
"content": (
"[Voice input — respond concisely and conversationally, "
"2-3 sentences max. No code blocks or markdown.] Hello there"
),
},
{"role": "assistant", "content": "Hi!"},
]
with patch.object(agent, "_save_session_log") as mock_save:
agent._persist_session(messages, [])
assert messages[0]["content"] == "Hello there"
saved_messages = mock_save.call_args.args[0]
assert saved_messages[0]["content"] == "Hello there"
first_db_write = agent._session_db.append_message.call_args_list[0].kwargs
assert first_db_write["content"] == "Hello there"
# ---------------------------------------------------------------------------
# Bugfix: _vprint force=True on error messages during TTS
# ---------------------------------------------------------------------------