import sys from transformers import AutoTokenizer, AutoModelForSeq2SeqLM def test_mt0(): model_name = "bigscience/mt0-base" print(f"Loading {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Test cases: (Language, Prompt, Input) # MT0 is instruction tuned, so we should prompt it in the target language or English. # Cross-lingual prompting (English prompt -> Target tasks) is usually supported. test_cases = [ ("English", "Correct grammar:", "he go to school yesterday"), ("Polish", "Popraw gramatykę:", "to jest testowe zdanie bez kropki"), ("Finnish", "Korjaa kielioppi:", "tämä on testilause ilman pistettä"), ("Russian", "Исправь грамматику:", "это тестовое предложение без точки"), ("Japanese", "文法を直してください:", "これは点のないテスト文です"), ("Spanish", "Corrige la gramática:", "esta es una oración de prueba sin punto"), ] print("\nStarting MT0 Tests:\n") for lang, prompt_text, input_text in test_cases: full_input = f"{prompt_text} {input_text}" inputs = tokenizer(full_input, return_tensors="pt") outputs = model.generate(inputs.input_ids, max_length=128) corrected = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"[{lang}]") print(f"Input: {full_input}") print(f"Output: {corrected}") print("-" * 20) if __name__ == "__main__": test_mt0()