fix: extract voice reply logic and add comprehensive tests

- Fix tempfile.mktemp() TOCTOU race in Discord voice input (use NamedTemporaryFile)
- Extract voice reply decision from _handle_message into _should_send_voice_reply()
- Rewrite TestAutoVoiceReply to call real method instead of testing a copy
- Add 59 new tests: VoiceReceiver, VC commands, adapter methods, streaming TTS
This commit is contained in:
0xbyt4 2026-03-11 23:18:49 +03:00
parent 0d56b79685
commit 86ddaaee9c
3 changed files with 845 additions and 97 deletions

View file

@ -851,7 +851,9 @@ class DiscordAdapter(BasePlatformAdapter):
"""Convert PCM -> WAV -> STT -> callback.""" """Convert PCM -> WAV -> STT -> callback."""
from tools.voice_mode import is_whisper_hallucination from tools.voice_mode import is_whisper_hallucination
wav_path = tempfile.mktemp(suffix=".wav", prefix="vc_listen_") tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", prefix="vc_listen_", delete=False)
wav_path = tmp_f.name
tmp_f.close()
try: try:
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path) await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)

View file

@ -1616,42 +1616,8 @@ class GatewayRunner:
) )
# Auto voice reply: send TTS audio before the text response # Auto voice reply: send TTS audio before the text response
chat_id = source.chat_id if self._should_send_voice_reply(event, response, agent_messages):
voice_mode = self._voice_mode.get(chat_id, "off") await self._send_voice_reply(event, response)
is_voice_input = (event.message_type == MessageType.VOICE)
should_voice_reply = (
(voice_mode == "all")
or (voice_mode == "voice_only" and is_voice_input)
)
logger.info("Voice reply check: chat_id=%s, voice_mode=%s, is_voice=%s, should_reply=%s, has_response=%s",
chat_id, voice_mode, is_voice_input, should_voice_reply, bool(response))
if should_voice_reply and response and not response.startswith("Error:"):
# Skip if agent already called TTS tool (avoid double voice)
has_agent_tts = any(
msg.get("role") == "assistant"
and any(
tc.get("function", {}).get("name") == "text_to_speech"
for tc in (msg.get("tool_calls") or [])
)
for msg in agent_messages
)
# Skip if voice input — base adapter auto-TTS in
# _process_message_background already sent audio for voice
# messages, so sending another would be double.
# Exception: Discord voice channel — the Discord play_tts
# override also skips (no-op), so the runner MUST handle it
# via play_in_voice_channel.
skip_double = is_voice_input
if skip_double:
adapter = self.adapters.get(source.platform)
guild_id = self._get_guild_id(event)
if (guild_id and adapter
and hasattr(adapter, "is_in_voice_channel")
and adapter.is_in_voice_channel(guild_id)):
skip_double = False
logger.info("Voice reply: has_agent_tts=%s, skip_double=%s, calling _send_voice_reply", has_agent_tts, skip_double)
if not has_agent_tts and not skip_double:
await self._send_voice_reply(event, response)
return response return response
@ -2302,6 +2268,64 @@ class GatewayRunner:
await adapter.handle_message(event) await adapter.handle_message(event)
def _should_send_voice_reply(
self,
event: MessageEvent,
response: str,
agent_messages: list,
) -> bool:
"""Decide whether the runner should send a TTS voice reply.
Returns False when:
- voice_mode is off for this chat
- response is empty or an error
- agent already called text_to_speech tool (dedup)
- voice input and base adapter auto-TTS already handled it (skip_double)
Exception: Discord voice channel base play_tts is a no-op there,
so the runner must handle VC playback.
"""
if not response or response.startswith("Error:"):
return False
chat_id = event.source.chat_id
voice_mode = self._voice_mode.get(chat_id, "off")
is_voice_input = (event.message_type == MessageType.VOICE)
should = (
(voice_mode == "all")
or (voice_mode == "voice_only" and is_voice_input)
)
if not should:
return False
# Dedup: agent already called TTS tool
has_agent_tts = any(
msg.get("role") == "assistant"
and any(
tc.get("function", {}).get("name") == "text_to_speech"
for tc in (msg.get("tool_calls") or [])
)
for msg in agent_messages
)
if has_agent_tts:
return False
# Dedup: base adapter auto-TTS already handles voice input.
# Exception: Discord voice channel — play_tts override is a no-op,
# so the runner must handle VC playback.
skip_double = is_voice_input
if skip_double:
adapter = self.adapters.get(event.source.platform)
guild_id = self._get_guild_id(event)
if (guild_id and adapter
and hasattr(adapter, "is_in_voice_channel")
and adapter.is_in_voice_channel(guild_id)):
skip_double = False
if skip_double:
return False
return True
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None: async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
"""Generate TTS audio and send as a voice message before the text reply.""" """Generate TTS audio and send as a voice message before the text reply."""
try: try:

View file

@ -2,7 +2,11 @@
import json import json
import os import os
import queue
import threading
import time
import pytest import pytest
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from gateway.platforms.base import MessageEvent, MessageType, SessionSource from gateway.platforms.base import MessageEvent, MessageType, SessionSource
@ -126,7 +130,7 @@ class TestHandleVoiceCommand:
# ===================================================================== # =====================================================================
class TestAutoVoiceReply: class TestAutoVoiceReply:
"""Test the should_voice_reply decision logic (extracted from _handle_message). """Test the real _should_send_voice_reply method on GatewayRunner.
The gateway has two TTS paths: The gateway has two TTS paths:
1. base adapter auto-TTS: fires for voice input in _process_message_background 1. base adapter auto-TTS: fires for voice input in _process_message_background
@ -138,43 +142,30 @@ class TestAutoVoiceReply:
override skip, so the runner must handle it via play_in_voice_channel. override skip, so the runner must handle it via play_in_voice_channel.
""" """
def _should_reply(self, voice_mode, message_type, agent_messages=None, @pytest.fixture
response="Hello!", in_voice_channel=False): def runner(self, tmp_path):
"""Replicate the auto voice reply decision from _handle_message.""" return _make_runner(tmp_path)
if not response or response.startswith("Error:"):
return False
is_voice_input = (message_type == MessageType.VOICE) def _call(self, runner, voice_mode, message_type, agent_messages=None,
should = ( response="Hello!", in_voice_channel=False):
(voice_mode == "all") """Call real _should_send_voice_reply on a GatewayRunner instance."""
or (voice_mode == "voice_only" and is_voice_input) chat_id = "123"
if voice_mode != "off":
runner._voice_mode[chat_id] = voice_mode
else:
runner._voice_mode.pop(chat_id, None)
event = _make_event(message_type=message_type)
if in_voice_channel:
mock_adapter = MagicMock()
mock_adapter.is_in_voice_channel = MagicMock(return_value=True)
event.raw_message = SimpleNamespace(guild_id=111, guild=None)
runner.adapters[event.source.platform] = mock_adapter
return runner._should_send_voice_reply(
event, response, agent_messages or []
) )
if not should:
return False
# Dedup: agent already called TTS tool
if agent_messages:
has_agent_tts = any(
msg.get("role") == "assistant"
and any(
tc.get("function", {}).get("name") == "text_to_speech"
for tc in (msg.get("tool_calls") or [])
)
for msg in agent_messages
)
if has_agent_tts:
return False
# Dedup: base adapter auto-TTS already handles voice input.
# Exception: in voice channel, Discord play_tts also skips,
# so the runner must handle VC playback.
skip_double = is_voice_input
if skip_double and in_voice_channel:
skip_double = False
if skip_double:
return False
return True
# -- Full platform x input x mode matrix -------------------------------- # -- Full platform x input x mode matrix --------------------------------
# #
@ -204,52 +195,52 @@ class TestAutoVoiceReply:
# -- Telegram/Slack/Web: voice input, base handles --------------------- # -- Telegram/Slack/Web: voice input, base handles ---------------------
def test_voice_input_voice_only_skipped(self): def test_voice_input_voice_only_skipped(self, runner):
"""voice_only + voice input: base auto-TTS handles it, runner skips.""" """voice_only + voice input: base auto-TTS handles it, runner skips."""
assert self._should_reply("voice_only", MessageType.VOICE) is False assert self._call(runner, "voice_only", MessageType.VOICE) is False
def test_voice_input_all_mode_skipped(self): def test_voice_input_all_mode_skipped(self, runner):
"""all + voice input: base auto-TTS handles it, runner skips.""" """all + voice input: base auto-TTS handles it, runner skips."""
assert self._should_reply("all", MessageType.VOICE) is False assert self._call(runner, "all", MessageType.VOICE) is False
# -- Text input: only runner handles ----------------------------------- # -- Text input: only runner handles -----------------------------------
def test_text_input_all_mode_runner_fires(self): def test_text_input_all_mode_runner_fires(self, runner):
"""all + text input: only runner fires (base auto-TTS only for voice).""" """all + text input: only runner fires (base auto-TTS only for voice)."""
assert self._should_reply("all", MessageType.TEXT) is True assert self._call(runner, "all", MessageType.TEXT) is True
def test_text_input_voice_only_no_reply(self): def test_text_input_voice_only_no_reply(self, runner):
"""voice_only + text input: neither fires.""" """voice_only + text input: neither fires."""
assert self._should_reply("voice_only", MessageType.TEXT) is False assert self._call(runner, "voice_only", MessageType.TEXT) is False
# -- Mode off: nothing fires ------------------------------------------- # -- Mode off: nothing fires -------------------------------------------
def test_off_mode_voice(self): def test_off_mode_voice(self, runner):
assert self._should_reply("off", MessageType.VOICE) is False assert self._call(runner, "off", MessageType.VOICE) is False
def test_off_mode_text(self): def test_off_mode_text(self, runner):
assert self._should_reply("off", MessageType.TEXT) is False assert self._call(runner, "off", MessageType.TEXT) is False
# -- Discord VC exception: runner must handle -------------------------- # -- Discord VC exception: runner must handle --------------------------
def test_discord_vc_voice_input_runner_fires(self): def test_discord_vc_voice_input_runner_fires(self, runner):
"""Discord VC + voice input: base play_tts skips (VC override), """Discord VC + voice input: base play_tts skips (VC override),
so runner must handle via play_in_voice_channel.""" so runner must handle via play_in_voice_channel."""
assert self._should_reply("all", MessageType.VOICE, in_voice_channel=True) is True assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is True
def test_discord_vc_voice_only_runner_fires(self): def test_discord_vc_voice_only_runner_fires(self, runner):
"""Discord VC + voice_only + voice: runner must handle.""" """Discord VC + voice_only + voice: runner must handle."""
assert self._should_reply("voice_only", MessageType.VOICE, in_voice_channel=True) is True assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is True
# -- Edge cases -------------------------------------------------------- # -- Edge cases --------------------------------------------------------
def test_error_response_skipped(self): def test_error_response_skipped(self, runner):
assert self._should_reply("all", MessageType.TEXT, response="Error: boom") is False assert self._call(runner, "all", MessageType.TEXT, response="Error: boom") is False
def test_empty_response_skipped(self): def test_empty_response_skipped(self, runner):
assert self._should_reply("all", MessageType.TEXT, response="") is False assert self._call(runner, "all", MessageType.TEXT, response="") is False
def test_dedup_skips_when_agent_called_tts(self): def test_dedup_skips_when_agent_called_tts(self, runner):
messages = [{ messages = [{
"role": "assistant", "role": "assistant",
"tool_calls": [{ "tool_calls": [{
@ -258,9 +249,9 @@ class TestAutoVoiceReply:
"function": {"name": "text_to_speech", "arguments": "{}"}, "function": {"name": "text_to_speech", "arguments": "{}"},
}], }],
}] }]
assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is False assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is False
def test_no_dedup_for_other_tools(self): def test_no_dedup_for_other_tools(self, runner):
messages = [{ messages = [{
"role": "assistant", "role": "assistant",
"tool_calls": [{ "tool_calls": [{
@ -269,7 +260,7 @@ class TestAutoVoiceReply:
"function": {"name": "web_search", "arguments": "{}"}, "function": {"name": "web_search", "arguments": "{}"},
}], }],
}] }]
assert self._should_reply("all", MessageType.TEXT, agent_messages=messages) is True assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is True
# ===================================================================== # =====================================================================
@ -443,3 +434,734 @@ class TestVoiceInHelp:
import inspect import inspect
source = inspect.getsource(GatewayRunner._handle_message) source = inspect.getsource(GatewayRunner._handle_message)
assert '"voice"' in source assert '"voice"' in source
# =====================================================================
# VoiceReceiver unit tests
# =====================================================================
class TestVoiceReceiver:
"""Test VoiceReceiver silence detection, SSRC mapping, and lifecycle."""
def _make_receiver(self):
from gateway.platforms.discord import VoiceReceiver
mock_vc = MagicMock()
mock_vc._connection.secret_key = [0] * 32
mock_vc._connection.dave_session = None
mock_vc._connection.ssrc = 9999
mock_vc._connection.add_socket_listener = MagicMock()
mock_vc._connection.remove_socket_listener = MagicMock()
mock_vc._connection.hook = None
receiver = VoiceReceiver(mock_vc)
return receiver
def test_initial_state(self):
receiver = self._make_receiver()
assert receiver._running is False
assert receiver._paused is False
assert len(receiver._buffers) == 0
assert len(receiver._ssrc_to_user) == 0
def test_start_sets_running(self):
receiver = self._make_receiver()
receiver.start()
assert receiver._running is True
def test_stop_clears_state(self):
receiver = self._make_receiver()
receiver.start()
receiver.map_ssrc(100, 42)
receiver._buffers[100] = bytearray(b"\x00" * 1000)
receiver._last_packet_time[100] = time.monotonic()
receiver.stop()
assert receiver._running is False
assert len(receiver._buffers) == 0
assert len(receiver._ssrc_to_user) == 0
assert len(receiver._last_packet_time) == 0
def test_map_ssrc(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
assert receiver._ssrc_to_user[100] == 42
def test_map_ssrc_overwrites(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
receiver.map_ssrc(100, 99)
assert receiver._ssrc_to_user[100] == 99
def test_pause_resume(self):
receiver = self._make_receiver()
assert receiver._paused is False
receiver.pause()
assert receiver._paused is True
receiver.resume()
assert receiver._paused is False
def test_check_silence_empty(self):
receiver = self._make_receiver()
assert receiver.check_silence() == []
def test_check_silence_returns_completed_utterance(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
# 48kHz, stereo, 16-bit = 192000 bytes/sec
# MIN_SPEECH_DURATION = 0.5s → need 96000 bytes
pcm_data = bytearray(b"\x00" * 96000)
receiver._buffers[100] = pcm_data
# Set last_packet_time far enough in the past to exceed SILENCE_THRESHOLD
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 1
user_id, data = completed[0]
assert user_id == 42
assert len(data) == 96000
# Buffer should be cleared after extraction
assert len(receiver._buffers[100]) == 0
def test_check_silence_ignores_short_buffer(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
# Too short to meet MIN_SPEECH_DURATION
receiver._buffers[100] = bytearray(b"\x00" * 100)
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 0
def test_check_silence_ignores_recent_audio(self):
receiver = self._make_receiver()
receiver.map_ssrc(100, 42)
receiver._buffers[100] = bytearray(b"\x00" * 96000)
receiver._last_packet_time[100] = time.monotonic() # just now
completed = receiver.check_silence()
assert len(completed) == 0
def test_check_silence_unknown_user_discarded(self):
receiver = self._make_receiver()
# No SSRC mapping — user_id will be 0
receiver._buffers[100] = bytearray(b"\x00" * 96000)
receiver._last_packet_time[100] = time.monotonic() - 3.0
completed = receiver.check_silence()
assert len(completed) == 0
def test_stale_buffer_discarded(self):
receiver = self._make_receiver()
# Buffer with no user mapping and very old timestamp
receiver._buffers[200] = bytearray(b"\x00" * 100)
receiver._last_packet_time[200] = time.monotonic() - 10.0
receiver.check_silence()
# Stale buffer (> 2x threshold) should be discarded
assert 200 not in receiver._buffers
def test_on_packet_skips_when_not_running(self):
receiver = self._make_receiver()
# Not started — _running is False
receiver._on_packet(b"\x00" * 100)
assert len(receiver._buffers) == 0
def test_on_packet_skips_when_paused(self):
receiver = self._make_receiver()
receiver.start()
receiver.pause()
receiver._on_packet(b"\x00" * 100)
# Paused — should not process
assert len(receiver._buffers) == 0
def test_on_packet_skips_short_data(self):
receiver = self._make_receiver()
receiver.start()
receiver._on_packet(b"\x00" * 10)
assert len(receiver._buffers) == 0
def test_on_packet_skips_non_rtp(self):
receiver = self._make_receiver()
receiver.start()
# Valid length but wrong RTP version
data = bytearray(b"\x00" * 20)
data[0] = 0x00 # version 0, not 2
receiver._on_packet(bytes(data))
assert len(receiver._buffers) == 0
# =====================================================================
# Gateway voice channel commands (join / leave / input)
# =====================================================================
class TestVoiceChannelCommands:
"""Test _handle_voice_channel_join, _handle_voice_channel_leave,
_handle_voice_channel_input on the GatewayRunner."""
@pytest.fixture
def runner(self, tmp_path):
return _make_runner(tmp_path)
def _make_discord_event(self, text="/voice channel", chat_id="123",
guild_id=111, user_id="user1"):
"""Create event with raw_message carrying guild info."""
source = SessionSource(
chat_id=chat_id,
user_id=user_id,
platform=MagicMock(),
)
source.platform.value = "discord"
source.thread_id = None
event = MessageEvent(text=text, message_type=MessageType.TEXT, source=source)
event.message_id = "msg42"
event.raw_message = SimpleNamespace(guild_id=guild_id, guild=None)
return event
# -- _handle_voice_channel_join --
@pytest.mark.asyncio
async def test_join_unsupported_platform(self, runner):
"""Platform without join_voice_channel returns unsupported message."""
mock_adapter = AsyncMock(spec=[]) # no join_voice_channel
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "not supported" in result.lower()
@pytest.mark.asyncio
async def test_join_no_guild_id(self, runner):
"""DM context (no guild_id) returns error."""
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock()
event = self._make_discord_event()
event.raw_message = None # no guild info
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "discord server" in result.lower()
@pytest.mark.asyncio
async def test_join_user_not_in_vc(self, runner):
"""User not in any voice channel."""
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock()
mock_adapter.get_user_voice_channel = AsyncMock(return_value=None)
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "need to be in a voice channel" in result.lower()
@pytest.mark.asyncio
async def test_join_success(self, runner):
"""Successful join sets voice_mode and returns confirmation."""
mock_channel = MagicMock()
mock_channel.name = "General"
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock(return_value=True)
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
mock_adapter._voice_text_channels = {}
mock_adapter._voice_input_callback = None
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "joined" in result.lower()
assert "General" in result
assert runner._voice_mode["123"] == "all"
@pytest.mark.asyncio
async def test_join_failure(self, runner):
"""Failed join returns permissions error."""
mock_channel = MagicMock()
mock_channel.name = "General"
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock(return_value=False)
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "failed" in result.lower()
@pytest.mark.asyncio
async def test_join_exception(self, runner):
"""Exception during join is caught and reported."""
mock_channel = MagicMock()
mock_channel.name = "General"
mock_adapter = AsyncMock()
mock_adapter.join_voice_channel = AsyncMock(side_effect=RuntimeError("No permission"))
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
event = self._make_discord_event()
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_join(event)
assert "failed" in result.lower()
# -- _handle_voice_channel_leave --
@pytest.mark.asyncio
async def test_leave_not_in_vc(self, runner):
"""Leave when not in VC returns appropriate message."""
mock_adapter = AsyncMock()
mock_adapter.is_in_voice_channel = MagicMock(return_value=False)
event = self._make_discord_event("/voice leave")
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_leave(event)
assert "not in" in result.lower()
@pytest.mark.asyncio
async def test_leave_no_guild(self, runner):
"""Leave from DM returns not in voice channel."""
mock_adapter = AsyncMock()
event = self._make_discord_event("/voice leave")
event.raw_message = None
runner.adapters[event.source.platform] = mock_adapter
result = await runner._handle_voice_channel_leave(event)
assert "not in" in result.lower()
@pytest.mark.asyncio
async def test_leave_success(self, runner):
"""Successful leave disconnects and clears voice mode."""
mock_adapter = AsyncMock()
mock_adapter.is_in_voice_channel = MagicMock(return_value=True)
mock_adapter.leave_voice_channel = AsyncMock()
event = self._make_discord_event("/voice leave")
runner.adapters[event.source.platform] = mock_adapter
runner._voice_mode["123"] = "all"
result = await runner._handle_voice_channel_leave(event)
assert "left" in result.lower()
assert "123" not in runner._voice_mode
mock_adapter.leave_voice_channel.assert_called_once_with(111)
# -- _handle_voice_channel_input --
@pytest.mark.asyncio
async def test_input_no_adapter(self, runner):
"""No Discord adapter — early return, no crash."""
from gateway.config import Platform
# No adapters set
await runner._handle_voice_channel_input(111, 42, "Hello")
@pytest.mark.asyncio
async def test_input_no_text_channel(self, runner):
"""No text channel mapped for guild — early return."""
from gateway.config import Platform
mock_adapter = AsyncMock()
mock_adapter._voice_text_channels = {}
mock_adapter._client = MagicMock()
runner.adapters[Platform.DISCORD] = mock_adapter
await runner._handle_voice_channel_input(111, 42, "Hello")
@pytest.mark.asyncio
async def test_input_creates_event_and_dispatches(self, runner):
"""Voice input creates synthetic event and calls handle_message."""
from gateway.config import Platform
mock_adapter = AsyncMock()
mock_adapter._voice_text_channels = {111: 123}
mock_channel = AsyncMock()
mock_adapter._client = MagicMock()
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
mock_adapter.handle_message = AsyncMock()
runner.adapters[Platform.DISCORD] = mock_adapter
await runner._handle_voice_channel_input(111, 42, "Hello from VC")
mock_adapter.handle_message.assert_called_once()
event = mock_adapter.handle_message.call_args[0][0]
assert event.text == "Hello from VC"
assert event.message_type == MessageType.VOICE
assert event.source.chat_id == "123"
@pytest.mark.asyncio
async def test_input_posts_transcript_in_text_channel(self, runner):
"""Voice input sends transcript message to text channel."""
from gateway.config import Platform
mock_adapter = AsyncMock()
mock_adapter._voice_text_channels = {111: 123}
mock_channel = AsyncMock()
mock_adapter._client = MagicMock()
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
mock_adapter.handle_message = AsyncMock()
runner.adapters[Platform.DISCORD] = mock_adapter
await runner._handle_voice_channel_input(111, 42, "Test transcript")
mock_channel.send.assert_called_once()
msg = mock_channel.send.call_args[0][0]
assert "Test transcript" in msg
assert "42" in msg # user_id in mention
# -- _get_guild_id --
def test_get_guild_id_from_guild(self, runner):
event = _make_event()
mock_guild = MagicMock()
mock_guild.id = 555
event.raw_message = SimpleNamespace(guild_id=None, guild=mock_guild)
result = runner._get_guild_id(event)
assert result == 555
def test_get_guild_id_from_interaction(self, runner):
event = _make_event()
event.raw_message = SimpleNamespace(guild_id=777, guild=None)
result = runner._get_guild_id(event)
assert result == 777
def test_get_guild_id_none(self, runner):
event = _make_event()
event.raw_message = None
result = runner._get_guild_id(event)
assert result is None
def test_get_guild_id_dm(self, runner):
event = _make_event()
event.raw_message = SimpleNamespace(guild_id=None, guild=None)
result = runner._get_guild_id(event)
assert result is None
# =====================================================================
# Discord adapter voice channel methods
# =====================================================================
class TestDiscordVoiceChannelMethods:
"""Test DiscordAdapter voice channel methods (join, leave, play, etc.)."""
def _make_adapter(self):
from gateway.platforms.discord import DiscordAdapter
from gateway.config import Platform, PlatformConfig
config = PlatformConfig(enabled=True, extra={})
config.token = "fake-token"
adapter = object.__new__(DiscordAdapter)
adapter.platform = Platform.DISCORD
adapter.config = config
adapter._client = MagicMock()
adapter._voice_clients = {}
adapter._voice_text_channels = {}
adapter._voice_timeout_tasks = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}
adapter._voice_input_callback = None
adapter._allowed_user_ids = set()
adapter._running = True
return adapter
def test_is_in_voice_channel_true(self):
adapter = self._make_adapter()
mock_vc = MagicMock()
mock_vc.is_connected.return_value = True
adapter._voice_clients[111] = mock_vc
assert adapter.is_in_voice_channel(111) is True
def test_is_in_voice_channel_false_no_client(self):
adapter = self._make_adapter()
assert adapter.is_in_voice_channel(111) is False
def test_is_in_voice_channel_false_disconnected(self):
adapter = self._make_adapter()
mock_vc = MagicMock()
mock_vc.is_connected.return_value = False
adapter._voice_clients[111] = mock_vc
assert adapter.is_in_voice_channel(111) is False
@pytest.mark.asyncio
async def test_leave_voice_channel_cleans_up(self):
adapter = self._make_adapter()
mock_vc = MagicMock()
mock_vc.is_connected.return_value = True
mock_vc.disconnect = AsyncMock()
adapter._voice_clients[111] = mock_vc
adapter._voice_text_channels[111] = 123
mock_receiver = MagicMock()
adapter._voice_receivers[111] = mock_receiver
mock_task = MagicMock()
adapter._voice_listen_tasks[111] = mock_task
mock_timeout = MagicMock()
adapter._voice_timeout_tasks[111] = mock_timeout
await adapter.leave_voice_channel(111)
mock_receiver.stop.assert_called_once()
mock_task.cancel.assert_called_once()
mock_vc.disconnect.assert_called_once()
mock_timeout.cancel.assert_called_once()
assert 111 not in adapter._voice_clients
assert 111 not in adapter._voice_text_channels
assert 111 not in adapter._voice_receivers
@pytest.mark.asyncio
async def test_leave_voice_channel_no_connection(self):
"""Leave when not connected — no crash."""
adapter = self._make_adapter()
await adapter.leave_voice_channel(111) # should not raise
@pytest.mark.asyncio
async def test_get_user_voice_channel_no_client(self):
adapter = self._make_adapter()
adapter._client = None
result = await adapter.get_user_voice_channel(111, "42")
assert result is None
@pytest.mark.asyncio
async def test_get_user_voice_channel_no_guild(self):
adapter = self._make_adapter()
adapter._client.get_guild = MagicMock(return_value=None)
result = await adapter.get_user_voice_channel(111, "42")
assert result is None
@pytest.mark.asyncio
async def test_get_user_voice_channel_user_not_in_vc(self):
adapter = self._make_adapter()
mock_guild = MagicMock()
mock_member = MagicMock()
mock_member.voice = None
mock_guild.get_member = MagicMock(return_value=mock_member)
adapter._client.get_guild = MagicMock(return_value=mock_guild)
result = await adapter.get_user_voice_channel(111, "42")
assert result is None
@pytest.mark.asyncio
async def test_get_user_voice_channel_success(self):
adapter = self._make_adapter()
mock_vc = MagicMock()
mock_guild = MagicMock()
mock_member = MagicMock()
mock_member.voice = MagicMock()
mock_member.voice.channel = mock_vc
mock_guild.get_member = MagicMock(return_value=mock_member)
adapter._client.get_guild = MagicMock(return_value=mock_guild)
result = await adapter.get_user_voice_channel(111, "42")
assert result is mock_vc
@pytest.mark.asyncio
async def test_play_in_voice_channel_not_connected(self):
adapter = self._make_adapter()
result = await adapter.play_in_voice_channel(111, "/tmp/test.ogg")
assert result is False
def test_is_allowed_user_empty_list(self):
adapter = self._make_adapter()
assert adapter._is_allowed_user("42") is True
def test_is_allowed_user_in_list(self):
adapter = self._make_adapter()
adapter._allowed_user_ids = {"42", "99"}
assert adapter._is_allowed_user("42") is True
def test_is_allowed_user_not_in_list(self):
adapter = self._make_adapter()
adapter._allowed_user_ids = {"99"}
assert adapter._is_allowed_user("42") is False
@pytest.mark.asyncio
async def test_process_voice_input_success(self):
"""Successful voice input: PCM->WAV->STT->callback."""
adapter = self._make_adapter()
callback = AsyncMock()
adapter._voice_input_callback = callback
adapter._allowed_user_ids = set()
pcm_data = b"\x00" * 96000
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
patch("tools.transcription_tools.transcribe_audio",
return_value={"success": True, "transcript": "Hello"}), \
patch("tools.voice_mode.is_whisper_hallucination", return_value=False):
await adapter._process_voice_input(111, 42, pcm_data)
callback.assert_called_once_with(guild_id=111, user_id=42, transcript="Hello")
@pytest.mark.asyncio
async def test_process_voice_input_hallucination_filtered(self):
"""Whisper hallucination is filtered out."""
adapter = self._make_adapter()
callback = AsyncMock()
adapter._voice_input_callback = callback
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
patch("tools.transcription_tools.transcribe_audio",
return_value={"success": True, "transcript": "Thank you."}), \
patch("tools.voice_mode.is_whisper_hallucination", return_value=True):
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
callback.assert_not_called()
@pytest.mark.asyncio
async def test_process_voice_input_stt_failure(self):
"""STT failure — callback not called."""
adapter = self._make_adapter()
callback = AsyncMock()
adapter._voice_input_callback = callback
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
patch("tools.transcription_tools.transcribe_audio",
return_value={"success": False, "error": "API error"}):
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
callback.assert_not_called()
@pytest.mark.asyncio
async def test_process_voice_input_exception_caught(self):
"""Exception during processing is caught, no crash."""
adapter = self._make_adapter()
adapter._voice_input_callback = AsyncMock()
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav",
side_effect=RuntimeError("ffmpeg not found")):
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
# Should not raise
# =====================================================================
# stream_tts_to_speaker functional tests
# =====================================================================
class TestStreamTtsToSpeaker:
"""Functional tests for the streaming TTS pipeline."""
def test_none_sentinel_flushes_buffer(self):
"""None sentinel causes remaining buffer to be spoken."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
def display(text):
spoken.append(text)
text_q.put("Hello world.")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=display)
assert done_evt.is_set()
assert any("Hello" in s for s in spoken)
def test_stop_event_aborts_early(self):
"""Setting stop_event causes early exit."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
stop_evt.set()
text_q.put("Should not be spoken.")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
assert len(spoken) == 0
def test_done_event_set_on_exception(self):
"""tts_done_event is set even when an exception occurs."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
# Put a non-string that will cause concatenation to fail
text_q.put(12345)
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt)
assert done_evt.is_set()
def test_think_blocks_stripped(self):
"""<think>...</think> content is not spoken."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
text_q.put("<think>internal reasoning</think>")
text_q.put("Visible response. ")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
joined = " ".join(spoken)
assert "internal reasoning" not in joined
assert "Visible" in joined
def test_sentence_splitting(self):
"""Sentences are split at boundaries and spoken individually."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
# Two sentences long enough to exceed min_sentence_len (20)
text_q.put("This is the first sentence. ")
text_q.put("This is the second sentence. ")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
assert len(spoken) >= 2
def test_markdown_stripped_in_speech(self):
"""Markdown formatting is removed before display/speech."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
text_q.put("**Bold text** and `code`. ")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
# Display callback gets raw text (before markdown stripping)
# But the actual TTS audio would be stripped — we verify pipeline doesn't crash
def test_duplicate_sentences_deduped(self):
"""Repeated sentences are spoken only once."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
# Same sentence twice, each long enough
text_q.put("This is a repeated sentence. ")
text_q.put("This is a repeated sentence. ")
text_q.put(None)
stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
# First occurrence is spoken, second is deduped
assert len(spoken) == 1
def test_no_api_key_display_only(self):
"""Without ELEVENLABS_API_KEY, display callback still works."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
text_q.put("Display only text. ")
text_q.put(None)
with patch.dict(os.environ, {"ELEVENLABS_API_KEY": ""}):
stream_tts_to_speaker(text_q, stop_evt, done_evt,
display_callback=lambda t: spoken.append(t))
assert done_evt.is_set()
assert len(spoken) >= 1
def test_long_buffer_flushed_on_timeout(self):
"""Buffer longer than long_flush_len is flushed on queue timeout."""
from tools.tts_tool import stream_tts_to_speaker
text_q = queue.Queue()
stop_evt = threading.Event()
done_evt = threading.Event()
spoken = []
# Put a long text without sentence boundary, then None after a delay
long_text = "a" * 150 # > long_flush_len (100)
text_q.put(long_text)
def delayed_sentinel():
time.sleep(1.0)
text_q.put(None)
t = threading.Thread(target=delayed_sentinel, daemon=True)
t.start()
stream_tts_to_speaker(text_q, stop_evt, done_evt,
display_callback=lambda t: spoken.append(t))
t.join(timeout=5)
assert done_evt.is_set()
assert len(spoken) >= 1