fix: 8 voice pipeline bugs with tests proving each fix
1. VoiceReceiver.stop() now acquires _lock before clearing shared state to prevent race with _on_packet on the socket reader thread 2. _packet_debug_count moved from class-level to instance-level to avoid cross-instance race condition in multi-guild setups 3. play_in_voice_channel uses asyncio.get_running_loop() instead of deprecated asyncio.get_event_loop() 4. _send_voice_reply uses uuid for filenames instead of time-based names that can collide when two replies happen in the same second 5. Voice timeout now notifies runner via _on_voice_disconnect callback so runner cleans up _voice_mode state (prevents orphaned TTS replies) 6. play_in_voice_channel adds PLAYBACK_TIMEOUT (120s) to prevent infinite blocking when FFmpeg callback is never called 7. _send_voice_reply moves temp file cleanup to finally block so files are always cleaned up even when send_voice/play raises 8. Base adapter auto-TTS wraps play_tts in try/finally with os.remove to clean up generated audio files after playback 18 new tests (120 total voice tests)
This commit is contained in:
parent
c925d2ee76
commit
9722bd8be0
4 changed files with 517 additions and 26 deletions
|
|
@ -752,11 +752,17 @@ class BasePlatformAdapter(ABC):
|
|||
|
||||
# Play TTS audio before text (voice-first experience)
|
||||
if _tts_path and Path(_tts_path).exists():
|
||||
await self.play_tts(
|
||||
chat_id=event.source.chat_id,
|
||||
audio_path=_tts_path,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
try:
|
||||
await self.play_tts(
|
||||
chat_id=event.source.chat_id,
|
||||
audio_path=_tts_path,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
os.remove(_tts_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Send the text portion
|
||||
if text_content:
|
||||
|
|
|
|||
|
|
@ -108,6 +108,9 @@ class VoiceReceiver:
|
|||
# Pause flag: don't capture while bot is playing TTS
|
||||
self._paused = False
|
||||
|
||||
# Debug logging counter (instance-level to avoid cross-instance races)
|
||||
self._packet_debug_count = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -131,10 +134,11 @@ class VoiceReceiver:
|
|||
self._vc._connection.remove_socket_listener(self._on_packet)
|
||||
except Exception:
|
||||
pass
|
||||
self._buffers.clear()
|
||||
self._last_packet_time.clear()
|
||||
self._decoders.clear()
|
||||
self._ssrc_to_user.clear()
|
||||
with self._lock:
|
||||
self._buffers.clear()
|
||||
self._last_packet_time.clear()
|
||||
self._decoders.clear()
|
||||
self._ssrc_to_user.clear()
|
||||
logger.info("VoiceReceiver stopped")
|
||||
|
||||
def pause(self):
|
||||
|
|
@ -188,15 +192,13 @@ class VoiceReceiver:
|
|||
# Packet handler (called from SocketReader thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_packet_debug_count = 0 # class-level counter for debug logging
|
||||
|
||||
def _on_packet(self, data: bytes):
|
||||
if not self._running or self._paused:
|
||||
return
|
||||
|
||||
# Log first few raw packets for debugging
|
||||
VoiceReceiver._packet_debug_count += 1
|
||||
if VoiceReceiver._packet_debug_count <= 5:
|
||||
self._packet_debug_count += 1
|
||||
if self._packet_debug_count <= 5:
|
||||
logger.info(
|
||||
"Raw UDP packet: len=%d, first_bytes=%s",
|
||||
len(data), data[:4].hex() if len(data) >= 4 else "short",
|
||||
|
|
@ -209,7 +211,7 @@ class VoiceReceiver:
|
|||
# Lower bits may vary (padding, extension, CSRC count).
|
||||
# Payload type (byte 1 lower 7 bits) = 0x78 (120) for voice.
|
||||
if (data[0] >> 6) != 2 or (data[1] & 0x7F) != 0x78:
|
||||
if VoiceReceiver._packet_debug_count <= 5:
|
||||
if self._packet_debug_count <= 5:
|
||||
logger.info("Skipped non-RTP: byte0=0x%02x byte1=0x%02x", data[0], data[1])
|
||||
return
|
||||
|
||||
|
|
@ -235,7 +237,7 @@ class VoiceReceiver:
|
|||
ext_words = struct.unpack_from(">H", data, ext_preamble_offset + 2)[0]
|
||||
ext_data_len = ext_words * 4
|
||||
|
||||
if VoiceReceiver._packet_debug_count <= 10:
|
||||
if self._packet_debug_count <= 10:
|
||||
with self._lock:
|
||||
known_user = self._ssrc_to_user.get(ssrc, "unknown")
|
||||
logger.info(
|
||||
|
|
@ -258,7 +260,7 @@ class VoiceReceiver:
|
|||
box = nacl.secret.Aead(self._secret_key)
|
||||
decrypted = box.decrypt(encrypted, header, bytes(nonce))
|
||||
except Exception as e:
|
||||
if VoiceReceiver._packet_debug_count <= 10:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("NaCl decrypt failed: %s (hdr=%d, enc=%d)", e, header_size, len(encrypted))
|
||||
return
|
||||
|
||||
|
|
@ -271,7 +273,7 @@ class VoiceReceiver:
|
|||
with self._lock:
|
||||
user_id = self._ssrc_to_user.get(ssrc, 0)
|
||||
if user_id == 0:
|
||||
if VoiceReceiver._packet_debug_count <= 10:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc)
|
||||
return # unknown user, can't DAVE-decrypt
|
||||
try:
|
||||
|
|
@ -280,7 +282,7 @@ class VoiceReceiver:
|
|||
user_id, davey.MediaType.audio, decrypted
|
||||
)
|
||||
except Exception as e:
|
||||
if VoiceReceiver._packet_debug_count <= 10:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
|
||||
return
|
||||
|
||||
|
|
@ -394,6 +396,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver
|
||||
self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop
|
||||
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
||||
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
|
|
@ -751,6 +754,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
task.cancel()
|
||||
self._voice_text_channels.pop(guild_id, None)
|
||||
|
||||
# Maximum seconds to wait for voice playback before giving up
|
||||
PLAYBACK_TIMEOUT = 120
|
||||
|
||||
async def play_in_voice_channel(self, guild_id: int, audio_path: str) -> bool:
|
||||
"""Play an audio file in the connected voice channel."""
|
||||
vc = self._voice_clients.get(guild_id)
|
||||
|
|
@ -763,12 +769,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
receiver.pause()
|
||||
|
||||
try:
|
||||
# Wait for current playback to finish
|
||||
# Wait for current playback to finish (with timeout)
|
||||
wait_start = time.monotonic()
|
||||
while vc.is_playing():
|
||||
if time.monotonic() - wait_start > self.PLAYBACK_TIMEOUT:
|
||||
logger.warning("Timed out waiting for previous playback to finish")
|
||||
vc.stop()
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
done = asyncio.Event()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _after(error):
|
||||
if error:
|
||||
|
|
@ -778,7 +789,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
source = discord.FFmpegPCMAudio(audio_path)
|
||||
source = discord.PCMVolumeTransformer(source, volume=1.0)
|
||||
vc.play(source, after=_after)
|
||||
await done.wait()
|
||||
try:
|
||||
await asyncio.wait_for(done.wait(), timeout=self.PLAYBACK_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Voice playback timed out after %ds", self.PLAYBACK_TIMEOUT)
|
||||
vc.stop()
|
||||
self._reset_voice_timeout(guild_id)
|
||||
return True
|
||||
finally:
|
||||
|
|
@ -814,6 +829,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
return
|
||||
text_ch_id = self._voice_text_channels.get(guild_id)
|
||||
await self.leave_voice_channel(guild_id)
|
||||
# Notify the runner so it can clean up voice_mode state
|
||||
if self._on_voice_disconnect and text_ch_id:
|
||||
try:
|
||||
self._on_voice_disconnect(str(text_ch_id))
|
||||
except Exception:
|
||||
pass
|
||||
if text_ch_id and self._client:
|
||||
ch = self._client.get_channel(text_ch_id)
|
||||
if ch:
|
||||
|
|
|
|||
|
|
@ -2190,10 +2190,12 @@ class GatewayRunner:
|
|||
if not voice_channel:
|
||||
return "You need to be in a voice channel first."
|
||||
|
||||
# Wire callback BEFORE join so voice input arriving immediately
|
||||
# Wire callbacks BEFORE join so voice input arriving immediately
|
||||
# after connection is not lost.
|
||||
if hasattr(adapter, "_voice_input_callback"):
|
||||
adapter._voice_input_callback = self._handle_voice_channel_input
|
||||
if hasattr(adapter, "_on_voice_disconnect"):
|
||||
adapter._on_voice_disconnect = self._handle_voice_timeout_cleanup
|
||||
|
||||
try:
|
||||
success = await adapter.join_voice_channel(voice_channel)
|
||||
|
|
@ -2236,6 +2238,14 @@ class GatewayRunner:
|
|||
adapter._voice_input_callback = None
|
||||
return "Left voice channel."
|
||||
|
||||
def _handle_voice_timeout_cleanup(self, chat_id: str) -> None:
|
||||
"""Called by the adapter when a voice channel times out.
|
||||
|
||||
Cleans up runner-side voice_mode state that the adapter cannot reach.
|
||||
"""
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._save_voice_modes()
|
||||
|
||||
async def _handle_voice_channel_input(
|
||||
self, guild_id: int, user_id: int, transcript: str
|
||||
):
|
||||
|
|
@ -2339,6 +2349,9 @@ class GatewayRunner:
|
|||
|
||||
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
|
||||
"""Generate TTS audio and send as a voice message before the text reply."""
|
||||
import uuid as _uuid
|
||||
audio_path = None
|
||||
actual_path = None
|
||||
try:
|
||||
from tools.tts_tool import text_to_speech_tool, _strip_markdown_for_tts
|
||||
|
||||
|
|
@ -2350,7 +2363,7 @@ class GatewayRunner:
|
|||
# The TTS tool may convert to .ogg — use file_path from result.
|
||||
audio_path = os.path.join(
|
||||
tempfile.gettempdir(), "hermes_voice",
|
||||
f"tts_reply_{int(time.time())}_{id(event) % 10000}.mp3",
|
||||
f"tts_reply_{_uuid.uuid4().hex[:12]}.mp3",
|
||||
)
|
||||
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
|
||||
|
||||
|
|
@ -2387,13 +2400,14 @@ class GatewayRunner:
|
|||
if "metadata" not in sig.parameters:
|
||||
send_kwargs.pop("metadata", None)
|
||||
await adapter.send_voice(**send_kwargs)
|
||||
for p in {audio_path, actual_path}:
|
||||
except Exception as e:
|
||||
logger.warning("Auto voice reply failed: %s", e, exc_info=True)
|
||||
finally:
|
||||
for p in {audio_path, actual_path} - {None}:
|
||||
try:
|
||||
os.unlink(p)
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Auto voice reply failed: %s", e, exc_info=True)
|
||||
|
||||
async def _handle_rollback_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /rollback command — list or restore filesystem checkpoints."""
|
||||
|
|
|
|||
|
|
@ -1402,3 +1402,453 @@ class TestStreamTtsToSpeaker:
|
|||
t.join(timeout=5)
|
||||
assert done_evt.is_set()
|
||||
assert len(spoken) >= 1
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Bug 1: VoiceReceiver.stop() must hold lock while clearing shared state
|
||||
# =====================================================================
|
||||
|
||||
class TestStopAcquiresLock:
|
||||
"""stop() must acquire _lock before clearing buffers/state."""
|
||||
|
||||
@staticmethod
|
||||
def _make_receiver():
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = None
|
||||
vc._connection.ssrc = 1
|
||||
return VoiceReceiver(vc)
|
||||
|
||||
def test_stop_clears_under_lock(self):
|
||||
"""stop() acquires _lock before clearing buffers.
|
||||
|
||||
Verify by holding the lock from another thread and checking that
|
||||
stop() blocks until the lock is released.
|
||||
"""
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
receiver._buffers[100] = bytearray(b"\x00" * 500)
|
||||
receiver._last_packet_time[100] = time.monotonic()
|
||||
receiver.map_ssrc(100, 42)
|
||||
|
||||
# Hold the lock from another thread
|
||||
lock_acquired = threading.Event()
|
||||
release_lock = threading.Event()
|
||||
|
||||
def hold_lock():
|
||||
with receiver._lock:
|
||||
lock_acquired.set()
|
||||
release_lock.wait(timeout=5)
|
||||
|
||||
holder = threading.Thread(target=hold_lock, daemon=True)
|
||||
holder.start()
|
||||
lock_acquired.wait(timeout=2)
|
||||
|
||||
# stop() in another thread — should block on the lock
|
||||
stop_done = threading.Event()
|
||||
|
||||
def do_stop():
|
||||
receiver.stop()
|
||||
stop_done.set()
|
||||
|
||||
stopper = threading.Thread(target=do_stop, daemon=True)
|
||||
stopper.start()
|
||||
|
||||
# stop should NOT complete while lock is held
|
||||
assert not stop_done.wait(timeout=0.3), \
|
||||
"stop() should block while _lock is held by another thread"
|
||||
|
||||
# Release the lock — stop should complete
|
||||
release_lock.set()
|
||||
assert stop_done.wait(timeout=2), \
|
||||
"stop() should complete after lock is released"
|
||||
|
||||
# State should be cleared
|
||||
assert len(receiver._buffers) == 0
|
||||
assert len(receiver._ssrc_to_user) == 0
|
||||
holder.join(timeout=2)
|
||||
stopper.join(timeout=2)
|
||||
|
||||
def test_stop_does_not_deadlock_with_on_packet(self):
|
||||
"""stop() during _on_packet should not deadlock."""
|
||||
receiver = self._make_receiver()
|
||||
receiver.start()
|
||||
|
||||
blocked = threading.Event()
|
||||
released = threading.Event()
|
||||
|
||||
def hold_lock():
|
||||
with receiver._lock:
|
||||
blocked.set()
|
||||
released.wait(timeout=2)
|
||||
|
||||
t = threading.Thread(target=hold_lock, daemon=True)
|
||||
t.start()
|
||||
blocked.wait(timeout=2)
|
||||
|
||||
stop_done = threading.Event()
|
||||
|
||||
def do_stop():
|
||||
receiver.stop()
|
||||
stop_done.set()
|
||||
|
||||
t2 = threading.Thread(target=do_stop, daemon=True)
|
||||
t2.start()
|
||||
|
||||
# stop should be blocked waiting for lock
|
||||
assert not stop_done.wait(timeout=0.2), \
|
||||
"stop() should wait for lock, not clear without it"
|
||||
|
||||
released.set()
|
||||
assert stop_done.wait(timeout=2), "stop() should complete after lock released"
|
||||
t.join(timeout=2)
|
||||
t2.join(timeout=2)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Bug 2: _packet_debug_count must be instance-level, not class-level
|
||||
# =====================================================================
|
||||
|
||||
class TestPacketDebugCounterIsInstanceLevel:
|
||||
"""Each VoiceReceiver instance has its own debug counter."""
|
||||
|
||||
@staticmethod
|
||||
def _make_receiver():
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = None
|
||||
vc._connection.ssrc = 1
|
||||
return VoiceReceiver(vc)
|
||||
|
||||
def test_counter_is_per_instance(self):
|
||||
"""Two receivers have independent counters."""
|
||||
r1 = self._make_receiver()
|
||||
r2 = self._make_receiver()
|
||||
|
||||
r1._packet_debug_count = 10
|
||||
assert r2._packet_debug_count == 0, \
|
||||
"_packet_debug_count must be instance-level, not shared across instances"
|
||||
|
||||
def test_counter_initialized_in_init(self):
|
||||
"""Counter is set in __init__, not as a class variable."""
|
||||
r = self._make_receiver()
|
||||
assert "_packet_debug_count" in r.__dict__, \
|
||||
"_packet_debug_count should be in instance __dict__, not class"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Bug 3: play_in_voice_channel uses get_running_loop not get_event_loop
|
||||
# =====================================================================
|
||||
|
||||
class TestPlayInVoiceChannelUsesRunningLoop:
|
||||
"""play_in_voice_channel must use asyncio.get_running_loop()."""
|
||||
|
||||
def test_source_uses_get_running_loop(self):
|
||||
"""The method source code calls get_running_loop, not get_event_loop."""
|
||||
import inspect
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.play_in_voice_channel)
|
||||
assert "get_running_loop" in source, \
|
||||
"play_in_voice_channel should use asyncio.get_running_loop()"
|
||||
assert "get_event_loop" not in source, \
|
||||
"play_in_voice_channel should NOT use deprecated asyncio.get_event_loop()"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Bug 4: _send_voice_reply filename uses uuid (no collision)
|
||||
# =====================================================================
|
||||
|
||||
class TestSendVoiceReplyFilename:
|
||||
"""_send_voice_reply uses uuid for unique filenames."""
|
||||
|
||||
def test_filename_uses_uuid(self):
|
||||
"""The method uses uuid in the filename, not time-based."""
|
||||
import inspect
|
||||
from gateway.run import GatewayRunner
|
||||
source = inspect.getsource(GatewayRunner._send_voice_reply)
|
||||
assert "uuid" in source, \
|
||||
"_send_voice_reply should use uuid for unique filenames"
|
||||
assert "int(time.time())" not in source, \
|
||||
"_send_voice_reply should not use int(time.time()) — collision risk"
|
||||
|
||||
def test_filenames_are_unique(self):
|
||||
"""Two calls produce different filenames."""
|
||||
import uuid
|
||||
names = set()
|
||||
for _ in range(100):
|
||||
name = f"tts_reply_{uuid.uuid4().hex[:12]}.mp3"
|
||||
assert name not in names, f"Collision detected: {name}"
|
||||
names.add(name)
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Bug 5: Voice timeout cleans up runner voice_mode via callback
|
||||
# =====================================================================
|
||||
|
||||
class TestVoiceTimeoutCleansRunnerState:
|
||||
"""Timeout disconnect notifies runner to clean voice_mode."""
|
||||
|
||||
@staticmethod
|
||||
def _make_discord_adapter():
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
adapter._voice_input_callback = None
|
||||
adapter._on_voice_disconnect = None
|
||||
adapter._client = None
|
||||
adapter._broadcast = AsyncMock()
|
||||
adapter._allowed_user_ids = set()
|
||||
return adapter
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
return self._make_discord_adapter()
|
||||
|
||||
def test_adapter_has_on_voice_disconnect_attr(self, adapter):
|
||||
"""DiscordAdapter has _on_voice_disconnect callback attribute."""
|
||||
assert hasattr(adapter, "_on_voice_disconnect")
|
||||
assert adapter._on_voice_disconnect is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_calls_disconnect_callback(self, adapter):
|
||||
"""_voice_timeout_handler calls _on_voice_disconnect with chat_id."""
|
||||
callback_calls = []
|
||||
adapter._on_voice_disconnect = lambda chat_id: callback_calls.append(chat_id)
|
||||
|
||||
# Set up state as if we're in a voice channel
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
mock_vc.disconnect = AsyncMock()
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 999
|
||||
adapter._voice_timeout_tasks[111] = MagicMock()
|
||||
adapter._voice_receivers[111] = MagicMock()
|
||||
adapter._voice_listen_tasks[111] = MagicMock()
|
||||
|
||||
# Patch sleep to return immediately
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter._voice_timeout_handler(111)
|
||||
|
||||
assert "999" in callback_calls, \
|
||||
"_on_voice_disconnect must be called with chat_id on timeout"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_cleanup_method_removes_voice_mode(self, tmp_path):
|
||||
"""_handle_voice_timeout_cleanup removes voice_mode for chat."""
|
||||
runner = _make_runner(tmp_path)
|
||||
runner._voice_mode["999"] = "all"
|
||||
|
||||
runner._handle_voice_timeout_cleanup("999")
|
||||
|
||||
assert "999" not in runner._voice_mode, \
|
||||
"voice_mode must be removed after timeout cleanup"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_without_callback_does_not_crash(self, adapter):
|
||||
"""Timeout works even without _on_voice_disconnect set."""
|
||||
adapter._on_voice_disconnect = None
|
||||
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
mock_vc.disconnect = AsyncMock()
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 999
|
||||
adapter._voice_timeout_tasks[111] = MagicMock()
|
||||
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter._voice_timeout_handler(111)
|
||||
|
||||
assert 111 not in adapter._voice_clients
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Bug 6: play_in_voice_channel has playback timeout
|
||||
# =====================================================================
|
||||
|
||||
class TestPlaybackTimeout:
|
||||
"""play_in_voice_channel must time out instead of blocking forever."""
|
||||
|
||||
@staticmethod
|
||||
def _make_discord_adapter():
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.platform = Platform.DISCORD
|
||||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
adapter._voice_input_callback = None
|
||||
adapter._on_voice_disconnect = None
|
||||
adapter._client = None
|
||||
adapter._broadcast = AsyncMock()
|
||||
adapter._allowed_user_ids = set()
|
||||
return adapter
|
||||
|
||||
def test_source_has_wait_for_timeout(self):
|
||||
"""The method uses asyncio.wait_for with timeout."""
|
||||
import inspect
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.play_in_voice_channel)
|
||||
assert "wait_for" in source, \
|
||||
"play_in_voice_channel must use asyncio.wait_for for timeout"
|
||||
assert "PLAYBACK_TIMEOUT" in source, \
|
||||
"play_in_voice_channel must reference PLAYBACK_TIMEOUT constant"
|
||||
|
||||
def test_playback_timeout_constant_exists(self):
|
||||
"""PLAYBACK_TIMEOUT constant is defined on DiscordAdapter."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
assert hasattr(DiscordAdapter, "PLAYBACK_TIMEOUT")
|
||||
assert DiscordAdapter.PLAYBACK_TIMEOUT > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playback_timeout_fires(self):
|
||||
"""When done event is never set, playback times out gracefully."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
adapter = self._make_discord_adapter()
|
||||
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
mock_vc.is_playing.return_value = False
|
||||
# play() never calls the after callback -> done never set
|
||||
mock_vc.play = MagicMock()
|
||||
mock_vc.stop = MagicMock()
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_timeout_tasks[111] = MagicMock()
|
||||
|
||||
# Use a tiny timeout for test speed
|
||||
original_timeout = DiscordAdapter.PLAYBACK_TIMEOUT
|
||||
DiscordAdapter.PLAYBACK_TIMEOUT = 0.1
|
||||
try:
|
||||
with patch("discord.FFmpegPCMAudio"), \
|
||||
patch("discord.PCMVolumeTransformer", side_effect=lambda s, **kw: s):
|
||||
result = await adapter.play_in_voice_channel(111, "/tmp/test.mp3")
|
||||
assert result is True
|
||||
# vc.stop() should have been called due to timeout
|
||||
mock_vc.stop.assert_called()
|
||||
finally:
|
||||
DiscordAdapter.PLAYBACK_TIMEOUT = original_timeout
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_playing_wait_has_timeout(self):
|
||||
"""While loop waiting for previous playback has a timeout."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
adapter = self._make_discord_adapter()
|
||||
|
||||
mock_vc = MagicMock()
|
||||
mock_vc.is_connected.return_value = True
|
||||
# is_playing always returns True — would loop forever without timeout
|
||||
mock_vc.is_playing.return_value = True
|
||||
mock_vc.stop = MagicMock()
|
||||
mock_vc.play = MagicMock()
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_timeout_tasks[111] = MagicMock()
|
||||
|
||||
original_timeout = DiscordAdapter.PLAYBACK_TIMEOUT
|
||||
DiscordAdapter.PLAYBACK_TIMEOUT = 0.2
|
||||
try:
|
||||
with patch("discord.FFmpegPCMAudio"), \
|
||||
patch("discord.PCMVolumeTransformer", side_effect=lambda s, **kw: s):
|
||||
result = await adapter.play_in_voice_channel(111, "/tmp/test.mp3")
|
||||
assert result is True
|
||||
# stop() called to break out of the is_playing loop
|
||||
mock_vc.stop.assert_called()
|
||||
finally:
|
||||
DiscordAdapter.PLAYBACK_TIMEOUT = original_timeout
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Bug 7: _send_voice_reply cleanup in finally block
|
||||
# =====================================================================
|
||||
|
||||
class TestSendVoiceReplyCleanup:
|
||||
"""_send_voice_reply must clean up temp files even on exception."""
|
||||
|
||||
def test_cleanup_in_finally(self):
|
||||
"""The method has cleanup in a finally block, not inside try."""
|
||||
import inspect, textwrap, ast
|
||||
from gateway.run import GatewayRunner
|
||||
source = textwrap.dedent(inspect.getsource(GatewayRunner._send_voice_reply))
|
||||
tree = ast.parse(source)
|
||||
func = tree.body[0]
|
||||
|
||||
has_finally_unlink = False
|
||||
for node in ast.walk(func):
|
||||
if isinstance(node, ast.Try) and node.finalbody:
|
||||
finally_source = ast.dump(node.finalbody[0])
|
||||
if "unlink" in finally_source or "remove" in finally_source:
|
||||
has_finally_unlink = True
|
||||
break
|
||||
|
||||
assert has_finally_unlink, \
|
||||
"_send_voice_reply must have os.unlink in a finally block"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_files_cleaned_on_send_exception(self, tmp_path):
|
||||
"""Temp files are removed even when send_voice raises."""
|
||||
runner = _make_runner(tmp_path)
|
||||
adapter = MagicMock()
|
||||
adapter.send_voice = AsyncMock(side_effect=RuntimeError("send failed"))
|
||||
adapter.is_in_voice_channel = MagicMock(return_value=False)
|
||||
event = _make_event(message_type=MessageType.VOICE)
|
||||
runner.adapters[event.source.platform] = adapter
|
||||
runner._get_guild_id = MagicMock(return_value=None)
|
||||
|
||||
# Create a fake audio file that TTS would produce
|
||||
fake_audio = tmp_path / "hermes_voice"
|
||||
fake_audio.mkdir()
|
||||
audio_file = fake_audio / "test.mp3"
|
||||
audio_file.write_bytes(b"fake audio")
|
||||
|
||||
tts_result = json.dumps({
|
||||
"success": True,
|
||||
"file_path": str(audio_file),
|
||||
})
|
||||
|
||||
with patch("gateway.run.asyncio.to_thread", new_callable=AsyncMock, return_value=tts_result), \
|
||||
patch("tools.tts_tool._strip_markdown_for_tts", return_value="hello"), \
|
||||
patch("os.path.isfile", return_value=True), \
|
||||
patch("os.makedirs"):
|
||||
await runner._send_voice_reply(event, "Hello world")
|
||||
|
||||
# File should be cleaned up despite exception
|
||||
assert not audio_file.exists(), \
|
||||
"Temp audio file must be cleaned up even when send_voice raises"
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Bug 8: Base adapter auto-TTS cleans up temp file after play_tts
|
||||
# =====================================================================
|
||||
|
||||
class TestAutoTtsTempFileCleanup:
|
||||
"""Base adapter auto-TTS must clean up generated audio file."""
|
||||
|
||||
def test_source_has_finally_remove(self):
|
||||
"""play_tts call is wrapped in try/finally with os.remove."""
|
||||
import inspect
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
source = inspect.getsource(BasePlatformAdapter._process_message_background)
|
||||
# Find the play_tts section and verify cleanup
|
||||
play_tts_idx = source.find("play_tts")
|
||||
assert play_tts_idx > 0
|
||||
after_play = source[play_tts_idx:]
|
||||
finally_idx = after_play.find("finally")
|
||||
remove_idx = after_play.find("os.remove")
|
||||
assert finally_idx > 0, "play_tts must be in a try/finally block"
|
||||
assert remove_idx > 0, "finally block must call os.remove on _tts_path"
|
||||
assert remove_idx > finally_idx, "os.remove must be inside the finally block"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue