Script python permettant de traduire un long texte
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

finetunning.py 4.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import os
  2. import torch
  3. from datasets import load_dataset
  4. from transformers import (
  5. AutoTokenizer,
  6. AutoModelForCausalLM,
  7. TrainingArguments,
  8. BitsAndBytesConfig
  9. )
  10. from peft import (
  11. LoraConfig,
  12. get_peft_model,
  13. prepare_model_for_kbit_training,
  14. )
  15. from trl import SFTTrainer
  16. # ----------------------------
  17. # Environment safety (Windows)
  18. # ----------------------------
  19. os.environ["TORCHDYNAMO_DISABLE"] = "1"
  20. # ----------------------------
  21. # Global configuration
  22. # ----------------------------
  23. MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
  24. OUTPUT_DIR = "./qwen2.5-7b-uk-fr-lora"
  25. DATA_FILE = "paires_clean.json"
  26. MAX_SEQ_LENGTH = 1024
  27. print(f"\n=== Starting fine-tuning script for {MODEL_NAME} ===\n")
  28. # ----------------------------
  29. # [1/7] Tokenizer
  30. # ----------------------------
  31. print(f"{80 * '_'}\n[1/7] Loading tokenizer...")
  32. tokenizer = AutoTokenizer.from_pretrained(
  33. MODEL_NAME,
  34. trust_remote_code=True
  35. )
  36. tokenizer.pad_token = tokenizer.eos_token
  37. tokenizer.model_max_length = MAX_SEQ_LENGTH
  38. print("Tokenizer loaded.")
  39. print(f"Pad token id: {tokenizer.pad_token_id}")
  40. print(f"Max sequence length: {tokenizer.model_max_length}")
  41. # ----------------------------
  42. # [2/7] Model loading (QLoRA)
  43. # ----------------------------
  44. print(f"{80 * '_'}\n[2/7] Loading model in 4-bit mode (QLoRA)...")
  45. model = AutoModelForCausalLM.from_pretrained(
  46. MODEL_NAME,
  47. load_in_4bit=True,
  48. device_map="auto",
  49. dtype=torch.float16,
  50. trust_remote_code=True,
  51. )
  52. print("Model loaded.")
  53. # ----------------------------
  54. # [3/7] Prepare model for k-bit training
  55. # ----------------------------
  56. print(f"{80 * '_'}\n[3/7] Preparing model for k-bit training...")
  57. model = prepare_model_for_kbit_training(model)
  58. model.gradient_checkpointing_enable(
  59. gradient_checkpointing_kwargs={"use_reentrant": False}
  60. )
  61. print("Model prepared for k-bit training.")
  62. print("Gradient checkpointing enabled (non-reentrant).")
  63. # ----------------------------
  64. # [4/7] LoRA configuration
  65. # ----------------------------
  66. print(f"{80 * '_'}\n[4/7] Configuring LoRA adapters...")
  67. lora_config = LoraConfig(
  68. r=32,
  69. lora_alpha=64,
  70. lora_dropout=0.02,
  71. bias="none",
  72. task_type="CAUSAL_LM",
  73. target_modules=[
  74. "q_proj", "k_proj", "v_proj", "o_proj",
  75. "gate_proj", "up_proj", "down_proj"
  76. ],
  77. )
  78. model = get_peft_model(model, lora_config)
  79. model.print_trainable_parameters()
  80. print("LoRA adapters successfully attached.")
  81. # ----------------------------
  82. # [5/7] Dataset loading & formatting
  83. # ----------------------------
  84. print(f"{80 * '_'}\n[5/7] Loading dataset from JSON file...")
  85. dataset = load_dataset(
  86. "json",
  87. data_files=DATA_FILE
  88. )
  89. print(f"Dataset loaded with {len(dataset['train'])} samples.")
  90. print("Formatting dataset for Ukrainian → French translation...")
  91. def format_prompt(example):
  92. return {
  93. "text": (
  94. "<|user|>\n"
  95. "Translate the following Ukrainian text into French.\n"
  96. f"Ukrainian: {example['text']}\n"
  97. "<|assistant|>\n"
  98. f"{example['translation']}"
  99. )
  100. }
  101. dataset = dataset.map(
  102. format_prompt,
  103. remove_columns=dataset["train"].column_names
  104. )
  105. print("Dataset formatting completed.")
  106. print(f"Example prompt:\n{dataset['train'][0]['text']}")
  107. # ----------------------------
  108. # [6/7] Training arguments
  109. # ----------------------------
  110. print(f"{80 * '_'}\n[6/7] Initializing training arguments...")
  111. training_args = TrainingArguments(
  112. output_dir=OUTPUT_DIR,
  113. per_device_train_batch_size=1,
  114. gradient_accumulation_steps=8,
  115. learning_rate=1e-4,
  116. num_train_epochs=3,
  117. fp16=False,
  118. bf16=False,
  119. optim="paged_adamw_32bit",
  120. logging_steps=10,
  121. save_steps=500,
  122. save_total_limit=2,
  123. report_to="none",
  124. )
  125. print("Training arguments ready.")
  126. print(f"Output directory: {OUTPUT_DIR}")
  127. print(f"Epochs: {training_args.num_train_epochs}")
  128. print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
  129. # ----------------------------
  130. # Trainer
  131. # ----------------------------
  132. print("Initializing SFTTrainer...")
  133. trainer = SFTTrainer(
  134. model=model,
  135. train_dataset=dataset["train"],
  136. tokenizer=tokenizer,
  137. args=training_args,
  138. )
  139. print("Trainer initialized.")
  140. # ----------------------------
  141. # [7/7] Training
  142. # ----------------------------
  143. print(f"{80 * '_'}\n[7/7] Starting training...")
  144. try:
  145. trainer.train(resume_from_checkpoint=True)
  146. except Exception as e:
  147. print("No checkpoint found or resume failed, starting fresh training.")
  148. print(f"Reason: {e}")
  149. trainer.train()
  150. print("Training completed successfully.")
  151. # ----------------------------
  152. # Save LoRA adapter
  153. # ----------------------------
  154. print(f"{80 * '_'}\nSaving LoRA adapter and tokenizer...")
  155. trainer.model.save_pretrained(OUTPUT_DIR)
  156. tokenizer.save_pretrained(OUTPUT_DIR)
  157. print("\n=== Fine-tuning finished ===")
  158. print(f"LoRA adapter saved in: {OUTPUT_DIR}")