ege-skill/ege-checker/recognition.py

206 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
recognition.py — модуль распознавания аудиофайла с ответами ученика ЕГЭ (говорение, английский язык).
Зависимости:
pip install faster-whisper
"""
from __future__ import annotations
import re
import logging
from dataclasses import dataclass, field
from pathlib import Path
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Константы
# ---------------------------------------------------------------------------
# Модели faster-whisper по убыванию скорости / возрастанию качества:
# tiny, base, small, medium, large-v2, large-v3
DEFAULT_MODEL = "medium"
# Подсказка для Whisper — описывает формат ответов ученика.
WHISPER_PROMPT = (
"Student answers to EGE English speaking exam. "
"Task one: read aloud. "
"Task two: ask four direct questions. "
"Task three: answer five questions. "
"Task four: monologue - compare two photos. "
)
# ---------------------------------------------------------------------------
# Структуры данных
# ---------------------------------------------------------------------------
@dataclass
class TranscriptResult:
"""Результат транскрипции аудиофайла."""
text: str # Полный текст транскрипта
language: str # Определённый язык ("en")
duration_seconds: float # Длительность аудио в секундах
segments: list[dict] = field(default_factory=list) # Детальные сегменты с таймкодами
model_used: str = DEFAULT_MODEL
# ---------------------------------------------------------------------------
# Транскрипция
# ---------------------------------------------------------------------------
def transcribe(
audio_path: str | Path,
model_size: str = DEFAULT_MODEL,
device: str = "auto",
compute_type: str = "auto",
language: str = "en",
beam_size: int = 5,
) -> TranscriptResult:
"""
Транскрибирует аудиофайл с ответами ученика.
Args:
audio_path: Путь к аудиофайлу (MP3, WAV, M4A, OGG, WEBM, FLAC).
model_size: Размер модели Whisper: tiny/base/small/medium/large-v2/large-v3.
medium — хороший баланс скорость/качество для ЕГЭ.
large-v3 — максимальное качество, медленнее.
device: "auto" | "cpu" | "cuda". "auto" выберет GPU если доступен.
compute_type: "auto" | "int8" | "float16" | "float32".
"auto" подберёт оптимальный тип для устройства.
language: Язык аудио. "en" для ответов на английском.
beam_size: Ширина луча beam search. 5 — стандарт, выше = точнее но медленнее.
Returns:
TranscriptResult с текстом, языком, длительностью и сегментами.
Raises:
FileNotFoundError: Если аудиофайл не найден.
RuntimeError: Если faster-whisper не установлен.
"""
try:
from faster_whisper import WhisperModel
except ImportError:
raise RuntimeError(
"faster-whisper не установлен. Установите: pip install faster-whisper"
)
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"Аудиофайл не найден: {audio_path}")
# Автовыбор устройства и типа вычислений
resolved_device, resolved_compute = _resolve_device(device, compute_type)
logger.info(
"Загрузка модели %s на %s (%s)...",
model_size, resolved_device, resolved_compute
)
model = WhisperModel(
model_size,
device=resolved_device,
compute_type=resolved_compute,
)
logger.info("Транскрибирую: %s", audio_path.name)
segments_gen, info = model.transcribe(
str(audio_path),
language=language,
beam_size=beam_size,
initial_prompt=WHISPER_PROMPT,
word_timestamps=False,
vad_filter=True, # Фильтрация тишины — полезно для записей с паузами
vad_parameters={
"min_silence_duration_ms": 500, # Паузы >0.5с считаются тишиной
"speech_pad_ms": 200,
},
)
# Материализуем генератор сегментов
segments = []
full_text_parts = []
for seg in segments_gen:
segments.append({
"start": round(seg.start, 2),
"end": round(seg.end, 2),
"text": seg.text.strip(),
})
full_text_parts.append(seg.text.strip())
full_text = " ".join(full_text_parts)
logger.info(
"Транскрипция завершена. Длительность: %.1f сек, слов ~%d",
info.duration, len(full_text.split())
)
return TranscriptResult(
text=full_text,
language=info.language,
duration_seconds=round(info.duration, 1),
segments=segments,
model_used=model_size,
)
def _resolve_device(device: str, compute_type: str) -> tuple[str, str]:
"""Определяет оптимальное устройство и тип вычислений."""
if device != "auto" and compute_type != "auto":
return device, compute_type
# Проверяем наличие CUDA
try:
from torch import cuda
has_cuda = cuda.is_available()
except ImportError:
has_cuda = False
if device == "auto":
device = "cuda" if has_cuda else "cpu"
if compute_type == "auto":
if device == "cuda":
compute_type = "float16" # GPU: float16 быстрее и точнее чем int8
else:
compute_type = "int8" # CPU: int8 значительно быстрее float32
return device, compute_type
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
import sys
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
parser = argparse.ArgumentParser(
description="Распознавание аудиоответов ЕГЭ (говорение, английский язык)"
)
parser.add_argument("audio", help="Путь к аудиофайлу")
parser.add_argument(
"--model", default=DEFAULT_MODEL,
choices=["tiny", "base", "small", "medium", "large-v2", "large-v3"],
help=f"Размер модели Whisper (по умолчанию: {DEFAULT_MODEL})"
)
parser.add_argument(
"--device", default="auto",
choices=["auto", "cpu", "cuda"],
help="Устройство для инференса (по умолчанию: auto)"
)
args = parser.parse_args()
try:
result = transcribe(args.audio, model_size=args.model, device=args.device)
print(result.text)
except (FileNotFoundError, RuntimeError) as e:
print(f"Ошибка: {e}", file=sys.stderr)
sys.exit(1)