You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
5.9 KiB
170 lines
5.9 KiB
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
|
from peft import PeftModel
|
|
from datasets import load_dataset
|
|
from nltk.translate.bleu_score import corpus_bleu
|
|
|
|
# ----------------------------
|
|
# Configuration
|
|
# ----------------------------
|
|
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" # base model
|
|
LORA_DIR = "./qwen2.5-7b-uk-fr-lora-2epoch" # fine-tuned LoRA
|
|
VALIDATION_FILE = "validation.jsonl" # small validation subset
|
|
MAX_INPUT_LENGTH = 1024
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
# Liste des prompts à tester
|
|
PROMPTS_TO_TEST = [
|
|
{
|
|
"name": "Prompt de base",
|
|
"prompt": "Traduis la phrase ukrainienne suivante en français: {text}"
|
|
},
|
|
{
|
|
"name": "Prompt spécialisé mémoires",
|
|
"prompt": (
|
|
"Tu es un traducteur spécialisé dans les mémoires ukrainiennes des années 1910.\n"
|
|
"- Garde le style narratif et les tournures orales de l'auteur.\n"
|
|
"- Respecte les règles de traduction suivantes :\n\n"
|
|
"Règles strictes :\n"
|
|
"1. **Conserve tous les noms de lieux** dans leur forme originale (ex. : Львів → Lviv, mais ajoute une note si nécessaire entre [ ]).\n"
|
|
"2. **Respecte le style narratif** : garde les tournures orales et les expressions propres à l'auteur.\n\n"
|
|
"Voici la phrase à traduire :\nUkrainien : {text}\nFrançais :"
|
|
)
|
|
},
|
|
{
|
|
"name": "Prompt détaillé",
|
|
"prompt": (
|
|
"Tu es un expert en traduction littéraire spécialisé dans les textes historiques ukrainiens.\n"
|
|
"Règles à suivre absolument :\n"
|
|
"1. Conserve tous les noms propres et toponymes dans leur forme originale\n"
|
|
"2. Préserve le style et le registre de l'auteur original\n"
|
|
"3. Ajoute des notes entre crochets pour expliquer les références culturelles si nécessaire\n"
|
|
"4. Traduis de manière naturelle en français tout en restant fidèle au texte source\n\n"
|
|
"Texte à traduire :\nUkrainien : {text}\nTraduction française :"
|
|
)
|
|
},
|
|
{
|
|
"name": "Prompt minimaliste",
|
|
"prompt": "Traduction fidèle de l'ukrainien vers le français : {text}"
|
|
}
|
|
]
|
|
|
|
print("=== Loading tokenizer and model ===")
|
|
|
|
# ----------------------------
|
|
# Load tokenizer
|
|
# ----------------------------
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
BASE_MODEL,
|
|
trust_remote_code=True
|
|
)
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
tokenizer.model_max_length = MAX_INPUT_LENGTH
|
|
|
|
# ----------------------------
|
|
# Load base model directly on GPU
|
|
# ----------------------------
|
|
print(f"{80 * '_'}\nLoading base model on GPU...")
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
BASE_MODEL,
|
|
torch_dtype=torch.float16,
|
|
device_map={"": 0}, # all on GPU
|
|
trust_remote_code=True
|
|
)
|
|
|
|
# ----------------------------
|
|
# Apply LoRA adapter
|
|
# ----------------------------
|
|
print(f"{80 * '_'}\nApplying LoRA adapter...")
|
|
model = PeftModel.from_pretrained(base_model, LORA_DIR)
|
|
model.eval()
|
|
model.to(DEVICE) # ensure everything on GPU
|
|
print("Model ready for validation.")
|
|
|
|
# ----------------------------
|
|
# Load validation dataset
|
|
# ----------------------------
|
|
print(f"{80 * '_'}\nLoading validation dataset...")
|
|
dataset = load_dataset("json", data_files=VALIDATION_FILE)
|
|
examples = dataset["train"]
|
|
print(f"{len(examples)} examples loaded for testing.")
|
|
|
|
# ----------------------------
|
|
# Translation function
|
|
# ----------------------------
|
|
@torch.inference_mode()
|
|
def translate(text, prompt_template):
|
|
prompt = prompt_template.format(text=text)
|
|
inputs = tokenizer(
|
|
prompt,
|
|
return_tensors="pt",
|
|
truncation=True,
|
|
max_length=MAX_INPUT_LENGTH
|
|
).to(DEVICE)
|
|
|
|
# Utilisation de GenerationConfig pour éviter les avertissements
|
|
generation_config = GenerationConfig.from_model_config(model.config)
|
|
generation_config.max_new_tokens = 256
|
|
generation_config.do_sample = False
|
|
|
|
outputs = model.generate(
|
|
**inputs,
|
|
generation_config=generation_config
|
|
)
|
|
|
|
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
# Extraction de la partie traduction
|
|
if "Français :" in result:
|
|
translation_part = result.split("Français :")[-1].strip()
|
|
elif "Traduction française :" in result:
|
|
translation_part = result.split("Traduction française :")[-1].strip()
|
|
else:
|
|
translation_part = result.split(text)[-1].strip()
|
|
|
|
return translation_part
|
|
|
|
# ----------------------------
|
|
# Evaluate all prompts and select best BLEU
|
|
# ----------------------------
|
|
best_bleu = 0
|
|
best_prompt = None
|
|
all_results = {}
|
|
|
|
print(f"{80 * '_'}\nTesting all prompts and computing BLEU scores...")
|
|
|
|
for prompt_config in PROMPTS_TO_TEST:
|
|
print(f"\n{80 * '='}\nTesting prompt: {prompt_config['name']}\n{80 * '='}")
|
|
references = []
|
|
hypotheses = []
|
|
|
|
for i, example in enumerate(examples):
|
|
src_text = example["text"]
|
|
ref_text = example["translation"]
|
|
pred_text = translate(src_text, prompt_config["prompt"])
|
|
|
|
print(f"\n[{i+1}] Source: {src_text}")
|
|
print(f" Reference: {ref_text}")
|
|
print(f" Prediction: {pred_text}")
|
|
|
|
references.append([ref_text.split()])
|
|
hypotheses.append(pred_text.split())
|
|
|
|
bleu_score = corpus_bleu(references, hypotheses) * 100
|
|
print(f"\n=== Corpus BLEU score for '{prompt_config['name']}': {bleu_score:.4f} ===")
|
|
|
|
all_results[prompt_config["name"]] = bleu_score
|
|
|
|
if bleu_score > best_bleu:
|
|
best_bleu = bleu_score
|
|
best_prompt = prompt_config
|
|
|
|
# ----------------------------
|
|
# Display results
|
|
# ----------------------------
|
|
print(f"\n{80 * '='}\nFINAL RESULTS\n{80 * '='}")
|
|
for prompt_name, score in all_results.items():
|
|
print(f"{prompt_name}: {score:.4f}")
|
|
|
|
print(f"\nBEST PROMPT: {best_prompt['name']} with BLEU score: {best_bleu:.4f}")
|
|
print(f"Prompt content:\n{best_prompt['prompt']}")
|
|
|