fix: persist clean voice transcripts and /voice off state
- keep CLI voice prefixes API-local while storing the original user text - persist explicit gateway off state and restore adapter auto-TTS suppression on restart - add regression coverage for both behaviors
This commit is contained in:
parent
523a1b6faf
commit
7b10881b9e
5 changed files with 192 additions and 29 deletions
13
cli.py
13
cli.py
|
|
@ -4218,9 +4218,8 @@ class HermesCLI:
|
||||||
text_queue.put(delta)
|
text_queue.put(delta)
|
||||||
|
|
||||||
# When voice mode is active, prepend a brief instruction so the
|
# When voice mode is active, prepend a brief instruction so the
|
||||||
# model responds concisely. The prefix is API-call-local only —
|
# model responds concisely. The prefix is API-call-local only —
|
||||||
# we strip it from the returned history so it never persists to
|
# run_conversation persists the original clean user message.
|
||||||
# session DB or resumed sessions.
|
|
||||||
_voice_prefix = ""
|
_voice_prefix = ""
|
||||||
if self._voice_mode and isinstance(message, str):
|
if self._voice_mode and isinstance(message, str):
|
||||||
_voice_prefix = (
|
_voice_prefix = (
|
||||||
|
|
@ -4236,6 +4235,7 @@ class HermesCLI:
|
||||||
conversation_history=self.conversation_history[:-1], # Exclude the message we just added
|
conversation_history=self.conversation_history[:-1], # Exclude the message we just added
|
||||||
stream_callback=stream_callback,
|
stream_callback=stream_callback,
|
||||||
task_id=self.session_id,
|
task_id=self.session_id,
|
||||||
|
persist_user_message=message if _voice_prefix else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start agent in background thread
|
# Start agent in background thread
|
||||||
|
|
@ -4302,13 +4302,6 @@ class HermesCLI:
|
||||||
# Update history with full conversation
|
# Update history with full conversation
|
||||||
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
||||||
|
|
||||||
# Strip voice prefix from history so it never persists
|
|
||||||
if _voice_prefix and self.conversation_history:
|
|
||||||
for msg in self.conversation_history:
|
|
||||||
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
|
||||||
if msg["content"].startswith(_voice_prefix):
|
|
||||||
msg["content"] = msg["content"][len(_voice_prefix):]
|
|
||||||
|
|
||||||
# Get the final response
|
# Get the final response
|
||||||
response = result.get("final_response", "") if result else ""
|
response = result.get("final_response", "") if result else ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -348,10 +348,20 @@ class GatewayRunner:
|
||||||
|
|
||||||
def _load_voice_modes(self) -> Dict[str, str]:
|
def _load_voice_modes(self) -> Dict[str, str]:
|
||||||
try:
|
try:
|
||||||
return json.loads(self._VOICE_MODE_PATH.read_text())
|
data = json.loads(self._VOICE_MODE_PATH.read_text())
|
||||||
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
valid_modes = {"off", "voice_only", "all"}
|
||||||
|
return {
|
||||||
|
str(chat_id): mode
|
||||||
|
for chat_id, mode in data.items()
|
||||||
|
if mode in valid_modes
|
||||||
|
}
|
||||||
|
|
||||||
def _save_voice_modes(self) -> None:
|
def _save_voice_modes(self) -> None:
|
||||||
try:
|
try:
|
||||||
self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -361,6 +371,26 @@ class GatewayRunner:
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.warning("Failed to save voice modes: %s", e)
|
logger.warning("Failed to save voice modes: %s", e)
|
||||||
|
|
||||||
|
def _set_adapter_auto_tts_disabled(self, adapter, chat_id: str, disabled: bool) -> None:
|
||||||
|
"""Update an adapter's in-memory auto-TTS suppression set if present."""
|
||||||
|
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
|
||||||
|
if not isinstance(disabled_chats, set):
|
||||||
|
return
|
||||||
|
if disabled:
|
||||||
|
disabled_chats.add(chat_id)
|
||||||
|
else:
|
||||||
|
disabled_chats.discard(chat_id)
|
||||||
|
|
||||||
|
def _sync_voice_mode_state_to_adapter(self, adapter) -> None:
|
||||||
|
"""Restore persisted /voice off state into a live platform adapter."""
|
||||||
|
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
|
||||||
|
if not isinstance(disabled_chats, set):
|
||||||
|
return
|
||||||
|
disabled_chats.clear()
|
||||||
|
disabled_chats.update(
|
||||||
|
chat_id for chat_id, mode in self._voice_mode.items() if mode == "off"
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
|
|
||||||
def _flush_memories_for_session(self, old_session_id: str):
|
def _flush_memories_for_session(self, old_session_id: str):
|
||||||
|
|
@ -666,6 +696,7 @@ class GatewayRunner:
|
||||||
success = await adapter.connect()
|
success = await adapter.connect()
|
||||||
if success:
|
if success:
|
||||||
self.adapters[platform] = adapter
|
self.adapters[platform] = adapter
|
||||||
|
self._sync_voice_mode_state_to_adapter(adapter)
|
||||||
connected_count += 1
|
connected_count += 1
|
||||||
logger.info("✓ %s connected", platform.value)
|
logger.info("✓ %s connected", platform.value)
|
||||||
else:
|
else:
|
||||||
|
|
@ -2140,23 +2171,23 @@ class GatewayRunner:
|
||||||
self._voice_mode[chat_id] = "voice_only"
|
self._voice_mode[chat_id] = "voice_only"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
if adapter:
|
if adapter:
|
||||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
|
||||||
return (
|
return (
|
||||||
"Voice mode enabled.\n"
|
"Voice mode enabled.\n"
|
||||||
"I'll reply with voice when you send voice messages.\n"
|
"I'll reply with voice when you send voice messages.\n"
|
||||||
"Use /voice tts to get voice replies for all messages."
|
"Use /voice tts to get voice replies for all messages."
|
||||||
)
|
)
|
||||||
elif args in ("off", "disable"):
|
elif args in ("off", "disable"):
|
||||||
self._voice_mode.pop(chat_id, None)
|
self._voice_mode[chat_id] = "off"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
if adapter:
|
if adapter:
|
||||||
adapter._auto_tts_disabled_chats.add(chat_id)
|
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
|
||||||
return "Voice mode disabled. Text-only replies."
|
return "Voice mode disabled. Text-only replies."
|
||||||
elif args == "tts":
|
elif args == "tts":
|
||||||
self._voice_mode[chat_id] = "all"
|
self._voice_mode[chat_id] = "all"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
if adapter:
|
if adapter:
|
||||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
|
||||||
return (
|
return (
|
||||||
"Auto-TTS enabled.\n"
|
"Auto-TTS enabled.\n"
|
||||||
"All replies will include a voice message."
|
"All replies will include a voice message."
|
||||||
|
|
@ -2195,13 +2226,13 @@ class GatewayRunner:
|
||||||
self._voice_mode[chat_id] = "voice_only"
|
self._voice_mode[chat_id] = "voice_only"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
if adapter:
|
if adapter:
|
||||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
|
||||||
return "Voice mode enabled."
|
return "Voice mode enabled."
|
||||||
else:
|
else:
|
||||||
self._voice_mode.pop(chat_id, None)
|
self._voice_mode[chat_id] = "off"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
if adapter:
|
if adapter:
|
||||||
adapter._auto_tts_disabled_chats.add(chat_id)
|
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
|
||||||
return "Voice mode disabled."
|
return "Voice mode disabled."
|
||||||
|
|
||||||
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
|
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
|
||||||
|
|
@ -2238,7 +2269,7 @@ class GatewayRunner:
|
||||||
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
||||||
self._voice_mode[event.source.chat_id] = "all"
|
self._voice_mode[event.source.chat_id] = "all"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
adapter._auto_tts_disabled_chats.discard(event.source.chat_id)
|
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False)
|
||||||
return (
|
return (
|
||||||
f"Joined voice channel **{voice_channel.name}**.\n"
|
f"Joined voice channel **{voice_channel.name}**.\n"
|
||||||
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
|
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
|
||||||
|
|
@ -2263,8 +2294,9 @@ class GatewayRunner:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error leaving voice channel: %s", e)
|
logger.warning("Error leaving voice channel: %s", e)
|
||||||
# Always clean up state even if leave raised an exception
|
# Always clean up state even if leave raised an exception
|
||||||
self._voice_mode.pop(event.source.chat_id, None)
|
self._voice_mode[event.source.chat_id] = "off"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
|
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True)
|
||||||
if hasattr(adapter, "_voice_input_callback"):
|
if hasattr(adapter, "_voice_input_callback"):
|
||||||
adapter._voice_input_callback = None
|
adapter._voice_input_callback = None
|
||||||
return "Left voice channel."
|
return "Left voice channel."
|
||||||
|
|
@ -2274,8 +2306,10 @@ class GatewayRunner:
|
||||||
|
|
||||||
Cleans up runner-side voice_mode state that the adapter cannot reach.
|
Cleans up runner-side voice_mode state that the adapter cannot reach.
|
||||||
"""
|
"""
|
||||||
self._voice_mode.pop(chat_id, None)
|
self._voice_mode[chat_id] = "off"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
|
adapter = self.adapters.get(Platform.DISCORD)
|
||||||
|
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
|
||||||
|
|
||||||
async def _handle_voice_channel_input(
|
async def _handle_voice_channel_input(
|
||||||
self, guild_id: int, user_id: int, transcript: str
|
self, guild_id: int, user_id: int, transcript: str
|
||||||
|
|
|
||||||
37
run_agent.py
37
run_agent.py
|
|
@ -497,6 +497,12 @@ class AIAgent:
|
||||||
# Initialized here so _vprint can reference it before run_conversation.
|
# Initialized here so _vprint can reference it before run_conversation.
|
||||||
self._stream_callback = None
|
self._stream_callback = None
|
||||||
|
|
||||||
|
# Optional current-turn user-message override used when the API-facing
|
||||||
|
# user message intentionally differs from the persisted transcript
|
||||||
|
# (e.g. CLI voice mode adds a temporary prefix for the live call only).
|
||||||
|
self._persist_user_message_idx = None
|
||||||
|
self._persist_user_message_override = None
|
||||||
|
|
||||||
# Initialize LLM client via centralized provider router.
|
# Initialize LLM client via centralized provider router.
|
||||||
# The router handles auth resolution, base URL, headers, and
|
# The router handles auth resolution, base URL, headers, and
|
||||||
# Codex/Anthropic wrapping for all known providers.
|
# Codex/Anthropic wrapping for all known providers.
|
||||||
|
|
@ -998,11 +1004,30 @@ class AIAgent:
|
||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
logging.warning(f"Failed to cleanup browser for task {task_id}: {e}")
|
logging.warning(f"Failed to cleanup browser for task {task_id}: {e}")
|
||||||
|
|
||||||
|
def _apply_persist_user_message_override(self, messages: List[Dict]) -> None:
|
||||||
|
"""Rewrite the current-turn user message before persistence/return.
|
||||||
|
|
||||||
|
Some call paths need an API-only user-message variant without letting
|
||||||
|
that synthetic text leak into persisted transcripts or resumed session
|
||||||
|
history. When an override is configured for the active turn, mutate the
|
||||||
|
in-memory messages list in place so both persistence and returned
|
||||||
|
history stay clean.
|
||||||
|
"""
|
||||||
|
idx = getattr(self, "_persist_user_message_idx", None)
|
||||||
|
override = getattr(self, "_persist_user_message_override", None)
|
||||||
|
if override is None or idx is None:
|
||||||
|
return
|
||||||
|
if 0 <= idx < len(messages):
|
||||||
|
msg = messages[idx]
|
||||||
|
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||||
|
msg["content"] = override
|
||||||
|
|
||||||
def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None):
|
def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None):
|
||||||
"""Save session state to both JSON log and SQLite on any exit path.
|
"""Save session state to both JSON log and SQLite on any exit path.
|
||||||
|
|
||||||
Ensures conversations are never lost, even on errors or early returns.
|
Ensures conversations are never lost, even on errors or early returns.
|
||||||
"""
|
"""
|
||||||
|
self._apply_persist_user_message_override(messages)
|
||||||
self._session_messages = messages
|
self._session_messages = messages
|
||||||
self._save_session_log(messages)
|
self._save_session_log(messages)
|
||||||
self._flush_messages_to_session_db(messages, conversation_history)
|
self._flush_messages_to_session_db(messages, conversation_history)
|
||||||
|
|
@ -1016,6 +1041,7 @@ class AIAgent:
|
||||||
"""
|
"""
|
||||||
if not self._session_db:
|
if not self._session_db:
|
||||||
return
|
return
|
||||||
|
self._apply_persist_user_message_override(messages)
|
||||||
try:
|
try:
|
||||||
start_idx = len(conversation_history) if conversation_history else 0
|
start_idx = len(conversation_history) if conversation_history else 0
|
||||||
flush_from = max(start_idx, self._last_flushed_db_idx)
|
flush_from = max(start_idx, self._last_flushed_db_idx)
|
||||||
|
|
@ -4065,6 +4091,7 @@ class AIAgent:
|
||||||
conversation_history: List[Dict[str, Any]] = None,
|
conversation_history: List[Dict[str, Any]] = None,
|
||||||
task_id: str = None,
|
task_id: str = None,
|
||||||
stream_callback: Optional[callable] = None,
|
stream_callback: Optional[callable] = None,
|
||||||
|
persist_user_message: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run a complete conversation with tool calling until completion.
|
Run a complete conversation with tool calling until completion.
|
||||||
|
|
@ -4077,6 +4104,9 @@ class AIAgent:
|
||||||
stream_callback: Optional callback invoked with each text delta during streaming.
|
stream_callback: Optional callback invoked with each text delta during streaming.
|
||||||
Used by the TTS pipeline to start audio generation before the full response.
|
Used by the TTS pipeline to start audio generation before the full response.
|
||||||
When None (default), API calls use the standard non-streaming path.
|
When None (default), API calls use the standard non-streaming path.
|
||||||
|
persist_user_message: Optional clean user message to store in
|
||||||
|
transcripts/history when user_message contains API-only
|
||||||
|
synthetic prefixes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: Complete conversation result with final response and message history
|
Dict: Complete conversation result with final response and message history
|
||||||
|
|
@ -4087,6 +4117,8 @@ class AIAgent:
|
||||||
|
|
||||||
# Store stream callback for _interruptible_api_call to pick up
|
# Store stream callback for _interruptible_api_call to pick up
|
||||||
self._stream_callback = stream_callback
|
self._stream_callback = stream_callback
|
||||||
|
self._persist_user_message_idx = None
|
||||||
|
self._persist_user_message_override = persist_user_message
|
||||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||||
effective_task_id = task_id or str(uuid.uuid4())
|
effective_task_id = task_id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
@ -4121,7 +4153,7 @@ class AIAgent:
|
||||||
|
|
||||||
# Preserve the original user message before nudge injection.
|
# Preserve the original user message before nudge injection.
|
||||||
# Honcho should receive the actual user input, not system nudges.
|
# Honcho should receive the actual user input, not system nudges.
|
||||||
original_user_message = user_message
|
original_user_message = persist_user_message if persist_user_message is not None else user_message
|
||||||
|
|
||||||
# Periodic memory nudge: remind the model to consider saving memories.
|
# Periodic memory nudge: remind the model to consider saving memories.
|
||||||
# Counter resets whenever the memory tool is actually used.
|
# Counter resets whenever the memory tool is actually used.
|
||||||
|
|
@ -4159,7 +4191,7 @@ class AIAgent:
|
||||||
_recall_mode = (self._honcho_config.recall_mode if self._honcho_config else "hybrid")
|
_recall_mode = (self._honcho_config.recall_mode if self._honcho_config else "hybrid")
|
||||||
if self._honcho and self._honcho_session_key and _recall_mode != "tools":
|
if self._honcho and self._honcho_session_key and _recall_mode != "tools":
|
||||||
try:
|
try:
|
||||||
prefetched_context = self._honcho_prefetch(user_message)
|
prefetched_context = self._honcho_prefetch(original_user_message)
|
||||||
if prefetched_context:
|
if prefetched_context:
|
||||||
if not conversation_history:
|
if not conversation_history:
|
||||||
self._honcho_context = prefetched_context
|
self._honcho_context = prefetched_context
|
||||||
|
|
@ -4172,6 +4204,7 @@ class AIAgent:
|
||||||
user_msg = {"role": "user", "content": user_message}
|
user_msg = {"role": "user", "content": user_message}
|
||||||
messages.append(user_msg)
|
messages.append(user_msg)
|
||||||
current_turn_user_idx = len(messages) - 1
|
current_turn_user_idx = len(messages) - 1
|
||||||
|
self._persist_user_message_idx = current_turn_user_idx
|
||||||
|
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")
|
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,53 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import queue
|
import queue
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import pytest
|
import pytest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_discord_mock():
|
||||||
|
"""Install a lightweight discord mock when discord.py isn't available."""
|
||||||
|
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||||
|
return
|
||||||
|
|
||||||
|
discord_mod = MagicMock()
|
||||||
|
discord_mod.Intents.default.return_value = MagicMock()
|
||||||
|
discord_mod.Client = MagicMock
|
||||||
|
discord_mod.File = MagicMock
|
||||||
|
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||||
|
discord_mod.Thread = type("Thread", (), {})
|
||||||
|
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||||
|
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||||
|
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||||
|
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||||
|
discord_mod.Interaction = object
|
||||||
|
discord_mod.Embed = MagicMock
|
||||||
|
discord_mod.app_commands = SimpleNamespace(
|
||||||
|
describe=lambda **kwargs: (lambda fn: fn),
|
||||||
|
choices=lambda **kwargs: (lambda fn: fn),
|
||||||
|
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||||
|
)
|
||||||
|
discord_mod.opus = SimpleNamespace(is_loaded=lambda: True, load_opus=lambda *_args, **_kwargs: None)
|
||||||
|
discord_mod.FFmpegPCMAudio = MagicMock
|
||||||
|
discord_mod.PCMVolumeTransformer = MagicMock
|
||||||
|
discord_mod.http = SimpleNamespace(Route=MagicMock)
|
||||||
|
|
||||||
|
ext_mod = MagicMock()
|
||||||
|
commands_mod = MagicMock()
|
||||||
|
commands_mod.Bot = MagicMock
|
||||||
|
ext_mod.commands = commands_mod
|
||||||
|
|
||||||
|
sys.modules.setdefault("discord", discord_mod)
|
||||||
|
sys.modules.setdefault("discord.ext", ext_mod)
|
||||||
|
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_discord_mock()
|
||||||
|
|
||||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
|
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -65,7 +106,7 @@ class TestHandleVoiceCommand:
|
||||||
event = _make_event("/voice off")
|
event = _make_event("/voice off")
|
||||||
result = await runner._handle_voice_command(event)
|
result = await runner._handle_voice_command(event)
|
||||||
assert "disabled" in result.lower()
|
assert "disabled" in result.lower()
|
||||||
assert "123" not in runner._voice_mode
|
assert runner._voice_mode["123"] == "off"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_voice_tts(self, runner):
|
async def test_voice_tts(self, runner):
|
||||||
|
|
@ -100,7 +141,7 @@ class TestHandleVoiceCommand:
|
||||||
event = _make_event("/voice")
|
event = _make_event("/voice")
|
||||||
result = await runner._handle_voice_command(event)
|
result = await runner._handle_voice_command(event)
|
||||||
assert "disabled" in result.lower()
|
assert "disabled" in result.lower()
|
||||||
assert "123" not in runner._voice_mode
|
assert runner._voice_mode["123"] == "off"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_persistence_saved(self, runner):
|
async def test_persistence_saved(self, runner):
|
||||||
|
|
@ -116,6 +157,33 @@ class TestHandleVoiceCommand:
|
||||||
loaded = runner._load_voice_modes()
|
loaded = runner._load_voice_modes()
|
||||||
assert loaded == {"456": "all"}
|
assert loaded == {"456": "all"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_persistence_saved_for_off(self, runner):
|
||||||
|
event = _make_event("/voice off")
|
||||||
|
await runner._handle_voice_command(event)
|
||||||
|
data = json.loads(runner._VOICE_MODE_PATH.read_text())
|
||||||
|
assert data["123"] == "off"
|
||||||
|
|
||||||
|
def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner):
|
||||||
|
runner._voice_mode = {"123": "off", "456": "all"}
|
||||||
|
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
|
||||||
|
|
||||||
|
runner._sync_voice_mode_state_to_adapter(adapter)
|
||||||
|
|
||||||
|
assert adapter._auto_tts_disabled_chats == {"123"}
|
||||||
|
|
||||||
|
def test_restart_restores_voice_off_state(self, runner, tmp_path):
|
||||||
|
runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"}))
|
||||||
|
|
||||||
|
restored_runner = _make_runner(tmp_path)
|
||||||
|
restored_runner._voice_mode = restored_runner._load_voice_modes()
|
||||||
|
adapter = SimpleNamespace(_auto_tts_disabled_chats=set())
|
||||||
|
|
||||||
|
restored_runner._sync_voice_mode_state_to_adapter(adapter)
|
||||||
|
|
||||||
|
assert restored_runner._voice_mode["123"] == "off"
|
||||||
|
assert adapter._auto_tts_disabled_chats == {"123"}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_per_chat_isolation(self, runner):
|
async def test_per_chat_isolation(self, runner):
|
||||||
e1 = _make_event("/voice on", chat_id="aaa")
|
e1 = _make_event("/voice on", chat_id="aaa")
|
||||||
|
|
@ -693,7 +761,7 @@ class TestVoiceChannelCommands:
|
||||||
runner._voice_mode["123"] = "all"
|
runner._voice_mode["123"] = "all"
|
||||||
result = await runner._handle_voice_channel_leave(event)
|
result = await runner._handle_voice_channel_leave(event)
|
||||||
assert "left" in result.lower()
|
assert "left" in result.lower()
|
||||||
assert "123" not in runner._voice_mode
|
assert runner._voice_mode["123"] == "off"
|
||||||
mock_adapter.leave_voice_channel.assert_called_once_with(111)
|
mock_adapter.leave_voice_channel.assert_called_once_with(111)
|
||||||
|
|
||||||
# -- _handle_voice_channel_input --
|
# -- _handle_voice_channel_input --
|
||||||
|
|
@ -1163,7 +1231,7 @@ class TestLeaveExceptionHandling:
|
||||||
|
|
||||||
result = await runner._handle_voice_channel_leave(event)
|
result = await runner._handle_voice_channel_leave(event)
|
||||||
assert "left" in result.lower()
|
assert "left" in result.lower()
|
||||||
assert "123" not in runner._voice_mode
|
assert runner._voice_mode["123"] == "off"
|
||||||
assert mock_adapter._voice_input_callback is None
|
assert mock_adapter._voice_input_callback is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -1626,8 +1694,8 @@ class TestVoiceTimeoutCleansRunnerState:
|
||||||
|
|
||||||
runner._handle_voice_timeout_cleanup("999")
|
runner._handle_voice_timeout_cleanup("999")
|
||||||
|
|
||||||
assert "999" not in runner._voice_mode, \
|
assert runner._voice_mode["999"] == "off", \
|
||||||
"voice_mode must be removed after timeout cleanup"
|
"voice_mode must persist explicit off state after timeout cleanup"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_timeout_without_callback_does_not_crash(self, adapter):
|
async def test_timeout_without_callback_does_not_crash(self, adapter):
|
||||||
|
|
|
||||||
|
|
@ -2383,6 +2383,41 @@ class TestStreamCallbackNonStreamingProvider:
|
||||||
assert received == ["Hello from Claude"]
|
assert received == ["Hello from Claude"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bugfix: API-only user message prefixes must not persist
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPersistUserMessageOverride:
|
||||||
|
"""Synthetic API-only user prefixes should never leak into transcripts."""
|
||||||
|
|
||||||
|
def test_persist_session_rewrites_current_turn_user_message(self, agent):
|
||||||
|
agent._session_db = MagicMock()
|
||||||
|
agent.session_id = "session-123"
|
||||||
|
agent._last_flushed_db_idx = 0
|
||||||
|
agent._persist_user_message_idx = 0
|
||||||
|
agent._persist_user_message_override = "Hello there"
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"[Voice input — respond concisely and conversationally, "
|
||||||
|
"2-3 sentences max. No code blocks or markdown.] Hello there"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "Hi!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(agent, "_save_session_log") as mock_save:
|
||||||
|
agent._persist_session(messages, [])
|
||||||
|
|
||||||
|
assert messages[0]["content"] == "Hello there"
|
||||||
|
saved_messages = mock_save.call_args.args[0]
|
||||||
|
assert saved_messages[0]["content"] == "Hello there"
|
||||||
|
first_db_write = agent._session_db.append_message.call_args_list[0].kwargs
|
||||||
|
assert first_db_write["content"] == "Hello there"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Bugfix: _vprint force=True on error messages during TTS
|
# Bugfix: _vprint force=True on error messages during TTS
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue