python.traduction/Finetunning/validation.py

171 lines
5.9 KiB
Python

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel
from datasets import load_dataset
from nltk.translate.bleu_score import corpus_bleu
# ----------------------------
# Configuration
# ----------------------------
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" # base model
LORA_DIR = "./qwen2.5-7b-uk-fr-lora-2epoch" # fine-tuned LoRA
VALIDATION_FILE = "validation.jsonl" # small validation subset
MAX_INPUT_LENGTH = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Liste des prompts à tester
PROMPTS_TO_TEST = [
{
"name": "Prompt de base",
"prompt": "Traduis la phrase ukrainienne suivante en français: {text}"
},
{
"name": "Prompt spécialisé mémoires",
"prompt": (
"Tu es un traducteur spécialisé dans les mémoires ukrainiennes des années 1910.\n"
"- Garde le style narratif et les tournures orales de l'auteur.\n"
"- Respecte les règles de traduction suivantes :\n\n"
"Règles strictes :\n"
"1. **Conserve tous les noms de lieux** dans leur forme originale (ex. : Львів → Lviv, mais ajoute une note si nécessaire entre [ ]).\n"
"2. **Respecte le style narratif** : garde les tournures orales et les expressions propres à l'auteur.\n\n"
"Voici la phrase à traduire :\nUkrainien : {text}\nFrançais :"
)
},
{
"name": "Prompt détaillé",
"prompt": (
"Tu es un expert en traduction littéraire spécialisé dans les textes historiques ukrainiens.\n"
"Règles à suivre absolument :\n"
"1. Conserve tous les noms propres et toponymes dans leur forme originale\n"
"2. Préserve le style et le registre de l'auteur original\n"
"3. Ajoute des notes entre crochets pour expliquer les références culturelles si nécessaire\n"
"4. Traduis de manière naturelle en français tout en restant fidèle au texte source\n\n"
"Texte à traduire :\nUkrainien : {text}\nTraduction française :"
)
},
{
"name": "Prompt minimaliste",
"prompt": "Traduction fidèle de l'ukrainien vers le français : {text}"
}
]
print("=== Loading tokenizer and model ===")
# ----------------------------
# Load tokenizer
# ----------------------------
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = MAX_INPUT_LENGTH
# ----------------------------
# Load base model directly on GPU
# ----------------------------
print(f"{80 * '_'}\nLoading base model on GPU...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map={"": 0}, # all on GPU
trust_remote_code=True
)
# ----------------------------
# Apply LoRA adapter
# ----------------------------
print(f"{80 * '_'}\nApplying LoRA adapter...")
model = PeftModel.from_pretrained(base_model, LORA_DIR)
model.eval()
model.to(DEVICE) # ensure everything on GPU
print("Model ready for validation.")
# ----------------------------
# Load validation dataset
# ----------------------------
print(f"{80 * '_'}\nLoading validation dataset...")
dataset = load_dataset("json", data_files=VALIDATION_FILE)
examples = dataset["train"]
print(f"{len(examples)} examples loaded for testing.")
# ----------------------------
# Translation function
# ----------------------------
@torch.inference_mode()
def translate(text, prompt_template):
prompt = prompt_template.format(text=text)
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=MAX_INPUT_LENGTH
).to(DEVICE)
# Utilisation de GenerationConfig pour éviter les avertissements
generation_config = GenerationConfig.from_model_config(model.config)
generation_config.max_new_tokens = 256
generation_config.do_sample = False
outputs = model.generate(
**inputs,
generation_config=generation_config
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extraction de la partie traduction
if "Français :" in result:
translation_part = result.split("Français :")[-1].strip()
elif "Traduction française :" in result:
translation_part = result.split("Traduction française :")[-1].strip()
else:
translation_part = result.split(text)[-1].strip()
return translation_part
# ----------------------------
# Evaluate all prompts and select best BLEU
# ----------------------------
best_bleu = 0
best_prompt = None
all_results = {}
print(f"{80 * '_'}\nTesting all prompts and computing BLEU scores...")
for prompt_config in PROMPTS_TO_TEST:
print(f"\n{80 * '='}\nTesting prompt: {prompt_config['name']}\n{80 * '='}")
references = []
hypotheses = []
for i, example in enumerate(examples):
src_text = example["text"]
ref_text = example["translation"]
pred_text = translate(src_text, prompt_config["prompt"])
print(f"\n[{i+1}] Source: {src_text}")
print(f" Reference: {ref_text}")
print(f" Prediction: {pred_text}")
references.append([ref_text.split()])
hypotheses.append(pred_text.split())
bleu_score = corpus_bleu(references, hypotheses) * 100
print(f"\n=== Corpus BLEU score for '{prompt_config['name']}': {bleu_score:.4f} ===")
all_results[prompt_config["name"]] = bleu_score
if bleu_score > best_bleu:
best_bleu = bleu_score
best_prompt = prompt_config
# ----------------------------
# Display results
# ----------------------------
print(f"\n{80 * '='}\nFINAL RESULTS\n{80 * '='}")
for prompt_name, score in all_results.items():
print(f"{prompt_name}: {score:.4f}")
print(f"\nBEST PROMPT: {best_prompt['name']} with BLEU score: {best_bleu:.4f}")
print(f"Prompt content:\n{best_prompt['prompt']}")