feat: add Discord voice channel listening — STT transcription and agent response pipeline
Phase 2 of voice channel support: bot listens to users speaking in VC, transcribes speech via Groq Whisper, and processes through the agent pipeline. - Add VoiceReceiver class for RTP packet capture, NaCl/DAVE decryption, Opus decode - Add silence detection and per-user PCM buffering - Wire voice input callback from adapter to GatewayRunner - Fix adapter dict key: use Platform.DISCORD enum instead of string - Fix guild_id extraction for synthetic voice events via SimpleNamespace raw_message - Pause/resume receiver during TTS playback to prevent echo
This commit is contained in:
parent
cc974904f8
commit
c0c358d051
2 changed files with 454 additions and 17 deletions
|
|
@ -10,7 +10,13 @@ Uses discord.py library for:
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Any
|
import struct
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Callable, Dict, List, Optional, Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -65,6 +71,294 @@ def check_discord_requirements() -> bool:
|
||||||
return DISCORD_AVAILABLE
|
return DISCORD_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceReceiver:
|
||||||
|
"""Captures and decodes voice audio from a Discord voice channel.
|
||||||
|
|
||||||
|
Attaches to a VoiceClient's socket listener, decrypts RTP packets
|
||||||
|
(NaCl transport + DAVE E2EE), decodes Opus to PCM, and buffers
|
||||||
|
per-user audio. A polling loop detects silence and delivers
|
||||||
|
completed utterances via a callback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SILENCE_THRESHOLD = 1.5 # seconds of silence → end of utterance
|
||||||
|
MIN_SPEECH_DURATION = 0.5 # minimum seconds to process (skip noise)
|
||||||
|
SAMPLE_RATE = 48000 # Discord native rate
|
||||||
|
CHANNELS = 2 # Discord sends stereo
|
||||||
|
|
||||||
|
def __init__(self, voice_client):
|
||||||
|
self._vc = voice_client
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
# Decryption
|
||||||
|
self._secret_key: Optional[bytes] = None
|
||||||
|
self._dave_session = None
|
||||||
|
self._bot_ssrc: int = 0
|
||||||
|
|
||||||
|
# SSRC -> user_id mapping (populated from SPEAKING events)
|
||||||
|
self._ssrc_to_user: Dict[int, int] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# Per-user audio buffers
|
||||||
|
self._buffers: Dict[int, bytearray] = defaultdict(bytearray)
|
||||||
|
self._last_packet_time: Dict[int, float] = {}
|
||||||
|
|
||||||
|
# Opus decoder per SSRC (each user needs own decoder state)
|
||||||
|
self._decoders: Dict[int, object] = {}
|
||||||
|
|
||||||
|
# Pause flag: don't capture while bot is playing TTS
|
||||||
|
self._paused = False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Lifecycle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start listening for voice packets."""
|
||||||
|
conn = self._vc._connection
|
||||||
|
self._secret_key = bytes(conn.secret_key)
|
||||||
|
self._dave_session = conn.dave_session
|
||||||
|
self._bot_ssrc = conn.ssrc
|
||||||
|
|
||||||
|
self._install_speaking_hook(conn)
|
||||||
|
conn.add_socket_listener(self._on_packet)
|
||||||
|
self._running = True
|
||||||
|
logger.info("VoiceReceiver started (bot_ssrc=%d)", self._bot_ssrc)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop listening and clean up."""
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
self._vc._connection.remove_socket_listener(self._on_packet)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._buffers.clear()
|
||||||
|
self._last_packet_time.clear()
|
||||||
|
self._decoders.clear()
|
||||||
|
self._ssrc_to_user.clear()
|
||||||
|
logger.info("VoiceReceiver stopped")
|
||||||
|
|
||||||
|
def pause(self):
|
||||||
|
self._paused = True
|
||||||
|
|
||||||
|
def resume(self):
|
||||||
|
self._paused = False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# SSRC -> user_id mapping via SPEAKING opcode hook
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def map_ssrc(self, ssrc: int, user_id: int):
|
||||||
|
with self._lock:
|
||||||
|
self._ssrc_to_user[ssrc] = user_id
|
||||||
|
|
||||||
|
def _install_speaking_hook(self, conn):
|
||||||
|
"""Wrap the voice websocket hook to capture SPEAKING events (op 5).
|
||||||
|
|
||||||
|
VoiceConnectionState stores the hook as ``conn.hook`` (public attr).
|
||||||
|
It is passed to DiscordVoiceWebSocket on each (re)connect, so we
|
||||||
|
must wrap it on the VoiceConnectionState level AND on the current
|
||||||
|
live websocket instance.
|
||||||
|
"""
|
||||||
|
original_hook = conn.hook
|
||||||
|
receiver_self = self
|
||||||
|
|
||||||
|
async def wrapped_hook(ws, msg):
|
||||||
|
if isinstance(msg, dict) and msg.get("op") == 5:
|
||||||
|
data = msg.get("d", {})
|
||||||
|
ssrc = data.get("ssrc")
|
||||||
|
user_id = data.get("user_id")
|
||||||
|
if ssrc and user_id:
|
||||||
|
logger.info("SPEAKING event: ssrc=%d -> user=%s", ssrc, user_id)
|
||||||
|
receiver_self.map_ssrc(int(ssrc), int(user_id))
|
||||||
|
if original_hook:
|
||||||
|
await original_hook(ws, msg)
|
||||||
|
|
||||||
|
# Set on connection state (for future reconnects)
|
||||||
|
conn.hook = wrapped_hook
|
||||||
|
# Set on the current live websocket (for immediate effect)
|
||||||
|
try:
|
||||||
|
from discord.utils import MISSING
|
||||||
|
if hasattr(conn, 'ws') and conn.ws is not MISSING:
|
||||||
|
conn.ws._hook = wrapped_hook
|
||||||
|
logger.info("Speaking hook installed on live websocket")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not install hook on live ws: %s", e)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Packet handler (called from SocketReader thread)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
_packet_debug_count = 0 # class-level counter for debug logging
|
||||||
|
|
||||||
|
def _on_packet(self, data: bytes):
|
||||||
|
if not self._running or self._paused:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Log first few raw packets for debugging
|
||||||
|
VoiceReceiver._packet_debug_count += 1
|
||||||
|
if VoiceReceiver._packet_debug_count <= 5:
|
||||||
|
logger.info(
|
||||||
|
"Raw UDP packet: len=%d, first_bytes=%s",
|
||||||
|
len(data), data[:4].hex() if len(data) >= 4 else "short",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(data) < 16:
|
||||||
|
return
|
||||||
|
|
||||||
|
# RTP version check: top 2 bits must be 10 (version 2).
|
||||||
|
# Lower bits may vary (padding, extension, CSRC count).
|
||||||
|
# Payload type (byte 1 lower 7 bits) = 0x78 (120) for voice.
|
||||||
|
if (data[0] >> 6) != 2 or (data[1] & 0x7F) != 0x78:
|
||||||
|
if VoiceReceiver._packet_debug_count <= 5:
|
||||||
|
logger.info("Skipped non-RTP: byte0=0x%02x byte1=0x%02x", data[0], data[1])
|
||||||
|
return
|
||||||
|
|
||||||
|
first_byte = data[0]
|
||||||
|
_, _, seq, timestamp, ssrc = struct.unpack_from(">BBHII", data, 0)
|
||||||
|
|
||||||
|
# Skip bot's own audio
|
||||||
|
if ssrc == self._bot_ssrc:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate dynamic RTP header size (RFC 9335 / rtpsize mode)
|
||||||
|
cc = first_byte & 0x0F # CSRC count
|
||||||
|
has_extension = bool(first_byte & 0x10) # extension bit
|
||||||
|
header_size = 12 + (4 * cc) + (4 if has_extension else 0)
|
||||||
|
|
||||||
|
if len(data) < header_size + 4: # need at least header + nonce
|
||||||
|
return
|
||||||
|
|
||||||
|
# Read extension length from preamble (for skipping after decrypt)
|
||||||
|
ext_data_len = 0
|
||||||
|
if has_extension:
|
||||||
|
ext_preamble_offset = 12 + (4 * cc)
|
||||||
|
ext_words = struct.unpack_from(">H", data, ext_preamble_offset + 2)[0]
|
||||||
|
ext_data_len = ext_words * 4
|
||||||
|
|
||||||
|
if VoiceReceiver._packet_debug_count <= 10:
|
||||||
|
with self._lock:
|
||||||
|
known_user = self._ssrc_to_user.get(ssrc, "unknown")
|
||||||
|
logger.info(
|
||||||
|
"RTP packet: ssrc=%d, seq=%d, user=%s, hdr=%d, ext_data=%d",
|
||||||
|
ssrc, seq, known_user, header_size, ext_data_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
header = bytes(data[:header_size])
|
||||||
|
payload_with_nonce = data[header_size:]
|
||||||
|
|
||||||
|
# --- NaCl transport decrypt (aead_xchacha20_poly1305_rtpsize) ---
|
||||||
|
if len(payload_with_nonce) < 4:
|
||||||
|
return
|
||||||
|
nonce = bytearray(24)
|
||||||
|
nonce[:4] = payload_with_nonce[-4:]
|
||||||
|
encrypted = bytes(payload_with_nonce[:-4])
|
||||||
|
|
||||||
|
try:
|
||||||
|
import nacl.secret # noqa: delayed import – only in voice path
|
||||||
|
box = nacl.secret.Aead(self._secret_key)
|
||||||
|
decrypted = box.decrypt(encrypted, header, bytes(nonce))
|
||||||
|
except Exception as e:
|
||||||
|
if VoiceReceiver._packet_debug_count <= 10:
|
||||||
|
logger.warning("NaCl decrypt failed: %s (hdr=%d, enc=%d)", e, header_size, len(encrypted))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip encrypted extension data to get the actual opus payload
|
||||||
|
if ext_data_len and len(decrypted) > ext_data_len:
|
||||||
|
decrypted = decrypted[ext_data_len:]
|
||||||
|
|
||||||
|
# --- DAVE E2EE decrypt ---
|
||||||
|
if self._dave_session:
|
||||||
|
with self._lock:
|
||||||
|
user_id = self._ssrc_to_user.get(ssrc, 0)
|
||||||
|
if user_id == 0:
|
||||||
|
if VoiceReceiver._packet_debug_count <= 10:
|
||||||
|
logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc)
|
||||||
|
return # unknown user, can't DAVE-decrypt
|
||||||
|
try:
|
||||||
|
import davey
|
||||||
|
decrypted = self._dave_session.decrypt(
|
||||||
|
user_id, davey.MediaType.audio, decrypted
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if VoiceReceiver._packet_debug_count <= 10:
|
||||||
|
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Opus decode -> PCM ---
|
||||||
|
try:
|
||||||
|
if ssrc not in self._decoders:
|
||||||
|
self._decoders[ssrc] = discord.opus.Decoder()
|
||||||
|
pcm = self._decoders[ssrc].decode(decrypted)
|
||||||
|
self._buffers[ssrc].extend(pcm)
|
||||||
|
self._last_packet_time[ssrc] = time.monotonic()
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Silence detection
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def check_silence(self) -> list:
|
||||||
|
"""Return list of (user_id, pcm_bytes) for completed utterances."""
|
||||||
|
now = time.monotonic()
|
||||||
|
completed = []
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
ssrc_user_map = dict(self._ssrc_to_user)
|
||||||
|
|
||||||
|
for ssrc in list(self._buffers.keys()):
|
||||||
|
last_time = self._last_packet_time.get(ssrc, now)
|
||||||
|
silence_duration = now - last_time
|
||||||
|
buf = self._buffers[ssrc]
|
||||||
|
# 48kHz, 16-bit, stereo = 192000 bytes/sec
|
||||||
|
buf_duration = len(buf) / (self.SAMPLE_RATE * self.CHANNELS * 2)
|
||||||
|
|
||||||
|
if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION:
|
||||||
|
user_id = ssrc_user_map.get(ssrc, 0)
|
||||||
|
if user_id:
|
||||||
|
completed.append((user_id, bytes(buf)))
|
||||||
|
self._buffers[ssrc] = bytearray()
|
||||||
|
self._last_packet_time.pop(ssrc, None)
|
||||||
|
elif silence_duration >= self.SILENCE_THRESHOLD * 2:
|
||||||
|
# Stale buffer with no valid user — discard
|
||||||
|
self._buffers.pop(ssrc, None)
|
||||||
|
self._last_packet_time.pop(ssrc, None)
|
||||||
|
|
||||||
|
return completed
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# PCM -> WAV conversion (for Whisper STT)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pcm_to_wav(pcm_data: bytes, output_path: str,
|
||||||
|
src_rate: int = 48000, src_channels: int = 2):
|
||||||
|
"""Convert raw PCM to 16kHz mono WAV via ffmpeg."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pcm", delete=False) as f:
|
||||||
|
f.write(pcm_data)
|
||||||
|
pcm_path = f.name
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"ffmpeg", "-y", "-loglevel", "error",
|
||||||
|
"-f", "s16le",
|
||||||
|
"-ar", str(src_rate),
|
||||||
|
"-ac", str(src_channels),
|
||||||
|
"-i", pcm_path,
|
||||||
|
"-ar", "16000",
|
||||||
|
"-ac", "1",
|
||||||
|
output_path,
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.unlink(pcm_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DiscordAdapter(BasePlatformAdapter):
|
class DiscordAdapter(BasePlatformAdapter):
|
||||||
"""
|
"""
|
||||||
Discord bot adapter.
|
Discord bot adapter.
|
||||||
|
|
@ -94,6 +388,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
|
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
|
||||||
self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id
|
self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id
|
||||||
self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task
|
self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task
|
||||||
|
# Phase 2: voice listening
|
||||||
|
self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver
|
||||||
|
self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop
|
||||||
|
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
async def connect(self) -> bool:
|
||||||
"""Connect to Discord and start receiving events."""
|
"""Connect to Discord and start receiving events."""
|
||||||
|
|
@ -402,10 +700,30 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
vc = await channel.connect()
|
vc = await channel.connect()
|
||||||
self._voice_clients[guild_id] = vc
|
self._voice_clients[guild_id] = vc
|
||||||
self._reset_voice_timeout(guild_id)
|
self._reset_voice_timeout(guild_id)
|
||||||
|
|
||||||
|
# Start voice receiver (Phase 2: listen to users)
|
||||||
|
try:
|
||||||
|
receiver = VoiceReceiver(vc)
|
||||||
|
receiver.start()
|
||||||
|
self._voice_receivers[guild_id] = receiver
|
||||||
|
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
|
||||||
|
self._voice_listen_loop(guild_id)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Voice receiver failed to start: %s", e)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def leave_voice_channel(self, guild_id: int) -> None:
|
async def leave_voice_channel(self, guild_id: int) -> None:
|
||||||
"""Disconnect from the voice channel in a guild."""
|
"""Disconnect from the voice channel in a guild."""
|
||||||
|
# Stop voice receiver first
|
||||||
|
receiver = self._voice_receivers.pop(guild_id, None)
|
||||||
|
if receiver:
|
||||||
|
receiver.stop()
|
||||||
|
listen_task = self._voice_listen_tasks.pop(guild_id, None)
|
||||||
|
if listen_task:
|
||||||
|
listen_task.cancel()
|
||||||
|
|
||||||
vc = self._voice_clients.pop(guild_id, None)
|
vc = self._voice_clients.pop(guild_id, None)
|
||||||
if vc and vc.is_connected():
|
if vc and vc.is_connected():
|
||||||
await vc.disconnect()
|
await vc.disconnect()
|
||||||
|
|
@ -420,24 +738,33 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
if not vc or not vc.is_connected():
|
if not vc or not vc.is_connected():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Wait for current playback to finish
|
# Pause voice receiver while playing (echo prevention)
|
||||||
while vc.is_playing():
|
receiver = self._voice_receivers.get(guild_id)
|
||||||
await asyncio.sleep(0.1)
|
if receiver:
|
||||||
|
receiver.pause()
|
||||||
|
|
||||||
done = asyncio.Event()
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
# Wait for current playback to finish
|
||||||
|
while vc.is_playing():
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
def _after(error):
|
done = asyncio.Event()
|
||||||
if error:
|
loop = asyncio.get_event_loop()
|
||||||
logger.error("Voice playback error: %s", error)
|
|
||||||
loop.call_soon_threadsafe(done.set)
|
|
||||||
|
|
||||||
source = discord.FFmpegPCMAudio(audio_path)
|
def _after(error):
|
||||||
source = discord.PCMVolumeTransformer(source, volume=1.0)
|
if error:
|
||||||
vc.play(source, after=_after)
|
logger.error("Voice playback error: %s", error)
|
||||||
await done.wait()
|
loop.call_soon_threadsafe(done.set)
|
||||||
self._reset_voice_timeout(guild_id)
|
|
||||||
return True
|
source = discord.FFmpegPCMAudio(audio_path)
|
||||||
|
source = discord.PCMVolumeTransformer(source, volume=1.0)
|
||||||
|
vc.play(source, after=_after)
|
||||||
|
await done.wait()
|
||||||
|
self._reset_voice_timeout(guild_id)
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
if receiver:
|
||||||
|
receiver.resume()
|
||||||
|
|
||||||
async def get_user_voice_channel(self, guild_id: int, user_id: str):
|
async def get_user_voice_channel(self, guild_id: int, user_id: str):
|
||||||
"""Return the voice channel the user is currently in, or None."""
|
"""Return the voice channel the user is currently in, or None."""
|
||||||
|
|
@ -481,6 +808,67 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
vc = self._voice_clients.get(guild_id)
|
vc = self._voice_clients.get(guild_id)
|
||||||
return vc is not None and vc.is_connected()
|
return vc is not None and vc.is_connected()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Voice listening (Phase 2)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _voice_listen_loop(self, guild_id: int):
|
||||||
|
"""Periodically check for completed utterances and process them."""
|
||||||
|
receiver = self._voice_receivers.get(guild_id)
|
||||||
|
if not receiver:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
while receiver._running:
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
completed = receiver.check_silence()
|
||||||
|
for user_id, pcm_data in completed:
|
||||||
|
if not self._is_allowed_user(str(user_id)):
|
||||||
|
continue
|
||||||
|
await self._process_voice_input(guild_id, user_id, pcm_data)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Voice listen loop error: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
async def _process_voice_input(self, guild_id: int, user_id: int, pcm_data: bytes):
|
||||||
|
"""Convert PCM -> WAV -> STT -> callback."""
|
||||||
|
from tools.voice_mode import is_whisper_hallucination
|
||||||
|
|
||||||
|
wav_path = tempfile.mktemp(suffix=".wav", prefix="vc_listen_")
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)
|
||||||
|
|
||||||
|
from tools.transcription_tools import transcribe_audio
|
||||||
|
result = await asyncio.to_thread(transcribe_audio, wav_path)
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
return
|
||||||
|
transcript = result.get("transcript", "").strip()
|
||||||
|
if not transcript or is_whisper_hallucination(transcript):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Voice input from user %d: %s", user_id, transcript[:100])
|
||||||
|
|
||||||
|
if self._voice_input_callback:
|
||||||
|
await self._voice_input_callback(
|
||||||
|
guild_id=guild_id,
|
||||||
|
user_id=user_id,
|
||||||
|
transcript=transcript,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Voice input processing failed: %s", e, exc_info=True)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
os.unlink(wav_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _is_allowed_user(self, user_id: str) -> bool:
|
||||||
|
"""Check if user is in DISCORD_ALLOWED_USERS."""
|
||||||
|
if not self._allowed_user_ids:
|
||||||
|
return True
|
||||||
|
return user_id in self._allowed_user_ids
|
||||||
|
|
||||||
async def send_image_file(
|
async def send_image_file(
|
||||||
self,
|
self,
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
|
|
|
||||||
|
|
@ -1608,6 +1608,8 @@ class GatewayRunner:
|
||||||
(voice_mode == "all")
|
(voice_mode == "all")
|
||||||
or (voice_mode == "voice_only" and is_voice_input)
|
or (voice_mode == "voice_only" and is_voice_input)
|
||||||
)
|
)
|
||||||
|
logger.info("Voice reply check: chat_id=%s, voice_mode=%s, is_voice=%s, should_reply=%s, has_response=%s",
|
||||||
|
chat_id, voice_mode, is_voice_input, should_voice_reply, bool(response))
|
||||||
if should_voice_reply and response and not response.startswith("Error:"):
|
if should_voice_reply and response and not response.startswith("Error:"):
|
||||||
# Skip if agent already called TTS tool (avoid double voice)
|
# Skip if agent already called TTS tool (avoid double voice)
|
||||||
has_agent_tts = any(
|
has_agent_tts = any(
|
||||||
|
|
@ -1618,6 +1620,7 @@ class GatewayRunner:
|
||||||
)
|
)
|
||||||
for msg in agent_messages
|
for msg in agent_messages
|
||||||
)
|
)
|
||||||
|
logger.info("Voice reply: has_agent_tts=%s, calling _send_voice_reply", has_agent_tts)
|
||||||
if not has_agent_tts:
|
if not has_agent_tts:
|
||||||
await self._send_voice_reply(event, response)
|
await self._send_voice_reply(event, response)
|
||||||
|
|
||||||
|
|
@ -2201,9 +2204,12 @@ class GatewayRunner:
|
||||||
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
||||||
self._voice_mode[event.source.chat_id] = "all"
|
self._voice_mode[event.source.chat_id] = "all"
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
|
# Wire voice input callback so the adapter can deliver transcripts
|
||||||
|
if hasattr(adapter, "_voice_input_callback"):
|
||||||
|
adapter._voice_input_callback = self._handle_voice_channel_input
|
||||||
return (
|
return (
|
||||||
f"Joined voice channel **{voice_channel.name}**.\n"
|
f"Joined voice channel **{voice_channel.name}**.\n"
|
||||||
f"I'll speak my replies here. Use /voice leave to disconnect."
|
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
|
||||||
)
|
)
|
||||||
return "Failed to join voice channel. Check bot permissions (Connect + Speak)."
|
return "Failed to join voice channel. Check bot permissions (Connect + Speak)."
|
||||||
|
|
||||||
|
|
@ -2223,6 +2229,49 @@ class GatewayRunner:
|
||||||
self._save_voice_modes()
|
self._save_voice_modes()
|
||||||
return "Left voice channel."
|
return "Left voice channel."
|
||||||
|
|
||||||
|
async def _handle_voice_channel_input(
|
||||||
|
self, guild_id: int, user_id: int, transcript: str
|
||||||
|
):
|
||||||
|
"""Handle transcribed voice from a user in a voice channel.
|
||||||
|
|
||||||
|
Creates a synthetic MessageEvent and processes it through the
|
||||||
|
adapter's full message pipeline (session, typing, agent, TTS reply).
|
||||||
|
"""
|
||||||
|
adapter = self.adapters.get(Platform.DISCORD)
|
||||||
|
if not adapter:
|
||||||
|
return
|
||||||
|
|
||||||
|
text_ch_id = adapter._voice_text_channels.get(guild_id)
|
||||||
|
if not text_ch_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Show transcript in text channel
|
||||||
|
try:
|
||||||
|
channel = adapter._client.get_channel(text_ch_id)
|
||||||
|
if channel:
|
||||||
|
await channel.send(f"**[Voice]** <@{user_id}>: {transcript}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Build a synthetic MessageEvent and feed through the normal pipeline
|
||||||
|
source = SessionSource(
|
||||||
|
platform=Platform.DISCORD,
|
||||||
|
chat_id=str(text_ch_id),
|
||||||
|
user_id=str(user_id),
|
||||||
|
user_name=str(user_id),
|
||||||
|
)
|
||||||
|
# Use SimpleNamespace as raw_message so _get_guild_id() can extract
|
||||||
|
# guild_id and _send_voice_reply() plays audio in the voice channel.
|
||||||
|
from types import SimpleNamespace
|
||||||
|
event = MessageEvent(
|
||||||
|
source=source,
|
||||||
|
text=transcript,
|
||||||
|
message_type=MessageType.VOICE,
|
||||||
|
raw_message=SimpleNamespace(guild_id=guild_id, guild=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
await adapter.handle_message(event)
|
||||||
|
|
||||||
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
|
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
|
||||||
"""Generate TTS audio and send as a voice message before the text reply."""
|
"""Generate TTS audio and send as a voice message before the text reply."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue