210 lines
8.2 KiB
Python
210 lines
8.2 KiB
Python
"""
|
||
recognition.py — модуль распознавания аудиофайла с ответами ученика ЕГЭ (аудирование, английский язык).
|
||
|
||
Зависимости:
|
||
pip install faster-whisper
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Константы
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# Модели faster-whisper по убыванию скорости / возрастанию качества:
|
||
# tiny, base, small, medium, large-v2, large-v3
|
||
DEFAULT_MODEL = "medium"
|
||
|
||
# Подсказка для Whisper — описывает формат ответов ученика.
|
||
# Критически важна для правильного распознавания "one/two/three" как цифр
|
||
# и "A equals 3" как ответа на задание 1.
|
||
WHISPER_PROMPT = (
|
||
"Student answers to EGE English listening exam. "
|
||
"Task one matching: speaker A answer three, speaker B answer one, "
|
||
"speaker C answer five, speaker D answer seven, speaker E answer two, speaker F answer four. "
|
||
"Tasks two through nine True False Not Stated: "
|
||
"task two true, task three false, task four not stated. "
|
||
"Tasks ten through eighteen multiple choice one two or three: "
|
||
"task ten two, task eleven one, task twelve three."
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Структуры данных
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@dataclass
|
||
class TranscriptResult:
|
||
"""Результат транскрипции аудиофайла."""
|
||
text: str # Полный текст транскрипта
|
||
language: str # Определённый язык ("en")
|
||
duration_seconds: float # Длительность аудио в секундах
|
||
segments: list[dict] = field(default_factory=list) # Детальные сегменты с таймкодами
|
||
model_used: str = DEFAULT_MODEL
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Транскрипция
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def transcribe(
|
||
audio_path: str | Path,
|
||
model_size: str = DEFAULT_MODEL,
|
||
device: str = "auto",
|
||
compute_type: str = "auto",
|
||
language: str = "en",
|
||
beam_size: int = 5,
|
||
) -> TranscriptResult:
|
||
"""
|
||
Транскрибирует аудиофайл с ответами ученика.
|
||
|
||
Args:
|
||
audio_path: Путь к аудиофайлу (MP3, WAV, M4A, OGG, WEBM, FLAC).
|
||
model_size: Размер модели Whisper: tiny/base/small/medium/large-v2/large-v3.
|
||
medium — хороший баланс скорость/качество для ЕГЭ.
|
||
large-v3 — максимальное качество, медленнее.
|
||
device: "auto" | "cpu" | "cuda". "auto" выберет GPU если доступен.
|
||
compute_type: "auto" | "int8" | "float16" | "float32".
|
||
"auto" подберёт оптимальный тип для устройства.
|
||
language: Язык аудио. "en" для ответов на английском.
|
||
beam_size: Ширина луча beam search. 5 — стандарт, выше = точнее но медленнее.
|
||
|
||
Returns:
|
||
TranscriptResult с текстом, языком, длительностью и сегментами.
|
||
|
||
Raises:
|
||
FileNotFoundError: Если аудиофайл не найден.
|
||
RuntimeError: Если faster-whisper не установлен.
|
||
"""
|
||
try:
|
||
from faster_whisper import WhisperModel
|
||
except ImportError:
|
||
raise RuntimeError(
|
||
"faster-whisper не установлен. Установите: pip install faster-whisper"
|
||
)
|
||
|
||
audio_path = Path(audio_path)
|
||
if not audio_path.exists():
|
||
raise FileNotFoundError(f"Аудиофайл не найден: {audio_path}")
|
||
|
||
# Автовыбор устройства и типа вычислений
|
||
resolved_device, resolved_compute = _resolve_device(device, compute_type)
|
||
|
||
logger.info(
|
||
"Загрузка модели %s на %s (%s)...",
|
||
model_size, resolved_device, resolved_compute
|
||
)
|
||
|
||
model = WhisperModel(
|
||
model_size,
|
||
device=resolved_device,
|
||
compute_type=resolved_compute,
|
||
)
|
||
|
||
logger.info("Транскрибирую: %s", audio_path.name)
|
||
|
||
segments_gen, info = model.transcribe(
|
||
str(audio_path),
|
||
language=language,
|
||
beam_size=beam_size,
|
||
initial_prompt=WHISPER_PROMPT,
|
||
word_timestamps=False,
|
||
vad_filter=True, # Фильтрация тишины — полезно для записей с паузами
|
||
vad_parameters={
|
||
"min_silence_duration_ms": 500, # Паузы >0.5с считаются тишиной
|
||
"speech_pad_ms": 200,
|
||
},
|
||
)
|
||
|
||
# Материализуем генератор сегментов
|
||
segments = []
|
||
full_text_parts = []
|
||
|
||
for seg in segments_gen:
|
||
segments.append({
|
||
"start": round(seg.start, 2),
|
||
"end": round(seg.end, 2),
|
||
"text": seg.text.strip(),
|
||
})
|
||
full_text_parts.append(seg.text.strip())
|
||
|
||
full_text = " ".join(full_text_parts)
|
||
|
||
logger.info(
|
||
"Транскрипция завершена. Длительность: %.1f сек, слов ~%d",
|
||
info.duration, len(full_text.split())
|
||
)
|
||
|
||
return TranscriptResult(
|
||
text=full_text,
|
||
language=info.language,
|
||
duration_seconds=round(info.duration, 1),
|
||
segments=segments,
|
||
model_used=model_size,
|
||
)
|
||
|
||
|
||
def _resolve_device(device: str, compute_type: str) -> tuple[str, str]:
|
||
"""Определяет оптимальное устройство и тип вычислений."""
|
||
if device != "auto" and compute_type != "auto":
|
||
return device, compute_type
|
||
|
||
# Проверяем наличие CUDA
|
||
try:
|
||
from torch import cuda
|
||
has_cuda = cuda.is_available()
|
||
except ImportError:
|
||
has_cuda = False
|
||
|
||
if device == "auto":
|
||
device = "cuda" if has_cuda else "cpu"
|
||
|
||
if compute_type == "auto":
|
||
if device == "cuda":
|
||
compute_type = "float16" # GPU: float16 быстрее и точнее чем int8
|
||
else:
|
||
compute_type = "int8" # CPU: int8 значительно быстрее float32
|
||
|
||
return device, compute_type
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CLI
|
||
# ---------------------------------------------------------------------------
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
import sys
|
||
|
||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||
|
||
parser = argparse.ArgumentParser(
|
||
description="Распознавание аудиоответов ЕГЭ (аудирование, английский язык)"
|
||
)
|
||
parser.add_argument("audio", help="Путь к аудиофайлу")
|
||
parser.add_argument(
|
||
"--model", default=DEFAULT_MODEL,
|
||
choices=["tiny", "base", "small", "medium", "large-v2", "large-v3"],
|
||
help=f"Размер модели Whisper (по умолчанию: {DEFAULT_MODEL})"
|
||
)
|
||
parser.add_argument(
|
||
"--device", default="auto",
|
||
choices=["auto", "cpu", "cuda"],
|
||
help="Устройство для инференса (по умолчанию: auto)"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
result = transcribe(args.audio, model_size=args.model, device=args.device)
|
||
print(result.text)
|
||
except (FileNotFoundError, RuntimeError) as e:
|
||
print(f"Ошибка: {e}", file=sys.stderr)
|
||
sys.exit(1)
|