Script python permettant de traduire un long texte
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

finetunning.py 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import os
  2. import torch
  3. from datasets import load_dataset
  4. from transformers import (
  5. AutoTokenizer,
  6. AutoModelForCausalLM,
  7. TrainingArguments,
  8. BitsAndBytesConfig,
  9. )
  10. from peft import (
  11. LoraConfig,
  12. get_peft_model,
  13. prepare_model_for_kbit_training,
  14. )
  15. from trl import SFTTrainer
  16. # ----------------------------
  17. # Environment safety (Windows)
  18. # ----------------------------
  19. os.environ["TORCHDYNAMO_DISABLE"] = "1"
  20. # ----------------------------
  21. # Global configuration
  22. # ----------------------------
  23. MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
  24. OUTPUT_DIR = "./qwen2.5-7b-uk-fr-lora"
  25. DATA_FILE = "paires_clean.json"
  26. MAX_SEQ_LENGTH = 1024
  27. print(f"\n=== Starting fine-tuning script for {MODEL_NAME} ===\n")
  28. # ----------------------------
  29. # [1/7] Tokenizer
  30. # ----------------------------
  31. print(f"{80 * '_'}\n[1/7] Loading tokenizer...")
  32. tokenizer = AutoTokenizer.from_pretrained(
  33. MODEL_NAME,
  34. trust_remote_code=True
  35. )
  36. tokenizer.pad_token = tokenizer.eos_token
  37. tokenizer.model_max_length = MAX_SEQ_LENGTH
  38. print("Tokenizer loaded.")
  39. print(f"Pad token id: {tokenizer.pad_token_id}")
  40. print(f"Max sequence length: {tokenizer.model_max_length}")
  41. # ----------------------------
  42. # [2/7] Quantization config (QLoRA)
  43. # ----------------------------
  44. print(f"{80 * '_'}\n[2/7] Configuring 4-bit quantization (BitsAndBytes)...")
  45. bnb_config = BitsAndBytesConfig(
  46. load_in_4bit=True,
  47. bnb_4bit_quant_type="nf4",
  48. bnb_4bit_compute_dtype=torch.float16,
  49. bnb_4bit_use_double_quant=True,
  50. )
  51. print("4-bit NF4 quantization configured.")
  52. print("Loading model...")
  53. model = AutoModelForCausalLM.from_pretrained(
  54. MODEL_NAME,
  55. device_map="auto",
  56. quantization_config=bnb_config,
  57. dtype=torch.float16,
  58. trust_remote_code=True,
  59. )
  60. print("Model loaded successfully.")
  61. # ----------------------------
  62. # [3/7] Prepare model for k-bit training
  63. # ----------------------------
  64. print(f"{80 * '_'}\n[3/7] Preparing model for k-bit training...")
  65. model = prepare_model_for_kbit_training(model)
  66. model.gradient_checkpointing_enable(
  67. gradient_checkpointing_kwargs={"use_reentrant": False}
  68. )
  69. print("Model prepared for k-bit training.")
  70. print("Gradient checkpointing enabled (non-reentrant).")
  71. # ----------------------------
  72. # [4/7] LoRA configuration
  73. # ----------------------------
  74. print(f"{80 * '_'}\n[4/7] Configuring LoRA adapters...")
  75. lora_config = LoraConfig(
  76. r=32,
  77. lora_alpha=64,
  78. lora_dropout=0.02,
  79. bias="none",
  80. task_type="CAUSAL_LM",
  81. target_modules=[
  82. "q_proj",
  83. "k_proj",
  84. "v_proj",
  85. "o_proj",
  86. "gate_proj",
  87. "up_proj",
  88. "down_proj",
  89. ],
  90. )
  91. model = get_peft_model(model, lora_config)
  92. model.print_trainable_parameters()
  93. print("LoRA adapters successfully attached.")
  94. # ----------------------------
  95. # [5/7] Dataset loading & formatting
  96. # ----------------------------
  97. print(f"{80 * '_'}\n[5/7] Loading dataset from JSON file...")
  98. dataset = load_dataset("json", data_files=DATA_FILE)
  99. print(f"Dataset loaded with {len(dataset['train'])} samples.")
  100. print("Formatting dataset for Ukrainian → French translation...")
  101. def format_prompt(example):
  102. return {
  103. "text": (
  104. "<|user|>\n"
  105. "Translate the following Ukrainian text into French.\n"
  106. f"Ukrainian: {example['text']}\n"
  107. "<|assistant|>\n"
  108. f"{example['translation']}"
  109. )
  110. }
  111. dataset = dataset.map(
  112. format_prompt,
  113. remove_columns=dataset["train"].column_names
  114. )
  115. print("Dataset formatting completed.")
  116. print("Example prompt:\n")
  117. print(dataset["train"][0]["text"])
  118. # ----------------------------
  119. # [6/7] Training arguments
  120. # ----------------------------
  121. print(f"{80 * '_'}\n[6/7] Initializing training arguments...")
  122. training_args = TrainingArguments(
  123. output_dir=OUTPUT_DIR,
  124. per_device_train_batch_size=1,
  125. gradient_accumulation_steps=8,
  126. learning_rate=1e-4,
  127. num_train_epochs=3,
  128. fp16=False,
  129. bf16=False,
  130. optim="paged_adamw_32bit",
  131. logging_steps=10,
  132. save_steps=500,
  133. save_total_limit=2,
  134. report_to="none",
  135. )
  136. print("Training arguments ready.")
  137. print(f"Output directory: {OUTPUT_DIR}")
  138. print(f"Epochs: {training_args.num_train_epochs}")
  139. print(
  140. f"Effective batch size: "
  141. f"{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}"
  142. )
  143. # ----------------------------
  144. # Trainer
  145. # ----------------------------
  146. print("Initializing SFTTrainer...")
  147. trainer = SFTTrainer(
  148. model=model,
  149. train_dataset=dataset["train"],
  150. tokenizer=tokenizer,
  151. args=training_args,
  152. )
  153. print("Trainer initialized.")
  154. # ----------------------------
  155. # [7/7] Training
  156. # ----------------------------
  157. print(f"{80 * '_'}\n[7/7] Starting training...")
  158. try:
  159. trainer.train(resume_from_checkpoint=True)
  160. except Exception as e:
  161. print("No checkpoint found or resume failed, starting fresh training.")
  162. print(f"Reason: {e}")
  163. trainer.train()
  164. print("Training completed successfully.")
  165. # ----------------------------
  166. # Save LoRA adapter
  167. # ----------------------------
  168. print(f"{80 * '_'}\nSaving LoRA adapter and tokenizer...")
  169. trainer.model.save_pretrained(OUTPUT_DIR)
  170. tokenizer.save_pretrained(OUTPUT_DIR)
  171. print("\n=== Fine-tuning finished ===")
  172. print(f"LoRA adapter saved in: {OUTPUT_DIR}")