surfaces/bot-examples/asr.py
Mikhail Putilovskij 6ced154124 feat(matrix): land QA follow-ups and refresh docs
- harden Matrix onboarding/chat lifecycle after manual QA
- refresh README and Matrix docs to match current behavior
- add local ignores for runtime artifacts and include current planning/report docs

Closes #7
Closes #9
Closes #14
2026-04-05 19:08:58 +03:00

233 lines
8.3 KiB
Python

"""ASR via OpenAI-compatible STT server (GigaAM, Whisper, etc).
Default: GigaAM (Russian-optimized, handles long-form natively via pyannote).
Fallback: Whisper (multilingual, needs client-side chunking for long audio).
Truncation detection and chunked retry only applies to Whisper-based backends.
GigaAM handles long-form audio server-side via pyannote segmentation.
"""
import asyncio
import logging
import os
import re
import tempfile
from pathlib import Path
import httpx
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
TIMEOUT = 300.0
# If Whisper covers less than this fraction of the audio, retry with chunks
COVERAGE_THRESHOLD = 0.85
def _is_whisper(stt_url: str) -> bool:
"""Heuristic: URL points to a Whisper-based server."""
return "whisper" in stt_url.lower()
async def _get_duration(audio_path: str) -> float | None:
"""Get audio duration in seconds via ffprobe."""
try:
proc = await asyncio.create_subprocess_exec(
"ffprobe", "-v", "quiet", "-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1", audio_path,
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL,
)
stdout, _ = await proc.communicate()
return float(stdout.decode().strip())
except Exception:
return None
async def _find_split_points(audio_path: str, target_chunk: float = 30.0) -> list[float]:
"""Find silence gaps for splitting audio into ~target_chunk second pieces."""
try:
proc = await asyncio.create_subprocess_exec(
"ffmpeg", "-i", audio_path,
"-af", "silencedetect=noise=-35dB:d=0.4",
"-f", "null", "-",
stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE,
)
_, stderr = await proc.communicate()
output = stderr.decode("utf-8", errors="replace")
silences = []
for m in re.finditer(r"silence_end:\s*([\d.]+)", output):
silences.append(float(m.group(1)))
if not silences:
return []
duration = await _get_duration(audio_path) or silences[-1] + 10
splits = []
target = target_chunk
while target < duration - 10:
best = min(silences, key=lambda s: abs(s - target))
if not splits or best > splits[-1] + 10:
splits.append(best)
target += target_chunk
return splits
except Exception:
return []
async def _stt_request(
url: str, audio_path: str, language: str | None = None,
response_format: str = "json",
) -> dict:
"""Single STT API call. Returns the JSON response dict."""
last_exc = None
for attempt in range(MAX_RETRIES):
try:
async with httpx.AsyncClient(timeout=TIMEOUT) as client:
with open(audio_path, "rb") as f:
data = {"response_format": response_format}
if _is_whisper(url):
data["model"] = "Systran/faster-whisper-large-v3"
if language:
data["language"] = language
files = {"file": (Path(audio_path).name, f, "application/octet-stream")}
resp = await client.post(url, data=data, files=files)
if resp.status_code != 200:
raise RuntimeError(
f"STT API returned {resp.status_code}: {resp.text[:200]}"
)
return resp.json()
except (httpx.ConnectError, httpx.TimeoutException) as e:
last_exc = e
if attempt < MAX_RETRIES - 1:
logger.warning(
"STT connection error (attempt %d/%d): %s",
attempt + 1, MAX_RETRIES, e,
)
continue
except RuntimeError:
raise
except Exception as e:
raise RuntimeError(f"STT transcription failed: {e}") from e
raise RuntimeError(f"STT unavailable after {MAX_RETRIES} attempts: {last_exc}")
async def _transcribe_chunked(
url: str, audio_path: str, split_points: list[float],
language: str | None = None,
) -> str:
"""Split audio at silence boundaries and transcribe each chunk."""
tmpdir = tempfile.mkdtemp(prefix="asr_chunk_")
chunks = []
try:
boundaries = [0.0] + split_points
for i, start in enumerate(boundaries):
chunk_path = os.path.join(tmpdir, f"chunk{i}.ogg")
args = ["ffmpeg", "-y", "-i", audio_path, "-ss", str(start)]
if i < len(split_points):
args += ["-t", str(split_points[i] - start)]
args += ["-c", "copy", chunk_path]
proc = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await proc.wait()
chunks.append(chunk_path)
texts = []
for chunk in chunks:
if not os.path.exists(chunk) or os.path.getsize(chunk) < 100:
continue
result = await _stt_request(url, chunk, language=language)
text = result.get("text", "").strip()
if text:
texts.append(text)
return " ".join(texts)
finally:
for f in chunks:
try:
os.unlink(f)
except OSError:
pass
try:
os.rmdir(tmpdir)
except OSError:
pass
HYBRID_THRESHOLD = 30.0 # seconds — use Whisper for short, GigaAM for long
async def transcribe(
audio_path: str,
stt_url: str,
language: str | None = None,
whisper_url: str | None = None,
) -> tuple[str, str]:
"""Transcribe audio file via OpenAI-compatible STT server.
Hybrid mode: if both stt_url and whisper_url are provided, uses Whisper
for short audio (<30s) and the primary STT for longer audio.
Returns:
(transcribed_text, engine_tag) — engine_tag is "w" or "g" (or first letter of host).
Raises:
RuntimeError: If transcription fails after retries.
"""
# Hybrid: pick engine based on duration
chosen_url = stt_url
if whisper_url and whisper_url != stt_url:
duration = await _get_duration(audio_path)
if duration is not None and duration < HYBRID_THRESHOLD:
chosen_url = whisper_url
url = f"{chosen_url.rstrip('/')}/v1/audio/transcriptions"
whisper = _is_whisper(chosen_url)
engine_tag = "w" if whisper else chosen_url.split("//")[-1][0]
# For Whisper: use verbose_json to detect truncation
# For others: simple json is enough
fmt = "verbose_json" if whisper else "json"
result = await _stt_request(url, audio_path, language=language, response_format=fmt)
text = result.get("text", "").strip()
if not text:
raise RuntimeError("STT returned empty transcription")
# Whisper truncation detection — only for Whisper backends
if whisper:
file_duration = await _get_duration(audio_path)
segments = result.get("segments", [])
if file_duration and segments and file_duration > 30:
last_segment_end = segments[-1].get("end", 0)
coverage = last_segment_end / file_duration
if coverage < COVERAGE_THRESHOLD:
logger.warning(
"Whisper truncated %s: covered %.0f/%.0fs (%.0f%%), retrying with chunks",
Path(audio_path).name, last_segment_end, file_duration, coverage * 100,
)
split_points = await _find_split_points(audio_path, target_chunk=30.0)
if not split_points:
n_chunks = max(2, int(file_duration / 30))
split_points = [file_duration * i / n_chunks for i in range(1, n_chunks)]
chunked_text = await _transcribe_chunked(
url, audio_path, split_points, language=language,
)
if len(chunked_text) > len(text):
text = chunked_text
logger.info(
"Chunked transcription recovered %d chars (was %d)",
len(text), len(result.get("text", "")),
)
logger.info("Transcribed %s: %d chars [%s]", Path(audio_path).name, len(text), engine_tag)
return text, engine_tag