2 changed files with 137 additions and 64 deletions
@ -1,77 +1,170 @@ |
|||
import torch |
|||
from transformers import AutoTokenizer, AutoModelForCausalLM |
|||
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
|||
from peft import PeftModel |
|||
from datasets import load_dataset |
|||
from nltk.translate.bleu_score import corpus_bleu |
|||
|
|||
# ---------------------------- |
|||
# Configuration |
|||
# ---------------------------- |
|||
MODEL_DIR = "./qwen2.5-7b-uk-fr-lora" # dossier où tu as sauvegardé LoRA |
|||
VALIDATION_FILE = "validation.jsonl" # petit subset de test (5-50 phrases) |
|||
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" # base model |
|||
LORA_DIR = "./qwen2.5-7b-uk-fr-lora" # fine-tuned LoRA |
|||
VALIDATION_FILE = "validation.jsonl" # small validation subset |
|||
MAX_INPUT_LENGTH = 1024 |
|||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|||
|
|||
print("=== Loading model and tokenizer ===") |
|||
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) |
|||
# Liste des prompts à tester |
|||
PROMPTS_TO_TEST = [ |
|||
{ |
|||
"name": "Prompt de base", |
|||
"prompt": "Traduis la phrase ukrainienne suivante en français: {text}" |
|||
}, |
|||
{ |
|||
"name": "Prompt spécialisé mémoires", |
|||
"prompt": ( |
|||
"Tu es un traducteur spécialisé dans les mémoires ukrainiennes des années 1910.\n" |
|||
"- Garde le style narratif et les tournures orales de l'auteur.\n" |
|||
"- Respecte les règles de traduction suivantes :\n\n" |
|||
"Règles strictes :\n" |
|||
"1. **Conserve tous les noms de lieux** dans leur forme originale (ex. : Львів → Lviv, mais ajoute une note si nécessaire entre [ ]).\n" |
|||
"2. **Respecte le style narratif** : garde les tournures orales et les expressions propres à l'auteur.\n\n" |
|||
"Voici la phrase à traduire :\nUkrainien : {text}\nFrançais :" |
|||
) |
|||
}, |
|||
{ |
|||
"name": "Prompt détaillé", |
|||
"prompt": ( |
|||
"Tu es un expert en traduction littéraire spécialisé dans les textes historiques ukrainiens.\n" |
|||
"Règles à suivre absolument :\n" |
|||
"1. Conserve tous les noms propres et toponymes dans leur forme originale\n" |
|||
"2. Préserve le style et le registre de l'auteur original\n" |
|||
"3. Ajoute des notes entre crochets pour expliquer les références culturelles si nécessaire\n" |
|||
"4. Traduis de manière naturelle en français tout en restant fidèle au texte source\n\n" |
|||
"Texte à traduire :\nUkrainien : {text}\nTraduction française :" |
|||
) |
|||
}, |
|||
{ |
|||
"name": "Prompt minimaliste", |
|||
"prompt": "Traduction fidèle de l'ukrainien vers le français : {text}" |
|||
} |
|||
] |
|||
|
|||
print("=== Loading tokenizer and model ===") |
|||
|
|||
# ---------------------------- |
|||
# Load tokenizer |
|||
# ---------------------------- |
|||
tokenizer = AutoTokenizer.from_pretrained( |
|||
BASE_MODEL, |
|||
trust_remote_code=True |
|||
) |
|||
tokenizer.pad_token = tokenizer.eos_token |
|||
tokenizer.model_max_length = MAX_INPUT_LENGTH |
|||
|
|||
model = AutoModelForCausalLM.from_pretrained( |
|||
MODEL_DIR, |
|||
device_map="auto", |
|||
# ---------------------------- |
|||
# Load base model directly on GPU |
|||
# ---------------------------- |
|||
print(f"{80 * '_'}\nLoading base model on GPU...") |
|||
base_model = AutoModelForCausalLM.from_pretrained( |
|||
BASE_MODEL, |
|||
torch_dtype=torch.float16, |
|||
device_map={"": 0}, # all on GPU |
|||
trust_remote_code=True |
|||
) |
|||
model.eval() |
|||
|
|||
print("Model loaded.") |
|||
# ---------------------------- |
|||
# Apply LoRA adapter |
|||
# ---------------------------- |
|||
print(f"{80 * '_'}\nApplying LoRA adapter...") |
|||
model = PeftModel.from_pretrained(base_model, LORA_DIR) |
|||
model.eval() |
|||
model.to(DEVICE) # ensure everything on GPU |
|||
print("Model ready for validation.") |
|||
|
|||
# ---------------------------- |
|||
# Load validation dataset |
|||
# ---------------------------- |
|||
print("Loading validation dataset...") |
|||
print(f"{80 * '_'}\nLoading validation dataset...") |
|||
dataset = load_dataset("json", data_files=VALIDATION_FILE) |
|||
examples = dataset["train"] # petit subset |
|||
examples = dataset["train"] |
|||
print(f"{len(examples)} examples loaded for testing.") |
|||
|
|||
# ---------------------------- |
|||
# Function to generate translation |
|||
# Translation function |
|||
# ---------------------------- |
|||
def translate(text): |
|||
prompt = f"Translate the following Ukrainian text into French:\nUkrainian: {text}\nFrench:" |
|||
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_INPUT_LENGTH).to(DEVICE) |
|||
with torch.no_grad(): |
|||
outputs = model.generate( |
|||
**inputs, |
|||
max_new_tokens=256, |
|||
do_sample=False, # deterministic |
|||
eos_token_id=tokenizer.eos_token_id, |
|||
pad_token_id=tokenizer.pad_token_id |
|||
) |
|||
@torch.inference_mode() |
|||
def translate(text, prompt_template): |
|||
prompt = prompt_template.format(text=text) |
|||
inputs = tokenizer( |
|||
prompt, |
|||
return_tensors="pt", |
|||
truncation=True, |
|||
max_length=MAX_INPUT_LENGTH |
|||
).to(DEVICE) |
|||
|
|||
# Utilisation de GenerationConfig pour éviter les avertissements |
|||
generation_config = GenerationConfig.from_model_config(model.config) |
|||
generation_config.max_new_tokens = 256 |
|||
generation_config.do_sample = False |
|||
|
|||
outputs = model.generate( |
|||
**inputs, |
|||
generation_config=generation_config |
|||
) |
|||
|
|||
result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|||
# Remove prompt from result |
|||
return result.replace(prompt, "").strip() |
|||
|
|||
# Extraction de la partie traduction |
|||
if "Français :" in result: |
|||
translation_part = result.split("Français :")[-1].strip() |
|||
elif "Traduction française :" in result: |
|||
translation_part = result.split("Traduction française :")[-1].strip() |
|||
else: |
|||
translation_part = result.split(text)[-1].strip() |
|||
|
|||
return translation_part |
|||
|
|||
# ---------------------------- |
|||
# Test all examples and compute BLEU |
|||
# Evaluate all prompts and select best BLEU |
|||
# ---------------------------- |
|||
print("Generating translations...") |
|||
references = [] |
|||
hypotheses = [] |
|||
best_bleu = 0 |
|||
best_prompt = None |
|||
all_results = {} |
|||
|
|||
print(f"{80 * '_'}\nTesting all prompts and computing BLEU scores...") |
|||
|
|||
for prompt_config in PROMPTS_TO_TEST: |
|||
print(f"\n{80 * '='}\nTesting prompt: {prompt_config['name']}\n{80 * '='}") |
|||
references = [] |
|||
hypotheses = [] |
|||
|
|||
for i, example in enumerate(examples): |
|||
src_text = example["text"] |
|||
ref_text = example["translation"] |
|||
pred_text = translate(src_text) |
|||
|
|||
print(f"\n[{i+1}] Source: {src_text}") |
|||
print(f" Reference: {ref_text}") |
|||
print(f" Prediction: {pred_text}") |
|||
for i, example in enumerate(examples): |
|||
src_text = example["text"] |
|||
ref_text = example["translation"] |
|||
pred_text = translate(src_text, prompt_config["prompt"]) |
|||
|
|||
# Prepare for BLEU (tokenized by space) |
|||
references.append([ref_text.split()]) |
|||
hypotheses.append(pred_text.split()) |
|||
print(f"\n[{i+1}] Source: {src_text}") |
|||
print(f" Reference: {ref_text}") |
|||
print(f" Prediction: {pred_text}") |
|||
|
|||
references.append([ref_text.split()]) |
|||
hypotheses.append(pred_text.split()) |
|||
|
|||
bleu_score = corpus_bleu(references, hypotheses) * 100 |
|||
print(f"\n=== Corpus BLEU score for '{prompt_config['name']}': {bleu_score:.4f} ===") |
|||
|
|||
all_results[prompt_config["name"]] = bleu_score |
|||
|
|||
if bleu_score > best_bleu: |
|||
best_bleu = bleu_score |
|||
best_prompt = prompt_config |
|||
|
|||
# ---------------------------- |
|||
# Display results |
|||
# ---------------------------- |
|||
print(f"\n{80 * '='}\nFINAL RESULTS\n{80 * '='}") |
|||
for prompt_name, score in all_results.items(): |
|||
print(f"{prompt_name}: {score:.4f}") |
|||
|
|||
# Compute corpus BLEU |
|||
bleu_score = corpus_bleu(references, hypotheses) |
|||
print(f"\n=== Corpus BLEU score: {bleu_score:.4f} ===") |
|||
print(f"\nBEST PROMPT: {best_prompt['name']} with BLEU score: {best_bleu:.4f}") |
|||
print(f"Prompt content:\n{best_prompt['prompt']}") |
|||
Write
Preview
Loading…
Cancel
Save
Reference in new issue