Script python permettant de traduire un long texte
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

finetunning.py 5.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import os
  2. import torch
  3. from datasets import load_dataset
  4. from transformers import (
  5. AutoTokenizer,
  6. AutoModelForCausalLM,
  7. TrainingArguments,
  8. BitsAndBytesConfig,
  9. )
  10. from peft import (
  11. LoraConfig,
  12. get_peft_model,
  13. prepare_model_for_kbit_training,
  14. )
  15. from trl import SFTTrainer
  16. # ----------------------------
  17. # Environment safety (Windows)
  18. # ----------------------------
  19. os.environ["TORCHDYNAMO_DISABLE"] = "1"
  20. # ----------------------------
  21. # Global configuration
  22. # ----------------------------
  23. MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
  24. OUTPUT_DIR = "./qwen2.5-7b-uk-fr-lora"
  25. DATA_FILE = "paires_clean.json"
  26. MAX_SEQ_LENGTH = 1024
  27. print(f"\n=== Starting fine-tuning script for {MODEL_NAME} ===\n")
  28. # ----------------------------
  29. # [1/7] Tokenizer
  30. # ----------------------------
  31. print(f"{80 * '_'}\n[1/7] Loading tokenizer...")
  32. tokenizer = AutoTokenizer.from_pretrained(
  33. MODEL_NAME,
  34. trust_remote_code=True
  35. )
  36. tokenizer.pad_token = tokenizer.eos_token
  37. tokenizer.model_max_length = MAX_SEQ_LENGTH
  38. print("Tokenizer loaded.")
  39. print(f"Pad token id: {tokenizer.pad_token_id}")
  40. print(f"Max sequence length: {tokenizer.model_max_length}")
  41. # ----------------------------
  42. # [2/7] Quantization config (QLoRA)
  43. # ----------------------------
  44. print(f"{80 * '_'}\n[2/7] Loading model in 4-bit mode (optimized QLoRA)...")
  45. assert torch.cuda.is_available(), "CUDA GPU not detected!"
  46. print(f"Using GPU: {torch.cuda.get_device_name(0)}")
  47. bnb_config = BitsAndBytesConfig(
  48. load_in_4bit=True,
  49. bnb_4bit_quant_type="nf4",
  50. bnb_4bit_compute_dtype=torch.float16,
  51. bnb_4bit_use_double_quant=True,
  52. )
  53. model = AutoModelForCausalLM.from_pretrained(
  54. MODEL_NAME,
  55. device_map="cuda", # 🔥 SAFE
  56. quantization_config=bnb_config,
  57. low_cpu_mem_usage=True,
  58. trust_remote_code=True,
  59. )
  60. print("Model loaded successfully in 4-bit mode on GPU.")
  61. # ----------------------------
  62. # [3/7] Prepare model for k-bit training
  63. # ----------------------------
  64. print(f"{80 * '_'}\n[3/7] Preparing model for k-bit training...")
  65. model = prepare_model_for_kbit_training(model)
  66. model.gradient_checkpointing_enable(
  67. gradient_checkpointing_kwargs={"use_reentrant": False}
  68. )
  69. print("Model prepared for k-bit training.")
  70. print("Gradient checkpointing enabled (non-reentrant).")
  71. # ----------------------------
  72. # [4/7] LoRA configuration
  73. # ----------------------------
  74. print(f"{80 * '_'}\n[4/7] Configuring LoRA adapters...")
  75. lora_config = LoraConfig(
  76. r=32,
  77. lora_alpha=64,
  78. lora_dropout=0.02,
  79. bias="none",
  80. task_type="CAUSAL_LM",
  81. target_modules=[
  82. "q_proj",
  83. "k_proj",
  84. "v_proj",
  85. "o_proj",
  86. "gate_proj",
  87. "up_proj",
  88. "down_proj",
  89. ],
  90. )
  91. model = get_peft_model(model, lora_config)
  92. model.print_trainable_parameters()
  93. print("LoRA adapters successfully attached.")
  94. # ----------------------------
  95. # [5/7] Dataset loading & formatting
  96. # ----------------------------
  97. print(f"{80 * '_'}\n[5/7] Loading dataset from JSON file...")
  98. dataset = load_dataset("json", data_files=DATA_FILE)
  99. print(f"Dataset loaded with {len(dataset['train'])} samples.")
  100. print("Formatting dataset for Ukrainian → French translation...")
  101. def format_prompt(example):
  102. return {
  103. "text": ("<|user|>\n"
  104. "Translate the following Ukrainian text into French.\n"
  105. f"Ukrainian: {example['text']}\n"
  106. "<|assistant|>\n"
  107. f"{example['translation']}"
  108. )
  109. }
  110. dataset = dataset.map(
  111. format_prompt,
  112. remove_columns=dataset["train"].column_names
  113. )
  114. print("Dataset formatting completed.")
  115. print("Example prompt:\n")
  116. print(dataset["train"][0]["text"])
  117. # ----------------------------
  118. # [6/7] Training arguments
  119. # ----------------------------
  120. print(f"{80 * '_'}\n[6/7] Initializing training arguments...")
  121. training_args = TrainingArguments(
  122. output_dir=OUTPUT_DIR,
  123. per_device_train_batch_size=1,
  124. gradient_accumulation_steps=8,
  125. learning_rate=1e-4,
  126. num_train_epochs=3,
  127. fp16=False,
  128. bf16=False,
  129. optim="paged_adamw_32bit",
  130. logging_steps=10,
  131. save_steps=500,
  132. save_total_limit=2,
  133. report_to="none",
  134. dataloader_pin_memory=False,
  135. )
  136. print("Training arguments ready.")
  137. print(f"Output directory: {OUTPUT_DIR}")
  138. print(f"Epochs: {training_args.num_train_epochs}")
  139. print(f"Effective batch size: "
  140. f"{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}"
  141. )
  142. # ----------------------------
  143. # Trainer
  144. # ----------------------------
  145. print("Initializing SFTTrainer...")
  146. trainer = SFTTrainer(
  147. model=model,
  148. train_dataset=dataset["train"],
  149. processing_class=tokenizer,
  150. args=training_args,
  151. )
  152. print("Trainer initialized.")
  153. # ----------------------------
  154. # [7/7] Training
  155. # ----------------------------
  156. print(f"{80 * '_'}\n[7/7] Starting training...")
  157. try:
  158. train_output = trainer.train(resume_from_checkpoint=True)
  159. except Exception as e:
  160. print("No checkpoint found or resume failed, starting fresh training.")
  161. print(f"Reason: {e}")
  162. train_output = trainer.train()
  163. print("\n=== Training summary ===")
  164. print(f"Global steps: {train_output.global_step}")
  165. print(f"Training loss: {train_output.training_loss}")
  166. print(f"Metrics: {train_output.metrics}")
  167. print("Training completed successfully.")
  168. # ----------------------------
  169. # Save LoRA adapter
  170. # ----------------------------
  171. print(f"{80 * '_'}\nSaving LoRA adapter and tokenizer...")
  172. trainer.model.save_pretrained(OUTPUT_DIR)
  173. tokenizer.save_pretrained(OUTPUT_DIR)
  174. print("\n=== Fine-tuning finished ===")
  175. print(f"LoRA adapter saved in: {OUTPUT_DIR}")