Script python permettant de traduire un long texte
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

finetunning.py 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import os
  2. import torch
  3. from datasets import load_dataset
  4. from transformers import (
  5. AutoTokenizer,
  6. AutoModelForCausalLM,
  7. TrainingArguments,
  8. )
  9. from peft import (
  10. LoraConfig,
  11. get_peft_model,
  12. prepare_model_for_kbit_training,
  13. )
  14. from trl import SFTTrainer
  15. # ----------------------------
  16. # Environment safety (Windows)
  17. # ----------------------------
  18. os.environ["TORCHDYNAMO_DISABLE"] = "1"
  19. # ----------------------------
  20. # Model configuration
  21. # ----------------------------
  22. MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
  23. print(f"=== Starting fine-tuning script {MODEL_NAME} ===")
  24. print(f"{80 * '_'}\n[1/7] Loading tokenizer...")
  25. tokenizer = AutoTokenizer.from_pretrained(
  26. MODEL_NAME,
  27. trust_remote_code=True
  28. )
  29. # Ensure padding is defined
  30. tokenizer.pad_token = tokenizer.eos_token
  31. tokenizer.model_max_length = 1024
  32. print("Tokenizer loaded and configured.")
  33. print(f"{80 * '_'}\n[2/7] Loading model in 4-bit mode (QLoRA)...")
  34. model = AutoModelForCausalLM.from_pretrained(
  35. MODEL_NAME,
  36. load_in_4bit=True,
  37. device_map="auto",
  38. torch_dtype=torch.float16, # weights in fp16, gradients fp32
  39. trust_remote_code=True,
  40. )
  41. print("Model loaded.")
  42. print(f"{80 * '_'}\n[3/7] Preparing model for k-bit training...")
  43. model = prepare_model_for_kbit_training(model)
  44. # Fix future PyTorch checkpointing behavior
  45. model.gradient_checkpointing_enable(
  46. gradient_checkpointing_kwargs={"use_reentrant": False}
  47. )
  48. print("Model prepared for k-bit training.")
  49. # ----------------------------
  50. # LoRA configuration
  51. # ----------------------------
  52. print(f"{80 * '_'}\n[4/7] Configuring LoRA adapters...")
  53. lora_config = LoraConfig(
  54. r=16,
  55. lora_alpha=32,
  56. lora_dropout=0.05,
  57. bias="none",
  58. task_type="CAUSAL_LM",
  59. target_modules=[
  60. "q_proj",
  61. "k_proj",
  62. "v_proj",
  63. "o_proj",
  64. "gate_proj",
  65. "up_proj",
  66. "down_proj",
  67. ],
  68. )
  69. model = get_peft_model(model, lora_config)
  70. model.print_trainable_parameters()
  71. print("LoRA adapters attached.")
  72. # ----------------------------
  73. # Dataset loading
  74. # ----------------------------
  75. print(f"{80 * '_'}\n[5/7] Loading dataset from JSON file...")
  76. dataset = load_dataset(
  77. "json",
  78. data_files="traductions.json"
  79. )
  80. print(f"Dataset loaded with {len(dataset['train'])} samples.")
  81. print("Formatting dataset for Ukrainian → French translation...")
  82. def format_prompt(example):
  83. prompt = (
  84. "Translate the following Ukrainian text into French.\n\n"
  85. f"Ukrainian: {example['text']}\n"
  86. f"French: {example['translation']}"
  87. )
  88. return {"text": prompt}
  89. dataset = dataset.map(
  90. format_prompt,
  91. remove_columns=dataset["train"].column_names
  92. )
  93. print("Dataset formatting completed.")
  94. # ----------------------------
  95. # Training arguments (AMP OFF)
  96. # ----------------------------
  97. print(f"{80 * '_'}\n[6/7] Initializing training arguments...")
  98. training_args = TrainingArguments(
  99. output_dir="./qwen2.5-7b-uk-fr-lora",
  100. per_device_train_batch_size=1,
  101. gradient_accumulation_steps=8,
  102. learning_rate=2e-4,
  103. num_train_epochs=2, # 2 epochs usually enough for translation
  104. fp16=False,
  105. bf16=False,
  106. logging_steps=10,
  107. save_steps=500,
  108. save_total_limit=2,
  109. optim="paged_adamw_32bit",
  110. report_to="none",
  111. )
  112. print("Training arguments ready.")
  113. # ----------------------------
  114. # Trainer
  115. # ----------------------------
  116. print("Initializing SFTTrainer...")
  117. trainer = SFTTrainer(
  118. model=model,
  119. train_dataset=dataset["train"],
  120. processing_class=tokenizer,
  121. args=training_args,
  122. )
  123. print("Trainer initialized.")
  124. # ----------------------------
  125. # Train
  126. # ----------------------------
  127. print(f"{80 * '_'}\n[7/7] Starting training...")
  128. trainer.train(resume_from_checkpoint=True)
  129. print("Training completed successfully.")
  130. # ----------------------------
  131. # Save LoRA adapter
  132. # ----------------------------
  133. print("Saving LoRA adapter and tokenizer...")
  134. trainer.model.save_pretrained("./qwen2.5-7b-uk-fr-lora")
  135. tokenizer.save_pretrained("./qwen2.5-7b-uk-fr-lora")
  136. print("=== Fine-tuning finished ===")
  137. print("LoRA adapter saved in ./qwen2.5-7b-uk-fr-lora")