Script python permettant de traduire un long texte
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

finetunning.py 3.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import torch
  2. from datasets import load_dataset
  3. from transformers import (
  4. AutoTokenizer,
  5. AutoModelForCausalLM,
  6. TrainingArguments,
  7. )
  8. from peft import (
  9. LoraConfig,
  10. get_peft_model,
  11. prepare_model_for_kbit_training,
  12. )
  13. from trl import SFTTrainer
  14. import os
  15. os.environ["TORCHDYNAMO_DISABLE"] = "1"
  16. # ----------------------------
  17. # Model configuration
  18. # ----------------------------
  19. MODEL_NAME = "Qwen/Qwen2.5-14B-Instruct"
  20. print("=== Starting fine-tuning script ===")
  21. print(f"{80 * '_'}\n[1/7] Loading tokenizer...")
  22. tokenizer = AutoTokenizer.from_pretrained(
  23. MODEL_NAME,
  24. trust_remote_code=True
  25. )
  26. # Ensure padding token is defined
  27. tokenizer.pad_token = tokenizer.eos_token
  28. tokenizer.model_max_length = 1024
  29. print("Tokenizer loaded and configured.")
  30. print(f"{80 * '_'}\n[2/7] Loading model in 4-bit mode (QLoRA)...")
  31. model = AutoModelForCausalLM.from_pretrained(
  32. MODEL_NAME,
  33. load_in_4bit=True,
  34. device_map="auto",
  35. torch_dtype=torch.float16, # OK for weights
  36. trust_remote_code=True,
  37. )
  38. print("Model loaded.")
  39. print(f"{80 * '_'}\n[3/7] Preparing model for k-bit training...")
  40. model = prepare_model_for_kbit_training(model)
  41. print("Model prepared for k-bit training.")
  42. # ----------------------------
  43. # LoRA configuration
  44. # ----------------------------
  45. print(f"{80 * '_'}\n[4/7] Configuring LoRA adapters...")
  46. lora_config = LoraConfig(
  47. r=16,
  48. lora_alpha=32,
  49. lora_dropout=0.05,
  50. bias="none",
  51. task_type="CAUSAL_LM",
  52. target_modules=[
  53. "q_proj",
  54. "k_proj",
  55. "v_proj",
  56. "o_proj",
  57. "gate_proj",
  58. "up_proj",
  59. "down_proj",
  60. ],
  61. )
  62. model = get_peft_model(model, lora_config)
  63. model.print_trainable_parameters()
  64. print("LoRA adapters attached to the model.")
  65. # ----------------------------
  66. # Dataset loading
  67. # ----------------------------
  68. print(f"{80 * '_'}\n[5/7] Loading dataset from JSON file...")
  69. dataset = load_dataset(
  70. "json",
  71. data_files="traductions.json"
  72. )
  73. print(f"Dataset loaded with {len(dataset['train'])} samples.")
  74. print("Formatting dataset for Ukrainian → French translation...")
  75. def format_prompt(example):
  76. prompt = (
  77. "Translate the following Ukrainian text into French.\n\n"
  78. f"Ukrainian: {example['text']}\n"
  79. f"French: {example['translation']}"
  80. )
  81. return {"text": prompt}
  82. dataset = dataset.map(format_prompt, remove_columns=dataset["train"].column_names)
  83. print("Dataset formatting completed.")
  84. # ----------------------------
  85. # Training arguments
  86. # ----------------------------
  87. print(f"{80 * '_'}\n[6/7] Initializing training arguments...")
  88. training_args = TrainingArguments(
  89. output_dir="./qwen-uk-fr-lora",
  90. per_device_train_batch_size=1,
  91. gradient_accumulation_steps=8,
  92. learning_rate=2e-4,
  93. num_train_epochs=3,
  94. fp16=False,
  95. bf16=False,
  96. logging_steps=10,
  97. save_steps=500,
  98. save_total_limit=2,
  99. # Use 32-bit optimizer
  100. optim="paged_adamw_32bit",
  101. report_to="none",
  102. )
  103. print("Training arguments ready.")
  104. # ----------------------------
  105. # Trainer
  106. # ----------------------------
  107. print("Initializing SFTTrainer...")
  108. trainer = SFTTrainer(
  109. model=model,
  110. train_dataset=dataset["train"],
  111. processing_class=tokenizer,
  112. args=training_args,
  113. )
  114. print("Trainer initialized.")
  115. # ----------------------------
  116. # Train
  117. # ----------------------------
  118. print(f"{80 * '_'}\n[7/7] Starting training...")
  119. trainer.train()
  120. print("Training completed successfully.")
  121. # ----------------------------
  122. # Save LoRA adapter
  123. # ----------------------------
  124. print("Saving LoRA adapter and tokenizer...")
  125. trainer.model.save_pretrained("./qwen-uk-fr-lora")
  126. tokenizer.save_pretrained("./qwen-uk-fr-lora")
  127. print("=== Fine-tuning finished ===")
  128. print("LoRA adapter saved in ./qwen-uk-fr-lora")