253 lines
10 KiB
Python
253 lines
10 KiB
Python
"""
|
|
Whisper Transcriber Module.
|
|
===========================
|
|
Transcriber Module.
|
|
===================
|
|
|
|
Handles audio transcription using faster-whisper.
|
|
Runs IN-PROCESS (no subprocess) to ensure stability on all systems.
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
from typing import Optional
|
|
import numpy as np
|
|
from src.core.config import ConfigManager
|
|
from src.core.paths import get_models_path
|
|
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
torch = None
|
|
|
|
# Import directly - valid since we are now running in the full environment
|
|
|
|
|
|
class WhisperTranscriber:
|
|
"""
|
|
Manages the faster-whisper model and transcription process.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize settings."""
|
|
self.config = ConfigManager()
|
|
self.model = None
|
|
self.current_model_size = None
|
|
self.current_compute_device = None
|
|
self.current_compute_type = None
|
|
|
|
def load_model(self):
|
|
"""
|
|
Loads the model specified in config.
|
|
Safe to call multiple times (checks if reload needed).
|
|
"""
|
|
size = self.config.get("model_size")
|
|
device = self.config.get("compute_device")
|
|
compute = self.config.get("compute_type")
|
|
|
|
# Check if already loaded
|
|
if (self.model and
|
|
self.current_model_size == size and
|
|
self.current_compute_device == device and
|
|
self.current_compute_type == compute):
|
|
return
|
|
|
|
logging.info(f"Loading Model: {size} on {device} ({compute})...")
|
|
|
|
try:
|
|
# Construct path to local model for offline support
|
|
new_path = get_models_path() / f"faster-whisper-{size}"
|
|
model_input = str(new_path) if new_path.exists() else size
|
|
|
|
# Force offline if path exists to avoid HF errors
|
|
local_only = new_path.exists()
|
|
|
|
try:
|
|
from faster_whisper import WhisperModel
|
|
self.model = WhisperModel(
|
|
model_input,
|
|
device=device,
|
|
compute_type=compute,
|
|
download_root=str(get_models_path()),
|
|
local_files_only=local_only
|
|
)
|
|
except Exception as load_err:
|
|
# CRITICAL FALLBACK: If CUDA/cublas fails (AMD/Intel users), fallback to CPU
|
|
err_str = str(load_err).lower()
|
|
if "cublas" in err_str or "cudnn" in err_str or "library" in err_str or "device" in err_str:
|
|
logging.warning(f"CUDA Init Failed ({load_err}). Falling back to CPU...")
|
|
self.config.set("compute_device", "cpu") # Update config for persistence/UI
|
|
self.current_compute_device = "cpu"
|
|
|
|
self.model = WhisperModel(
|
|
model_input,
|
|
device="cpu",
|
|
compute_type="int8", # CPU usually handles int8 well with newer extensions, or standard
|
|
download_root=str(get_models_path()),
|
|
local_files_only=local_only
|
|
)
|
|
else:
|
|
raise load_err
|
|
|
|
self.current_model_size = size
|
|
self.current_compute_device = device
|
|
self.current_compute_type = compute
|
|
logging.info("Model loaded successfully.")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to load model: {e}")
|
|
self.model = None
|
|
|
|
# Auto-Repair: Detect vocabulary/corrupt errors
|
|
err_str = str(e).lower()
|
|
if "vocabulary" in err_str or "tokenizer" in err_str or "config.json" in err_str:
|
|
# ... existing auto-repair logic ...
|
|
logging.warning("Corrupt model detected on load. Attempting to delete and reset...")
|
|
try:
|
|
import shutil
|
|
# Differentiate between simple path and HF path
|
|
new_path = get_models_path() / f"faster-whisper-{size}"
|
|
if new_path.exists():
|
|
shutil.rmtree(new_path)
|
|
logging.info(f"Deleted corrupt model at {new_path}")
|
|
else:
|
|
# Try legacy HF path
|
|
hf_path = get_models_path() / f"models--Systran--faster-whisper-{size}"
|
|
if hf_path.exists():
|
|
shutil.rmtree(hf_path)
|
|
logging.info(f"Deleted corrupt HF model at {hf_path}")
|
|
|
|
# Notify UI to refresh state (will show 'Download' button now)
|
|
# We can't reach bridge easily here without passing it in,
|
|
# but the UI polls or listens to logs.
|
|
# The user will simply see "Model Missing" in settings after this.
|
|
except Exception as del_err:
|
|
logging.error(f"Failed to delete corrupt model: {del_err}")
|
|
|
|
def transcribe(self, audio_data, is_file: bool = False, task: Optional[str] = None) -> str:
|
|
"""
|
|
Transcribe audio data.
|
|
"""
|
|
logging.info(f"Starting transcription... (is_file={is_file}, task={task})")
|
|
|
|
# Ensure model is loaded
|
|
if not self.model:
|
|
self.load_model()
|
|
if not self.model:
|
|
return "Error: Model failed to load. Please check Settings -> Model Info."
|
|
|
|
try:
|
|
# Config
|
|
beam_size = int(self.config.get("beam_size"))
|
|
best_of = int(self.config.get("best_of"))
|
|
vad = False if is_file else self.config.get("vad_filter")
|
|
language = self.config.get("language")
|
|
|
|
# Use task override if provided, otherwise config
|
|
# Ensure safe string and lowercase ("transcribe" vs "Transcribe")
|
|
raw_task = task if task else self.config.get("task")
|
|
final_task = str(raw_task).strip().lower() if raw_task else "transcribe"
|
|
|
|
# Sanity check for valid Whisper tasks
|
|
if final_task not in ["transcribe", "translate"]:
|
|
logging.warning(f"Invalid task '{final_task}' detected. Defaulting to 'transcribe'.")
|
|
final_task = "transcribe"
|
|
|
|
# Language handling
|
|
final_language = language if language != "auto" else None
|
|
|
|
# Anti-Hallucination: Force condition_on_previous_text=False for translation
|
|
condition_prev = self.config.get("condition_on_previous_text")
|
|
|
|
# Helper options for Translation Stability
|
|
initial_prompt = self.config.get("initial_prompt")
|
|
|
|
if final_task == "translate":
|
|
condition_prev = False
|
|
# Force beam search if user has set it to greedy (1)
|
|
# Translation requires more search breadth to find the English mapping
|
|
if beam_size < 5:
|
|
logging.info("Forcing beam_size=5 for Translation task.")
|
|
beam_size = 5
|
|
|
|
# Inject guidance prompt if none exists
|
|
if not initial_prompt:
|
|
initial_prompt = "Translate this to English."
|
|
|
|
logging.info(f"Model Dispatch: Task='{final_task}', Language='{final_language}', ConditionPrev={condition_prev}, Beam={beam_size}")
|
|
|
|
# Build arguments dynamically to avoid passing None if that's the issue
|
|
transcribe_opts = {
|
|
"beam_size": beam_size,
|
|
"best_of": best_of,
|
|
"vad_filter": vad,
|
|
"task": final_task,
|
|
"vad_parameters": dict(min_silence_duration_ms=500),
|
|
"condition_on_previous_text": condition_prev,
|
|
"without_timestamps": True
|
|
}
|
|
|
|
if initial_prompt:
|
|
transcribe_opts["initial_prompt"] = initial_prompt
|
|
|
|
# Only add language if it's explicitly set (not None/Auto)
|
|
# This avoids potentially confusing the model with explicit None
|
|
if final_language:
|
|
transcribe_opts["language"] = final_language
|
|
|
|
# Transcribe
|
|
segments, info = self.model.transcribe(audio_data, **transcribe_opts)
|
|
|
|
# Aggregate text
|
|
text_result = ""
|
|
for segment in segments:
|
|
text_result += segment.text + " "
|
|
|
|
text_result = text_result.strip()
|
|
|
|
# Low VRAM Mode: Unload Whisper Model immediately
|
|
if self.config.get("unload_models_after_use"):
|
|
self.unload_model()
|
|
|
|
logging.info(f"Final Transcription Output: '{text_result}'")
|
|
return text_result
|
|
|
|
except Exception as e:
|
|
logging.error(f"Transcription failed: {e}")
|
|
return f"Error: {str(e)}"
|
|
|
|
def model_exists(self, size: str) -> bool:
|
|
"""Checks if a model size is already downloaded."""
|
|
new_path = get_models_path() / f"faster-whisper-{size}"
|
|
if new_path.exists():
|
|
# Strict check
|
|
required = ["config.json", "model.bin", "vocabulary.json"]
|
|
if all((new_path / f).exists() for f in required):
|
|
return True
|
|
|
|
# Legacy HF cache check
|
|
folder_name = f"models--Systran--faster-whisper-{size}"
|
|
path = get_models_path() / folder_name / "snapshots"
|
|
if path.exists() and any(path.iterdir()):
|
|
return True
|
|
|
|
return False
|
|
|
|
def unload_model(self):
|
|
"""
|
|
Unloads model to free memory.
|
|
"""
|
|
if self.model:
|
|
del self.model
|
|
|
|
self.model = None
|
|
self.current_model_size = None
|
|
|
|
# Force garbage collection
|
|
import gc
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
logging.info("Whisper Model unloaded (Low VRAM Mode).")
|