ege-skill/ege-checker/recognition.py

210 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
recognition.py — модуль распознавания аудиофайла с ответами ученика ЕГЭ (аудирование, английский язык).
Зависимости:
pip install faster-whisper
"""
from __future__ import annotations
import re
import logging
from dataclasses import dataclass, field
from pathlib import Path
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Константы
# ---------------------------------------------------------------------------
# Модели faster-whisper по убыванию скорости / возрастанию качества:
# tiny, base, small, medium, large-v2, large-v3
DEFAULT_MODEL = "medium"
# Подсказка для Whisper — описывает формат ответов ученика.
# Критически важна для правильного распознавания "one/two/three" как цифр
# и "A equals 3" как ответа на задание 1.
WHISPER_PROMPT = (
"Student answers to EGE English listening exam. "
"Task one matching: speaker A answer three, speaker B answer one, "
"speaker C answer five, speaker D answer seven, speaker E answer two, speaker F answer four. "
"Tasks two through nine True False Not Stated: "
"task two true, task three false, task four not stated. "
"Tasks ten through eighteen multiple choice one two or three: "
"task ten two, task eleven one, task twelve three."
)
# ---------------------------------------------------------------------------
# Структуры данных
# ---------------------------------------------------------------------------
@dataclass
class TranscriptResult:
"""Результат транскрипции аудиофайла."""
text: str # Полный текст транскрипта
language: str # Определённый язык ("en")
duration_seconds: float # Длительность аудио в секундах
segments: list[dict] = field(default_factory=list) # Детальные сегменты с таймкодами
model_used: str = DEFAULT_MODEL
# ---------------------------------------------------------------------------
# Транскрипция
# ---------------------------------------------------------------------------
def transcribe(
audio_path: str | Path,
model_size: str = DEFAULT_MODEL,
device: str = "auto",
compute_type: str = "auto",
language: str = "en",
beam_size: int = 5,
) -> TranscriptResult:
"""
Транскрибирует аудиофайл с ответами ученика.
Args:
audio_path: Путь к аудиофайлу (MP3, WAV, M4A, OGG, WEBM, FLAC).
model_size: Размер модели Whisper: tiny/base/small/medium/large-v2/large-v3.
medium — хороший баланс скорость/качество для ЕГЭ.
large-v3 — максимальное качество, медленнее.
device: "auto" | "cpu" | "cuda". "auto" выберет GPU если доступен.
compute_type: "auto" | "int8" | "float16" | "float32".
"auto" подберёт оптимальный тип для устройства.
language: Язык аудио. "en" для ответов на английском.
beam_size: Ширина луча beam search. 5 — стандарт, выше = точнее но медленнее.
Returns:
TranscriptResult с текстом, языком, длительностью и сегментами.
Raises:
FileNotFoundError: Если аудиофайл не найден.
RuntimeError: Если faster-whisper не установлен.
"""
try:
from faster_whisper import WhisperModel
except ImportError:
raise RuntimeError(
"faster-whisper не установлен. Установите: pip install faster-whisper"
)
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"Аудиофайл не найден: {audio_path}")
# Автовыбор устройства и типа вычислений
resolved_device, resolved_compute = _resolve_device(device, compute_type)
logger.info(
"Загрузка модели %s на %s (%s)...",
model_size, resolved_device, resolved_compute
)
model = WhisperModel(
model_size,
device=resolved_device,
compute_type=resolved_compute,
)
logger.info("Транскрибирую: %s", audio_path.name)
segments_gen, info = model.transcribe(
str(audio_path),
language=language,
beam_size=beam_size,
initial_prompt=WHISPER_PROMPT,
word_timestamps=False,
vad_filter=True, # Фильтрация тишины — полезно для записей с паузами
vad_parameters={
"min_silence_duration_ms": 500, # Паузы >0.5с считаются тишиной
"speech_pad_ms": 200,
},
)
# Материализуем генератор сегментов
segments = []
full_text_parts = []
for seg in segments_gen:
segments.append({
"start": round(seg.start, 2),
"end": round(seg.end, 2),
"text": seg.text.strip(),
})
full_text_parts.append(seg.text.strip())
full_text = " ".join(full_text_parts)
logger.info(
"Транскрипция завершена. Длительность: %.1f сек, слов ~%d",
info.duration, len(full_text.split())
)
return TranscriptResult(
text=full_text,
language=info.language,
duration_seconds=round(info.duration, 1),
segments=segments,
model_used=model_size,
)
def _resolve_device(device: str, compute_type: str) -> tuple[str, str]:
"""Определяет оптимальное устройство и тип вычислений."""
if device != "auto" and compute_type != "auto":
return device, compute_type
# Проверяем наличие CUDA
try:
from torch import cuda
has_cuda = cuda.is_available()
except ImportError:
has_cuda = False
if device == "auto":
device = "cuda" if has_cuda else "cpu"
if compute_type == "auto":
if device == "cuda":
compute_type = "float16" # GPU: float16 быстрее и точнее чем int8
else:
compute_type = "int8" # CPU: int8 значительно быстрее float32
return device, compute_type
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
import sys
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
parser = argparse.ArgumentParser(
description="Распознавание аудиоответов ЕГЭ (аудирование, английский язык)"
)
parser.add_argument("audio", help="Путь к аудиофайлу")
parser.add_argument(
"--model", default=DEFAULT_MODEL,
choices=["tiny", "base", "small", "medium", "large-v2", "large-v3"],
help=f"Размер модели Whisper (по умолчанию: {DEFAULT_MODEL})"
)
parser.add_argument(
"--device", default="auto",
choices=["auto", "cpu", "cuda"],
help="Устройство для инференса (по умолчанию: auto)"
)
args = parser.parse_args()
try:
result = transcribe(args.audio, model_size=args.model, device=args.device)
print(result.text)
except (FileNotFoundError, RuntimeError) as e:
print(f"Ошибка: {e}", file=sys.stderr)
sys.exit(1)