Script python permettant de traduire un long texte
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

validation.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import torch
  2. from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
  3. from peft import PeftModel
  4. from datasets import load_dataset
  5. from nltk.translate.bleu_score import corpus_bleu
  6. # ----------------------------
  7. # Configuration
  8. # ----------------------------
  9. BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" # base model
  10. LORA_DIR = "./qwen2.5-7b-uk-fr-lora" # fine-tuned LoRA
  11. VALIDATION_FILE = "validation.jsonl" # small validation subset
  12. MAX_INPUT_LENGTH = 1024
  13. DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
  14. # Liste des prompts à tester
  15. PROMPTS_TO_TEST = [
  16. {
  17. "name": "Prompt de base",
  18. "prompt": "Traduis la phrase ukrainienne suivante en français: {text}"
  19. },
  20. {
  21. "name": "Prompt spécialisé mémoires",
  22. "prompt": (
  23. "Tu es un traducteur spécialisé dans les mémoires ukrainiennes des années 1910.\n"
  24. "- Garde le style narratif et les tournures orales de l'auteur.\n"
  25. "- Respecte les règles de traduction suivantes :\n\n"
  26. "Règles strictes :\n"
  27. "1. **Conserve tous les noms de lieux** dans leur forme originale (ex. : Львів → Lviv, mais ajoute une note si nécessaire entre [ ]).\n"
  28. "2. **Respecte le style narratif** : garde les tournures orales et les expressions propres à l'auteur.\n\n"
  29. "Voici la phrase à traduire :\nUkrainien : {text}\nFrançais :"
  30. )
  31. },
  32. {
  33. "name": "Prompt détaillé",
  34. "prompt": (
  35. "Tu es un expert en traduction littéraire spécialisé dans les textes historiques ukrainiens.\n"
  36. "Règles à suivre absolument :\n"
  37. "1. Conserve tous les noms propres et toponymes dans leur forme originale\n"
  38. "2. Préserve le style et le registre de l'auteur original\n"
  39. "3. Ajoute des notes entre crochets pour expliquer les références culturelles si nécessaire\n"
  40. "4. Traduis de manière naturelle en français tout en restant fidèle au texte source\n\n"
  41. "Texte à traduire :\nUkrainien : {text}\nTraduction française :"
  42. )
  43. },
  44. {
  45. "name": "Prompt minimaliste",
  46. "prompt": "Traduction fidèle de l'ukrainien vers le français : {text}"
  47. }
  48. ]
  49. print("=== Loading tokenizer and model ===")
  50. # ----------------------------
  51. # Load tokenizer
  52. # ----------------------------
  53. tokenizer = AutoTokenizer.from_pretrained(
  54. BASE_MODEL,
  55. trust_remote_code=True
  56. )
  57. tokenizer.pad_token = tokenizer.eos_token
  58. tokenizer.model_max_length = MAX_INPUT_LENGTH
  59. # ----------------------------
  60. # Load base model directly on GPU
  61. # ----------------------------
  62. print(f"{80 * '_'}\nLoading base model on GPU...")
  63. base_model = AutoModelForCausalLM.from_pretrained(
  64. BASE_MODEL,
  65. torch_dtype=torch.float16,
  66. device_map={"": 0}, # all on GPU
  67. trust_remote_code=True
  68. )
  69. # ----------------------------
  70. # Apply LoRA adapter
  71. # ----------------------------
  72. print(f"{80 * '_'}\nApplying LoRA adapter...")
  73. model = PeftModel.from_pretrained(base_model, LORA_DIR)
  74. model.eval()
  75. model.to(DEVICE) # ensure everything on GPU
  76. print("Model ready for validation.")
  77. # ----------------------------
  78. # Load validation dataset
  79. # ----------------------------
  80. print(f"{80 * '_'}\nLoading validation dataset...")
  81. dataset = load_dataset("json", data_files=VALIDATION_FILE)
  82. examples = dataset["train"]
  83. print(f"{len(examples)} examples loaded for testing.")
  84. # ----------------------------
  85. # Translation function
  86. # ----------------------------
  87. @torch.inference_mode()
  88. def translate(text, prompt_template):
  89. prompt = prompt_template.format(text=text)
  90. inputs = tokenizer(
  91. prompt,
  92. return_tensors="pt",
  93. truncation=True,
  94. max_length=MAX_INPUT_LENGTH
  95. ).to(DEVICE)
  96. # Utilisation de GenerationConfig pour éviter les avertissements
  97. generation_config = GenerationConfig.from_model_config(model.config)
  98. generation_config.max_new_tokens = 256
  99. generation_config.do_sample = False
  100. outputs = model.generate(
  101. **inputs,
  102. generation_config=generation_config
  103. )
  104. result = tokenizer.decode(outputs[0], skip_special_tokens=True)
  105. # Extraction de la partie traduction
  106. if "Français :" in result:
  107. translation_part = result.split("Français :")[-1].strip()
  108. elif "Traduction française :" in result:
  109. translation_part = result.split("Traduction française :")[-1].strip()
  110. else:
  111. translation_part = result.split(text)[-1].strip()
  112. return translation_part
  113. # ----------------------------
  114. # Evaluate all prompts and select best BLEU
  115. # ----------------------------
  116. best_bleu = 0
  117. best_prompt = None
  118. all_results = {}
  119. print(f"{80 * '_'}\nTesting all prompts and computing BLEU scores...")
  120. for prompt_config in PROMPTS_TO_TEST:
  121. print(f"\n{80 * '='}\nTesting prompt: {prompt_config['name']}\n{80 * '='}")
  122. references = []
  123. hypotheses = []
  124. for i, example in enumerate(examples):
  125. src_text = example["text"]
  126. ref_text = example["translation"]
  127. pred_text = translate(src_text, prompt_config["prompt"])
  128. print(f"\n[{i+1}] Source: {src_text}")
  129. print(f" Reference: {ref_text}")
  130. print(f" Prediction: {pred_text}")
  131. references.append([ref_text.split()])
  132. hypotheses.append(pred_text.split())
  133. bleu_score = corpus_bleu(references, hypotheses) * 100
  134. print(f"\n=== Corpus BLEU score for '{prompt_config['name']}': {bleu_score:.4f} ===")
  135. all_results[prompt_config["name"]] = bleu_score
  136. if bleu_score > best_bleu:
  137. best_bleu = bleu_score
  138. best_prompt = prompt_config
  139. # ----------------------------
  140. # Display results
  141. # ----------------------------
  142. print(f"\n{80 * '='}\nFINAL RESULTS\n{80 * '='}")
  143. for prompt_name, score in all_results.items():
  144. print(f"{prompt_name}: {score:.4f}")
  145. print(f"\nBEST PROMPT: {best_prompt['name']} with BLEU score: {best_bleu:.4f}")
  146. print(f"Prompt content:\n{best_prompt['prompt']}")