Add local text correction engine

This commit is contained in:
2026-01-31 01:02:24 +02:00
parent 6a98142c1d
commit 32d4e328ff
10 changed files with 601 additions and 61 deletions

View File

@@ -17,6 +17,7 @@ from src.core.paths import get_base_path
DEFAULT_SETTINGS = {
"hotkey": "f8",
"hotkey_translate": "f10",
"hotkey_correct": "f9", # New: Transcribe + Correct
"model_size": "small",
"input_device": None, # Device ID (int) or Name (str), None = Default
"save_recordings": False, # Save .wav files for debugging
@@ -49,6 +50,11 @@ DEFAULT_SETTINGS = {
"condition_on_previous_text": True,
"initial_prompt": "Mm-hmm. Okay, let's go. I speak in full sentences.", # Default: Forces punctuation
# LLM Correction
"llm_enabled": False,
"llm_mode": "Standard", # "Grammar", "Standard", "Rewrite"
"llm_model_name": "llama-3.2-1b-instruct",
# Low VRAM Mode
@@ -102,9 +108,9 @@ class ConfigManager:
except Exception as e:
logging.error(f"Failed to save settings: {e}")
def get(self, key: str) -> Any:
def get(self, key: str, default: Any = None) -> Any:
"""Get a setting value."""
return self.data.get(key, DEFAULT_SETTINGS.get(key))
return self.data.get(key, DEFAULT_SETTINGS.get(key, default))

185
src/core/llm_engine.py Normal file
View File

@@ -0,0 +1,185 @@
"""
LLM Engine Module.
==================
Handles interaction with the local Llama 3.2 1B model for transcription correction.
Uses llama-cpp-python for efficient local inference.
"""
import os
import logging
from typing import Optional
from src.core.paths import get_models_path
from src.core.config import ConfigManager
try:
from llama_cpp import Llama
except ImportError:
Llama = None
class LLMEngine:
"""
Manages the Llama model and performs text correction/rewriting.
"""
def __init__(self):
self.config = ConfigManager()
self.model = None
self.current_model_path = None
# --- Mode 1: Grammar Only (Strict) ---
self.prompt_grammar = (
"You are a text correction tool. "
"Correct the grammar/spelling. Do not change punctuation or capitalization styles. "
"Do not remove any words (including profanity). Output ONLY the result."
"\n\nExample:\nInput: 'damn it works'\nOutput: 'damn it works'"
)
# --- Mode 2: Standard (Grammar + Punctuation + Caps) ---
self.prompt_standard = (
"You are a text correction tool. "
"Standardize the grammar, punctuation, and capitalization. "
"Do not remove any words (including profanity). Output ONLY the result."
"\n\nExample:\nInput: 'damn it works'\nOutput: 'Damn it works.'"
)
# --- Mode 3: Rewrite (Tone-Aware Polish) ---
self.prompt_rewrite = (
"You are a text rewriting tool. Improve flow/clarity but keep the exact tone and vocabulary. "
"Do not remove any words (including profanity). Output ONLY the result."
"\n\nExample:\nInput: 'damn it works'\nOutput: 'Damn, it works.'"
)
def load_model(self) -> bool:
"""
Loads the LLM model if it exists.
Returns True if successful, False otherwise.
"""
if Llama is None:
logging.error("llama-cpp-python not installed.")
return False
model_name = self.config.get("llm_model_name", "llama-3.2-1b-instruct")
model_dir = get_models_path() / "llm" / model_name
model_file = model_dir / "llama-3.2-1b-instruct-q4_k_m.gguf"
if not model_file.exists():
logging.warning(f"LLM Model not found at: {model_file}")
return False
if self.model and self.current_model_path == str(model_file):
return True
try:
logging.info(f"Loading LLM from {model_file}...")
n_gpu_layers = 0
try:
import torch
if torch.cuda.is_available():
n_gpu_layers = -1
except:
pass
self.model = Llama(
model_path=str(model_file),
n_gpu_layers=n_gpu_layers,
n_ctx=2048,
verbose=False
)
self.current_model_path = str(model_file)
logging.info("LLM loaded successfully.")
return True
except Exception as e:
logging.error(f"Failed to load LLM: {e}")
self.model = None
return False
def correct_text(self, text: str, mode: str = "Standard") -> str:
"""Corrects or rewrites the provided text."""
if not text or not text.strip():
return text
if not self.model:
if not self.load_model():
return text
logging.info(f"LLM Processing ({mode}): '{text}'")
system_prompt = self.prompt_standard
if mode == "Grammar": system_prompt = self.prompt_grammar
elif mode == "Rewrite": system_prompt = self.prompt_rewrite
# PREFIX INJECTION TECHNIQUE
# We end the prompt with the start of the assistant's answer specifically phrased to force compliance.
# "Here is the processed output:" forces it into a completion mode rather than a refusal mode.
prefix_injection = "Here is the processed output:\n"
prompt = (
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\nProcess this input:\n{text}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{prefix_injection}"
)
try:
output = self.model(
prompt,
max_tokens=512,
stop=["<|eot_id|>"],
echo=False,
temperature=0.1
)
result = output['choices'][0]['text'].strip()
# 1. Fallback: If result is empty, it might have just outputted nothing because we prefilled?
# Actually llama-cpp-python usually returns the *continuation*.
# So if it outputted "My corrected text.", the full logical response is "Here is...: My corrected text."
# We just want the result.
# Refusal Detection (Safety Net)
refusal_triggers = [
"I cannot", "I can't", "I am unable", "I apologize", "sorry",
"As an AI", "explicit content", "harmful content", "safety guidelines"
]
lower_res = result.lower()
if any(trig in lower_res for trig in refusal_triggers) and len(result) < 150:
logging.warning(f"LLM Refusal Detected: '{result}'. Falling back to original.")
return text # Return original text on refusal!
# --- Post-Processing ---
# 1. Strip quotes
if result.startswith('"') and result.endswith('"') and len(result) > 2 and '"' not in result[1:-1]:
result = result[1:-1]
if result.startswith("'") and result.endswith("'") and len(result) > 2 and "'" not in result[1:-1]:
result = result[1:-1]
# 2. Split by newline
if "\n" in result:
lines = result.split('\n')
clean_lines = [l.strip() for l in lines if l.strip()]
if clean_lines:
result = clean_lines[0]
# 3. Aggressive Preamble Stripping (Updates for new prefix)
import re
prefixes = [
r"^Here is the processed output:?\s*", # The one we injected
r"^Here is the corrected text:?\s*",
r"^Here is the rewritten text:?\s*",
r"^Here's the result:?\s*",
r"^Sure,? here is regex.*:?\s*",
r"^Output:?\s*",
r"^Processing result:?\s*",
]
for p in prefixes:
result = re.sub(p, "", result, flags=re.IGNORECASE).strip()
if result.startswith('"') and result.endswith('"') and len(result) > 2 and '"' not in result[1:-1]:
result = result[1:-1]
logging.info(f"LLM Result: '{result}'")
return result
except Exception as e:
logging.error(f"LLM inference failed: {e}")
return text # Fail safe logic