fix: persist clean voice transcripts and /voice off state

- keep CLI voice prefixes API-local while storing the original user text
- persist explicit gateway off state and restore adapter auto-TTS suppression on restart
- add regression coverage for both behaviors
This commit is contained in:
teknium1 2026-03-14 06:14:22 -07:00
parent 523a1b6faf
commit 7b10881b9e
5 changed files with 192 additions and 29 deletions

View file

@ -348,10 +348,20 @@ class GatewayRunner:
def _load_voice_modes(self) -> Dict[str, str]:
try:
return json.loads(self._VOICE_MODE_PATH.read_text())
data = json.loads(self._VOICE_MODE_PATH.read_text())
except (FileNotFoundError, json.JSONDecodeError, OSError):
return {}
if not isinstance(data, dict):
return {}
valid_modes = {"off", "voice_only", "all"}
return {
str(chat_id): mode
for chat_id, mode in data.items()
if mode in valid_modes
}
def _save_voice_modes(self) -> None:
try:
self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True)
@ -361,6 +371,26 @@ class GatewayRunner:
except OSError as e:
logger.warning("Failed to save voice modes: %s", e)
def _set_adapter_auto_tts_disabled(self, adapter, chat_id: str, disabled: bool) -> None:
"""Update an adapter's in-memory auto-TTS suppression set if present."""
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
if not isinstance(disabled_chats, set):
return
if disabled:
disabled_chats.add(chat_id)
else:
disabled_chats.discard(chat_id)
def _sync_voice_mode_state_to_adapter(self, adapter) -> None:
"""Restore persisted /voice off state into a live platform adapter."""
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
if not isinstance(disabled_chats, set):
return
disabled_chats.clear()
disabled_chats.update(
chat_id for chat_id, mode in self._voice_mode.items() if mode == "off"
)
# -----------------------------------------------------------------
def _flush_memories_for_session(self, old_session_id: str):
@ -666,6 +696,7 @@ class GatewayRunner:
success = await adapter.connect()
if success:
self.adapters[platform] = adapter
self._sync_voice_mode_state_to_adapter(adapter)
connected_count += 1
logger.info("%s connected", platform.value)
else:
@ -2140,23 +2171,23 @@ class GatewayRunner:
self._voice_mode[chat_id] = "voice_only"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.discard(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return (
"Voice mode enabled.\n"
"I'll reply with voice when you send voice messages.\n"
"Use /voice tts to get voice replies for all messages."
)
elif args in ("off", "disable"):
self._voice_mode.pop(chat_id, None)
self._voice_mode[chat_id] = "off"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.add(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
return "Voice mode disabled. Text-only replies."
elif args == "tts":
self._voice_mode[chat_id] = "all"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.discard(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return (
"Auto-TTS enabled.\n"
"All replies will include a voice message."
@ -2195,13 +2226,13 @@ class GatewayRunner:
self._voice_mode[chat_id] = "voice_only"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.discard(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return "Voice mode enabled."
else:
self._voice_mode.pop(chat_id, None)
self._voice_mode[chat_id] = "off"
self._save_voice_modes()
if adapter:
adapter._auto_tts_disabled_chats.add(chat_id)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
return "Voice mode disabled."
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
@ -2238,7 +2269,7 @@ class GatewayRunner:
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
self._voice_mode[event.source.chat_id] = "all"
self._save_voice_modes()
adapter._auto_tts_disabled_chats.discard(event.source.chat_id)
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False)
return (
f"Joined voice channel **{voice_channel.name}**.\n"
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
@ -2263,8 +2294,9 @@ class GatewayRunner:
except Exception as e:
logger.warning("Error leaving voice channel: %s", e)
# Always clean up state even if leave raised an exception
self._voice_mode.pop(event.source.chat_id, None)
self._voice_mode[event.source.chat_id] = "off"
self._save_voice_modes()
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True)
if hasattr(adapter, "_voice_input_callback"):
adapter._voice_input_callback = None
return "Left voice channel."
@ -2274,8 +2306,10 @@ class GatewayRunner:
Cleans up runner-side voice_mode state that the adapter cannot reach.
"""
self._voice_mode.pop(chat_id, None)
self._voice_mode[chat_id] = "off"
self._save_voice_modes()
adapter = self.adapters.get(Platform.DISCORD)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
async def _handle_voice_channel_input(
self, guild_id: int, user_id: int, transcript: str