test(voice): add comprehensive flow tests for voice channel fixes
Tests cover the actual code paths changed in voice fixes: _on_packet DAVE passthrough (8 tests): - Known SSRC + DAVE decrypt success → buffered - Unknown SSRC + DAVE → skip DAVE, passthrough to Opus - DAVE "Unencrypted" error → passthrough, not dropped - DAVE other error → packet dropped - No DAVE session → direct decode - Bot's own SSRC → ignored (echo prevention) - Multiple SSRCs → separate buffers SSRC auto-mapping (6 tests): - Single allowed user → auto-mapped - Multiple allowed users → no auto-map - No allowlist → sole non-bot member inferred - Unallowed user → rejected - Only bot in channel → no map - Auto-map persists across checks Buffer lifecycle (4 tests): - Known SSRC completed utterance - Short buffer ignored - Recent audio waits - Stale unknown buffer discarded TTS playback (10 tests): - play_tts calls play_in_voice_channel in VC - play_tts falls through when not in VC - play_tts wrong channel no match - Voice input dedup (runner skips) - Text + voice_mode combinations - Error/empty response skipped - Agent TTS tool dedup UDP keepalive (2 tests): - Interval within bounds - Silence frame actually sent via send_packet
This commit is contained in:
parent
1cacaccca6
commit
63f0ec96ec
1 changed files with 527 additions and 0 deletions
|
|
@ -2037,3 +2037,530 @@ class TestDisconnectVoiceCleanup:
|
||||||
assert len(adapter._voice_receivers) == 0
|
assert len(adapter._voice_receivers) == 0
|
||||||
assert len(adapter._voice_listen_tasks) == 0
|
assert len(adapter._voice_listen_tasks) == 0
|
||||||
assert len(adapter._voice_timeout_tasks) == 0
|
assert len(adapter._voice_timeout_tasks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================
|
||||||
|
# Discord Voice Channel Flow Tests
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestVoiceReception:
|
||||||
|
"""Audio reception: SSRC mapping, DAVE passthrough, buffer lifecycle."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_receiver(allowed_ids=None, members=None, dave=False, bot_id=9999):
|
||||||
|
from gateway.platforms.discord import VoiceReceiver
|
||||||
|
vc = MagicMock()
|
||||||
|
vc._connection.secret_key = [0] * 32
|
||||||
|
vc._connection.dave_session = MagicMock() if dave else None
|
||||||
|
vc._connection.ssrc = bot_id
|
||||||
|
vc._connection.add_socket_listener = MagicMock()
|
||||||
|
vc._connection.remove_socket_listener = MagicMock()
|
||||||
|
vc._connection.hook = None
|
||||||
|
vc.user = SimpleNamespace(id=bot_id)
|
||||||
|
vc.channel = MagicMock()
|
||||||
|
vc.channel.members = members or []
|
||||||
|
receiver = VoiceReceiver(vc, allowed_user_ids=allowed_ids)
|
||||||
|
return receiver
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fill_buffer(receiver, ssrc, duration_s=1.0, age_s=3.0):
|
||||||
|
"""Add PCM data to buffer. 48kHz stereo 16-bit = 192000 bytes/sec."""
|
||||||
|
size = int(192000 * duration_s)
|
||||||
|
receiver._buffers[ssrc] = bytearray(b"\x00" * size)
|
||||||
|
receiver._last_packet_time[ssrc] = time.monotonic() - age_s
|
||||||
|
|
||||||
|
# -- Known SSRC (normal flow) --
|
||||||
|
|
||||||
|
def test_known_ssrc_returns_completed(self):
|
||||||
|
receiver = self._make_receiver()
|
||||||
|
receiver.start()
|
||||||
|
receiver.map_ssrc(100, 42)
|
||||||
|
self._fill_buffer(receiver, 100)
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 1
|
||||||
|
assert completed[0][0] == 42
|
||||||
|
assert len(receiver._buffers[100]) == 0 # cleared
|
||||||
|
|
||||||
|
def test_known_ssrc_short_buffer_ignored(self):
|
||||||
|
receiver = self._make_receiver()
|
||||||
|
receiver.start()
|
||||||
|
receiver.map_ssrc(100, 42)
|
||||||
|
self._fill_buffer(receiver, 100, duration_s=0.1) # too short
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 0
|
||||||
|
|
||||||
|
def test_known_ssrc_recent_audio_waits(self):
|
||||||
|
receiver = self._make_receiver()
|
||||||
|
receiver.start()
|
||||||
|
receiver.map_ssrc(100, 42)
|
||||||
|
self._fill_buffer(receiver, 100, age_s=0.0) # just arrived
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 0
|
||||||
|
|
||||||
|
# -- Unknown SSRC + DAVE passthrough --
|
||||||
|
|
||||||
|
def test_unknown_ssrc_no_automap_no_completed(self):
|
||||||
|
"""Unknown SSRC, no members to infer — buffer cleared, not returned."""
|
||||||
|
receiver = self._make_receiver(dave=True, members=[])
|
||||||
|
receiver.start()
|
||||||
|
self._fill_buffer(receiver, 100)
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 0
|
||||||
|
assert len(receiver._buffers[100]) == 0
|
||||||
|
|
||||||
|
def test_unknown_ssrc_late_speaking_event(self):
|
||||||
|
"""Audio buffered before SPEAKING → SPEAKING maps → next check returns it."""
|
||||||
|
receiver = self._make_receiver(dave=True)
|
||||||
|
receiver.start()
|
||||||
|
self._fill_buffer(receiver, 100, age_s=0.0) # still receiving
|
||||||
|
# No user yet
|
||||||
|
assert receiver.check_silence() == []
|
||||||
|
# SPEAKING event arrives
|
||||||
|
receiver.map_ssrc(100, 42)
|
||||||
|
# Silence kicks in
|
||||||
|
receiver._last_packet_time[100] = time.monotonic() - 3.0
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 1
|
||||||
|
assert completed[0][0] == 42
|
||||||
|
|
||||||
|
# -- SSRC auto-mapping --
|
||||||
|
|
||||||
|
def test_automap_single_allowed_user(self):
|
||||||
|
members = [
|
||||||
|
SimpleNamespace(id=9999, name="Bot"),
|
||||||
|
SimpleNamespace(id=42, name="Alice"),
|
||||||
|
]
|
||||||
|
receiver = self._make_receiver(allowed_ids={"42"}, members=members)
|
||||||
|
receiver.start()
|
||||||
|
self._fill_buffer(receiver, 100)
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 1
|
||||||
|
assert completed[0][0] == 42
|
||||||
|
assert receiver._ssrc_to_user[100] == 42
|
||||||
|
|
||||||
|
def test_automap_multiple_allowed_users_no_map(self):
|
||||||
|
members = [
|
||||||
|
SimpleNamespace(id=9999, name="Bot"),
|
||||||
|
SimpleNamespace(id=42, name="Alice"),
|
||||||
|
SimpleNamespace(id=43, name="Bob"),
|
||||||
|
]
|
||||||
|
receiver = self._make_receiver(allowed_ids={"42", "43"}, members=members)
|
||||||
|
receiver.start()
|
||||||
|
self._fill_buffer(receiver, 100)
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 0
|
||||||
|
|
||||||
|
def test_automap_no_allowlist_single_member(self):
|
||||||
|
"""No allowed_user_ids → sole non-bot member inferred."""
|
||||||
|
members = [
|
||||||
|
SimpleNamespace(id=9999, name="Bot"),
|
||||||
|
SimpleNamespace(id=42, name="Alice"),
|
||||||
|
]
|
||||||
|
receiver = self._make_receiver(allowed_ids=None, members=members)
|
||||||
|
receiver.start()
|
||||||
|
self._fill_buffer(receiver, 100)
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 1
|
||||||
|
assert completed[0][0] == 42
|
||||||
|
|
||||||
|
def test_automap_unallowed_user_rejected(self):
|
||||||
|
"""User in channel but not in allowed list — not mapped."""
|
||||||
|
members = [
|
||||||
|
SimpleNamespace(id=9999, name="Bot"),
|
||||||
|
SimpleNamespace(id=42, name="Alice"),
|
||||||
|
]
|
||||||
|
receiver = self._make_receiver(allowed_ids={"99"}, members=members)
|
||||||
|
receiver.start()
|
||||||
|
self._fill_buffer(receiver, 100)
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 0
|
||||||
|
|
||||||
|
def test_automap_only_bot_in_channel(self):
|
||||||
|
"""Only bot in channel — no one to map to."""
|
||||||
|
members = [SimpleNamespace(id=9999, name="Bot")]
|
||||||
|
receiver = self._make_receiver(allowed_ids=None, members=members)
|
||||||
|
receiver.start()
|
||||||
|
self._fill_buffer(receiver, 100)
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 0
|
||||||
|
|
||||||
|
def test_automap_persists_across_calls(self):
|
||||||
|
"""Auto-mapped SSRC stays mapped for subsequent checks."""
|
||||||
|
members = [
|
||||||
|
SimpleNamespace(id=9999, name="Bot"),
|
||||||
|
SimpleNamespace(id=42, name="Alice"),
|
||||||
|
]
|
||||||
|
receiver = self._make_receiver(allowed_ids={"42"}, members=members)
|
||||||
|
receiver.start()
|
||||||
|
self._fill_buffer(receiver, 100)
|
||||||
|
receiver.check_silence()
|
||||||
|
assert receiver._ssrc_to_user[100] == 42
|
||||||
|
# Second utterance — should use cached mapping
|
||||||
|
self._fill_buffer(receiver, 100)
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
assert len(completed) == 1
|
||||||
|
assert completed[0][0] == 42
|
||||||
|
|
||||||
|
# -- Stale buffer cleanup --
|
||||||
|
|
||||||
|
def test_stale_unknown_buffer_discarded(self):
|
||||||
|
"""Buffer with no user and very old timestamp is discarded."""
|
||||||
|
receiver = self._make_receiver()
|
||||||
|
receiver.start()
|
||||||
|
receiver._buffers[200] = bytearray(b"\x00" * 100)
|
||||||
|
receiver._last_packet_time[200] = time.monotonic() - 10.0
|
||||||
|
receiver.check_silence()
|
||||||
|
assert 200 not in receiver._buffers
|
||||||
|
|
||||||
|
# -- Pause / resume (echo prevention) --
|
||||||
|
|
||||||
|
def test_paused_receiver_ignores_packets(self):
|
||||||
|
receiver = self._make_receiver()
|
||||||
|
receiver.start()
|
||||||
|
receiver.pause()
|
||||||
|
receiver._on_packet(b"\x00" * 100)
|
||||||
|
assert len(receiver._buffers) == 0
|
||||||
|
|
||||||
|
def test_resumed_receiver_accepts_packets(self):
|
||||||
|
receiver = self._make_receiver()
|
||||||
|
receiver.start()
|
||||||
|
receiver.pause()
|
||||||
|
receiver.resume()
|
||||||
|
assert receiver._paused is False
|
||||||
|
|
||||||
|
# -- _on_packet DAVE passthrough behavior --
|
||||||
|
|
||||||
|
def _make_receiver_with_nacl(self, dave_session=None, mapped_ssrcs=None):
|
||||||
|
"""Create a receiver that can process _on_packet with mocked NaCl + Opus."""
|
||||||
|
from gateway.platforms.discord import VoiceReceiver
|
||||||
|
vc = MagicMock()
|
||||||
|
vc._connection.secret_key = [0] * 32
|
||||||
|
vc._connection.dave_session = dave_session
|
||||||
|
vc._connection.ssrc = 9999
|
||||||
|
vc._connection.add_socket_listener = MagicMock()
|
||||||
|
vc._connection.remove_socket_listener = MagicMock()
|
||||||
|
vc._connection.hook = None
|
||||||
|
vc.user = SimpleNamespace(id=9999)
|
||||||
|
vc.channel = MagicMock()
|
||||||
|
vc.channel.members = []
|
||||||
|
receiver = VoiceReceiver(vc)
|
||||||
|
receiver.start()
|
||||||
|
# Pre-map SSRCs if provided
|
||||||
|
if mapped_ssrcs:
|
||||||
|
for ssrc, uid in mapped_ssrcs.items():
|
||||||
|
receiver.map_ssrc(ssrc, uid)
|
||||||
|
return receiver
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_rtp_packet(ssrc=100, seq=1, timestamp=960):
|
||||||
|
"""Build a minimal valid RTP packet for _on_packet.
|
||||||
|
|
||||||
|
We need: RTP header (12 bytes) + encrypted payload + 4-byte nonce.
|
||||||
|
NaCl decrypt is mocked so payload content doesn't matter.
|
||||||
|
"""
|
||||||
|
import struct
|
||||||
|
# RTP header: version=2, payload_type=0x78, no extension, no CSRC
|
||||||
|
header = struct.pack(">BBHII", 0x80, 0x78, seq, timestamp, ssrc)
|
||||||
|
# Fake encrypted payload (NaCl will be mocked) + 4 byte nonce
|
||||||
|
payload = b"\x00" * 20 + b"\x00\x00\x00\x01"
|
||||||
|
return header + payload
|
||||||
|
|
||||||
|
def _inject_mock_decoder(self, receiver, ssrc):
|
||||||
|
"""Pre-inject a mock Opus decoder for the given SSRC."""
|
||||||
|
mock_decoder = MagicMock()
|
||||||
|
mock_decoder.decode.return_value = b"\x00" * 3840
|
||||||
|
receiver._decoders[ssrc] = mock_decoder
|
||||||
|
return mock_decoder
|
||||||
|
|
||||||
|
def test_on_packet_dave_known_user_decrypt_ok(self):
|
||||||
|
"""Known SSRC + DAVE decrypt success → audio buffered."""
|
||||||
|
dave = MagicMock()
|
||||||
|
dave.decrypt.return_value = b"\xf8\xff\xfe"
|
||||||
|
receiver = self._make_receiver_with_nacl(
|
||||||
|
dave_session=dave, mapped_ssrcs={100: 42}
|
||||||
|
)
|
||||||
|
self._inject_mock_decoder(receiver, 100)
|
||||||
|
|
||||||
|
with patch("nacl.secret.Aead") as mock_aead:
|
||||||
|
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||||
|
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||||
|
|
||||||
|
assert 100 in receiver._buffers
|
||||||
|
assert len(receiver._buffers[100]) > 0
|
||||||
|
dave.decrypt.assert_called_once()
|
||||||
|
|
||||||
|
def test_on_packet_dave_unknown_ssrc_passthrough(self):
|
||||||
|
"""Unknown SSRC + DAVE → skip DAVE, attempt Opus decode (passthrough)."""
|
||||||
|
dave = MagicMock()
|
||||||
|
receiver = self._make_receiver_with_nacl(dave_session=dave)
|
||||||
|
self._inject_mock_decoder(receiver, 100)
|
||||||
|
|
||||||
|
with patch("nacl.secret.Aead") as mock_aead:
|
||||||
|
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||||
|
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||||
|
|
||||||
|
dave.decrypt.assert_not_called()
|
||||||
|
assert 100 in receiver._buffers
|
||||||
|
assert len(receiver._buffers[100]) > 0
|
||||||
|
|
||||||
|
def test_on_packet_dave_unencrypted_error_passthrough(self):
|
||||||
|
"""DAVE decrypt 'Unencrypted' error → use data as-is, don't drop."""
|
||||||
|
dave = MagicMock()
|
||||||
|
dave.decrypt.side_effect = Exception(
|
||||||
|
"Failed to decrypt: DecryptionFailed(UnencryptedWhenPassthroughDisabled)"
|
||||||
|
)
|
||||||
|
receiver = self._make_receiver_with_nacl(
|
||||||
|
dave_session=dave, mapped_ssrcs={100: 42}
|
||||||
|
)
|
||||||
|
self._inject_mock_decoder(receiver, 100)
|
||||||
|
|
||||||
|
with patch("nacl.secret.Aead") as mock_aead:
|
||||||
|
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||||
|
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||||
|
|
||||||
|
assert 100 in receiver._buffers
|
||||||
|
assert len(receiver._buffers[100]) > 0
|
||||||
|
|
||||||
|
def test_on_packet_dave_other_error_drops(self):
|
||||||
|
"""DAVE decrypt non-Unencrypted error → packet dropped."""
|
||||||
|
dave = MagicMock()
|
||||||
|
dave.decrypt.side_effect = Exception("KeyRotationFailed")
|
||||||
|
receiver = self._make_receiver_with_nacl(
|
||||||
|
dave_session=dave, mapped_ssrcs={100: 42}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("nacl.secret.Aead") as mock_aead:
|
||||||
|
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||||
|
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||||
|
|
||||||
|
assert len(receiver._buffers.get(100, b"")) == 0
|
||||||
|
|
||||||
|
def test_on_packet_no_dave_direct_decode(self):
|
||||||
|
"""No DAVE session → decode directly."""
|
||||||
|
receiver = self._make_receiver_with_nacl(dave_session=None)
|
||||||
|
self._inject_mock_decoder(receiver, 100)
|
||||||
|
|
||||||
|
with patch("nacl.secret.Aead") as mock_aead:
|
||||||
|
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||||
|
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||||
|
|
||||||
|
assert 100 in receiver._buffers
|
||||||
|
assert len(receiver._buffers[100]) > 0
|
||||||
|
|
||||||
|
def test_on_packet_bot_own_ssrc_ignored(self):
|
||||||
|
"""Bot's own SSRC → dropped (echo prevention)."""
|
||||||
|
receiver = self._make_receiver_with_nacl()
|
||||||
|
with patch("nacl.secret.Aead"):
|
||||||
|
receiver._on_packet(self._build_rtp_packet(ssrc=9999))
|
||||||
|
assert len(receiver._buffers) == 0
|
||||||
|
|
||||||
|
def test_on_packet_multiple_ssrcs_separate_buffers(self):
|
||||||
|
"""Different SSRCs → separate buffers."""
|
||||||
|
receiver = self._make_receiver_with_nacl(dave_session=None)
|
||||||
|
self._inject_mock_decoder(receiver, 100)
|
||||||
|
self._inject_mock_decoder(receiver, 200)
|
||||||
|
|
||||||
|
with patch("nacl.secret.Aead") as mock_aead:
|
||||||
|
mock_aead.return_value.decrypt.return_value = b"\xf8\xff\xfe"
|
||||||
|
receiver._on_packet(self._build_rtp_packet(ssrc=100))
|
||||||
|
receiver._on_packet(self._build_rtp_packet(ssrc=200))
|
||||||
|
|
||||||
|
assert 100 in receiver._buffers
|
||||||
|
assert 200 in receiver._buffers
|
||||||
|
|
||||||
|
|
||||||
|
class TestVoiceTTSPlayback:
|
||||||
|
"""TTS playback: play_tts in VC, dedup, fallback."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_discord_adapter():
|
||||||
|
from gateway.platforms.discord import DiscordAdapter
|
||||||
|
from gateway.config import PlatformConfig, Platform
|
||||||
|
config = PlatformConfig(enabled=True, extra={})
|
||||||
|
config.token = "fake-token"
|
||||||
|
adapter = object.__new__(DiscordAdapter)
|
||||||
|
adapter.platform = Platform.DISCORD
|
||||||
|
adapter.config = config
|
||||||
|
adapter._voice_clients = {}
|
||||||
|
adapter._voice_text_channels = {}
|
||||||
|
adapter._voice_receivers = {}
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
# -- play_tts behavior --
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_play_tts_plays_in_vc(self):
|
||||||
|
"""play_tts calls play_in_voice_channel when bot is in VC."""
|
||||||
|
adapter = self._make_discord_adapter()
|
||||||
|
mock_vc = MagicMock()
|
||||||
|
mock_vc.is_connected.return_value = True
|
||||||
|
adapter._voice_clients[111] = mock_vc
|
||||||
|
adapter._voice_text_channels[111] = 123
|
||||||
|
|
||||||
|
played = []
|
||||||
|
async def fake_play(gid, path):
|
||||||
|
played.append((gid, path))
|
||||||
|
return True
|
||||||
|
adapter.play_in_voice_channel = fake_play
|
||||||
|
|
||||||
|
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg")
|
||||||
|
assert result.success is True
|
||||||
|
assert played == [(111, "/tmp/tts.ogg")]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_play_tts_fallback_when_not_in_vc(self):
|
||||||
|
"""play_tts sends as file attachment when bot is not in VC."""
|
||||||
|
adapter = self._make_discord_adapter()
|
||||||
|
from gateway.platforms.base import SendResult
|
||||||
|
adapter.send_voice = AsyncMock(return_value=SendResult(success=False, error="no client"))
|
||||||
|
result = await adapter.play_tts(chat_id="123", audio_path="/tmp/tts.ogg")
|
||||||
|
assert result.success is False
|
||||||
|
adapter.send_voice.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_play_tts_wrong_channel_no_match(self):
|
||||||
|
"""play_tts doesn't match if chat_id is for a different channel."""
|
||||||
|
adapter = self._make_discord_adapter()
|
||||||
|
mock_vc = MagicMock()
|
||||||
|
mock_vc.is_connected.return_value = True
|
||||||
|
adapter._voice_clients[111] = mock_vc
|
||||||
|
adapter._voice_text_channels[111] = 123
|
||||||
|
|
||||||
|
from gateway.platforms.base import SendResult
|
||||||
|
adapter.send_voice = AsyncMock(return_value=SendResult(success=True))
|
||||||
|
# Different chat_id — shouldn't match VC
|
||||||
|
result = await adapter.play_tts(chat_id="999", audio_path="/tmp/tts.ogg")
|
||||||
|
adapter.send_voice.assert_called_once()
|
||||||
|
|
||||||
|
# -- Runner dedup --
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_runner():
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
runner._voice_mode = {}
|
||||||
|
runner.adapters = {}
|
||||||
|
return runner
|
||||||
|
|
||||||
|
def _call_should_reply(self, runner, voice_mode, msg_type, response="Hello", agent_msgs=None):
|
||||||
|
from gateway.platforms.base import MessageType, MessageEvent, SessionSource
|
||||||
|
from gateway.config import Platform
|
||||||
|
runner._voice_mode["ch1"] = voice_mode
|
||||||
|
source = SessionSource(
|
||||||
|
platform=Platform.DISCORD, chat_id="ch1",
|
||||||
|
user_id="1", user_name="test", chat_type="channel",
|
||||||
|
)
|
||||||
|
event = MessageEvent(source=source, text="test", message_type=msg_type)
|
||||||
|
return runner._should_send_voice_reply(event, response, agent_msgs or [])
|
||||||
|
|
||||||
|
def test_voice_input_runner_skips(self):
|
||||||
|
"""Voice input: runner skips — base adapter handles via play_tts."""
|
||||||
|
from gateway.platforms.base import MessageType
|
||||||
|
runner = self._make_runner()
|
||||||
|
assert self._call_should_reply(runner, "all", MessageType.VOICE) is False
|
||||||
|
|
||||||
|
def test_text_input_voice_all_runner_fires(self):
|
||||||
|
"""Text input + voice_mode=all: runner generates TTS."""
|
||||||
|
from gateway.platforms.base import MessageType
|
||||||
|
runner = self._make_runner()
|
||||||
|
assert self._call_should_reply(runner, "all", MessageType.TEXT) is True
|
||||||
|
|
||||||
|
def test_text_input_voice_off_no_tts(self):
|
||||||
|
"""Text input + voice_mode=off: no TTS."""
|
||||||
|
from gateway.platforms.base import MessageType
|
||||||
|
runner = self._make_runner()
|
||||||
|
assert self._call_should_reply(runner, "off", MessageType.TEXT) is False
|
||||||
|
|
||||||
|
def test_text_input_voice_only_no_tts(self):
|
||||||
|
"""Text input + voice_mode=voice_only: no TTS for text."""
|
||||||
|
from gateway.platforms.base import MessageType
|
||||||
|
runner = self._make_runner()
|
||||||
|
assert self._call_should_reply(runner, "voice_only", MessageType.TEXT) is False
|
||||||
|
|
||||||
|
def test_error_response_no_tts(self):
|
||||||
|
"""Error response: no TTS regardless of voice_mode."""
|
||||||
|
from gateway.platforms.base import MessageType
|
||||||
|
runner = self._make_runner()
|
||||||
|
assert self._call_should_reply(runner, "all", MessageType.TEXT, response="Error: boom") is False
|
||||||
|
|
||||||
|
def test_empty_response_no_tts(self):
|
||||||
|
"""Empty response: no TTS."""
|
||||||
|
from gateway.platforms.base import MessageType
|
||||||
|
runner = self._make_runner()
|
||||||
|
assert self._call_should_reply(runner, "all", MessageType.TEXT, response="") is False
|
||||||
|
|
||||||
|
def test_agent_tts_tool_dedup(self):
|
||||||
|
"""Agent already called text_to_speech tool: runner skips."""
|
||||||
|
from gateway.platforms.base import MessageType
|
||||||
|
runner = self._make_runner()
|
||||||
|
agent_msgs = [{"role": "assistant", "tool_calls": [
|
||||||
|
{"id": "1", "type": "function", "function": {"name": "text_to_speech", "arguments": "{}"}}
|
||||||
|
]}]
|
||||||
|
assert self._call_should_reply(runner, "all", MessageType.TEXT, agent_msgs=agent_msgs) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestUDPKeepalive:
|
||||||
|
"""UDP keepalive prevents Discord from dropping the voice session."""
|
||||||
|
|
||||||
|
def test_keepalive_interval_is_reasonable(self):
|
||||||
|
from gateway.platforms.discord import DiscordAdapter
|
||||||
|
interval = DiscordAdapter._KEEPALIVE_INTERVAL
|
||||||
|
assert 5 <= interval <= 30, f"Keepalive interval {interval}s should be between 5-30s"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keepalive_sends_silence_frame(self):
|
||||||
|
"""Listen loop sends silence frame via send_packet after interval."""
|
||||||
|
from gateway.platforms.discord import DiscordAdapter
|
||||||
|
from gateway.config import PlatformConfig, Platform
|
||||||
|
|
||||||
|
config = PlatformConfig(enabled=True, extra={})
|
||||||
|
config.token = "fake"
|
||||||
|
adapter = object.__new__(DiscordAdapter)
|
||||||
|
adapter.platform = Platform.DISCORD
|
||||||
|
adapter.config = config
|
||||||
|
adapter._voice_clients = {}
|
||||||
|
adapter._voice_text_channels = {}
|
||||||
|
adapter._voice_receivers = {}
|
||||||
|
adapter._voice_listen_tasks = {}
|
||||||
|
|
||||||
|
# Mock VC and receiver
|
||||||
|
mock_vc = MagicMock()
|
||||||
|
mock_vc.is_connected.return_value = True
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
adapter._voice_clients[111] = mock_vc
|
||||||
|
mock_vc._connection = mock_conn
|
||||||
|
|
||||||
|
from gateway.platforms.discord import VoiceReceiver
|
||||||
|
mock_receiver_vc = MagicMock()
|
||||||
|
mock_receiver_vc._connection.secret_key = [0] * 32
|
||||||
|
mock_receiver_vc._connection.dave_session = None
|
||||||
|
mock_receiver_vc._connection.ssrc = 9999
|
||||||
|
mock_receiver_vc._connection.add_socket_listener = MagicMock()
|
||||||
|
mock_receiver_vc._connection.remove_socket_listener = MagicMock()
|
||||||
|
mock_receiver_vc._connection.hook = None
|
||||||
|
receiver = VoiceReceiver(mock_receiver_vc)
|
||||||
|
receiver.start()
|
||||||
|
adapter._voice_receivers[111] = receiver
|
||||||
|
|
||||||
|
# Set keepalive interval very short for test
|
||||||
|
original_interval = DiscordAdapter._KEEPALIVE_INTERVAL
|
||||||
|
DiscordAdapter._KEEPALIVE_INTERVAL = 0.1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run listen loop briefly
|
||||||
|
import asyncio
|
||||||
|
loop_task = asyncio.create_task(adapter._voice_listen_loop(111))
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
receiver._running = False # stop loop
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
loop_task.cancel()
|
||||||
|
try:
|
||||||
|
await loop_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# send_packet should have been called with silence frame
|
||||||
|
mock_conn.send_packet.assert_called_with(b'\xf8\xff\xfe')
|
||||||
|
finally:
|
||||||
|
DiscordAdapter._KEEPALIVE_INTERVAL = original_interval
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue