Script python permettant de traduire un long texte
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

finetunning.py 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import os
  2. import torch
  3. from datasets import load_dataset
  4. from transformers import (
  5. AutoTokenizer,
  6. AutoModelForCausalLM,
  7. TrainingArguments,
  8. )
  9. from peft import (
  10. LoraConfig,
  11. get_peft_model,
  12. prepare_model_for_kbit_training,
  13. )
  14. from trl import SFTTrainer
  15. # ----------------------------
  16. # Environment safety (Windows)
  17. # ----------------------------
  18. os.environ["TORCHDYNAMO_DISABLE"] = "1"
  19. # ----------------------------
  20. # Model configuration
  21. # ----------------------------
  22. MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
  23. print(f"=== Starting fine-tuning script {MODEL_NAME} ===")
  24. print(f"{80 * '_'}\n[1/7] Loading tokenizer...")
  25. tokenizer = AutoTokenizer.from_pretrained(
  26. MODEL_NAME,
  27. trust_remote_code=True
  28. )
  29. # Ensure padding is defined
  30. tokenizer.pad_token = tokenizer.eos_token
  31. tokenizer.model_max_length = 1024
  32. print("Tokenizer loaded and configured.")
  33. print(f"{80 * '_'}\n[2/7] Loading model in 4-bit mode (QLoRA)...")
  34. model = AutoModelForCausalLM.from_pretrained(
  35. MODEL_NAME,
  36. load_in_4bit=True,
  37. device_map="auto",
  38. torch_dtype=torch.float16, # weights in fp16, gradients fp32
  39. trust_remote_code=True,
  40. )
  41. print("Model loaded.")
  42. print(f"{80 * '_'}\n[3/7] Preparing model for k-bit training...")
  43. model = prepare_model_for_kbit_training(model)
  44. # Fix future PyTorch checkpointing behavior
  45. model.gradient_checkpointing_enable(
  46. gradient_checkpointing_kwargs={"use_reentrant": False}
  47. )
  48. print("Model prepared for k-bit training.")
  49. # ----------------------------
  50. # LoRA configuration
  51. # ----------------------------
  52. print(f"{80 * '_'}\n[4/7] Configuring LoRA adapters...")
  53. lora_config = LoraConfig(
  54. r=16,
  55. lora_alpha=32,
  56. lora_dropout=0.05,
  57. bias="none",
  58. task_type="CAUSAL_LM",
  59. target_modules=[
  60. "q_proj",
  61. "k_proj",
  62. "v_proj",
  63. "o_proj",
  64. "gate_proj",
  65. "up_proj",
  66. "down_proj",
  67. ],
  68. )
  69. model = get_peft_model(model, lora_config)
  70. model.print_trainable_parameters()
  71. print("LoRA adapters attached.")
  72. # ----------------------------
  73. # Dataset loading
  74. # ----------------------------
  75. print(f"{80 * '_'}\n[5/7] Loading dataset from JSON file...")
  76. dataset = load_dataset(
  77. "json",
  78. data_files="traductions.json"
  79. )
  80. print(f"Dataset loaded with {len(dataset['train'])} samples.")
  81. print("Formatting dataset for Ukrainian → French translation...")
  82. def format_prompt(example):
  83. prompt = (
  84. "Translate the following Ukrainian text into French.\n\n"
  85. f"Ukrainian: {example['text']}\n"
  86. f"French: {example['translation']}"
  87. )
  88. return {"text": prompt}
  89. dataset = dataset.map(
  90. format_prompt,
  91. remove_columns=dataset["train"].column_names
  92. )
  93. print("Dataset formatting completed.")
  94. # ----------------------------
  95. # Training arguments (AMP OFF)
  96. # ----------------------------
  97. print(f"{80 * '_'}\n[6/7] Initializing training arguments...")
  98. training_args = TrainingArguments(
  99. output_dir="./qwen2.5-7b-uk-fr-lora",
  100. per_device_train_batch_size=1,
  101. gradient_accumulation_steps=8,
  102. learning_rate=2e-4,
  103. num_train_epochs=2, # 2 epochs usually enough for translation
  104. fp16=False,
  105. bf16=False,
  106. logging_steps=10,
  107. save_steps=500,
  108. save_total_limit=2,
  109. optim="paged_adamw_32bit",
  110. report_to="none",
  111. )
  112. print("Training arguments ready.")
  113. # ----------------------------
  114. # Trainer
  115. # ----------------------------
  116. print("Initializing SFTTrainer...")
  117. trainer = SFTTrainer(
  118. model=model,
  119. train_dataset=dataset["train"],
  120. processing_class=tokenizer,
  121. args=training_args,
  122. )
  123. print("Trainer initialized.")
  124. # ----------------------------
  125. # Train
  126. # ----------------------------
  127. print(f"{80 * '_'}\n[7/7] Starting training...")
  128. trainer.train(resume_from_checkpoint=True)
  129. print("Training completed successfully.")
  130. # ----------------------------
  131. # Save LoRA adapter
  132. # ----------------------------
  133. print("Saving LoRA adapter and tokenizer...")
  134. trainer.model.save_pretrained("./qwen2.5-7b-uk-fr-lora")
  135. tokenizer.save_pretrained("./qwen2.5-7b-uk-fr-lora")
  136. print("=== Fine-tuning finished ===")
  137. print("LoRA adapter saved in ./qwen2.5-7b-uk-fr-lora")