| import os | |||||
| import torch | import torch | ||||
| from datasets import load_dataset | from datasets import load_dataset | ||||
| from transformers import ( | from transformers import ( | ||||
| prepare_model_for_kbit_training, | prepare_model_for_kbit_training, | ||||
| ) | ) | ||||
| from trl import SFTTrainer | from trl import SFTTrainer | ||||
| import os | |||||
| # ---------------------------- | |||||
| # Environment safety (Windows) | |||||
| # ---------------------------- | |||||
| os.environ["TORCHDYNAMO_DISABLE"] = "1" | os.environ["TORCHDYNAMO_DISABLE"] = "1" | ||||
| # ---------------------------- | # ---------------------------- | ||||
| # Model configuration | # Model configuration | ||||
| # ---------------------------- | # ---------------------------- | ||||
| MODEL_NAME = "Qwen/Qwen2.5-14B-Instruct" | |||||
| MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" | |||||
| print("=== Starting fine-tuning script ===") | |||||
| print(f"=== Starting fine-tuning script {MODEL_NAME} ===") | |||||
| print(f"{80 * '_'}\n[1/7] Loading tokenizer...") | print(f"{80 * '_'}\n[1/7] Loading tokenizer...") | ||||
| tokenizer = AutoTokenizer.from_pretrained( | tokenizer = AutoTokenizer.from_pretrained( | ||||
| trust_remote_code=True | trust_remote_code=True | ||||
| ) | ) | ||||
| # Ensure padding token is defined | |||||
| # Ensure padding is defined | |||||
| tokenizer.pad_token = tokenizer.eos_token | tokenizer.pad_token = tokenizer.eos_token | ||||
| tokenizer.model_max_length = 1024 | tokenizer.model_max_length = 1024 | ||||
| MODEL_NAME, | MODEL_NAME, | ||||
| load_in_4bit=True, | load_in_4bit=True, | ||||
| device_map="auto", | device_map="auto", | ||||
| torch_dtype=torch.float16, # OK for weights | |||||
| torch_dtype=torch.float16, # weights in fp16, gradients fp32 | |||||
| trust_remote_code=True, | trust_remote_code=True, | ||||
| ) | ) | ||||
| print("Model loaded.") | print("Model loaded.") | ||||
| print(f"{80 * '_'}\n[3/7] Preparing model for k-bit training...") | print(f"{80 * '_'}\n[3/7] Preparing model for k-bit training...") | ||||
| model = prepare_model_for_kbit_training(model) | model = prepare_model_for_kbit_training(model) | ||||
| # Fix future PyTorch checkpointing behavior | |||||
| model.gradient_checkpointing_enable( | |||||
| gradient_checkpointing_kwargs={"use_reentrant": False} | |||||
| ) | |||||
| print("Model prepared for k-bit training.") | print("Model prepared for k-bit training.") | ||||
| # ---------------------------- | # ---------------------------- | ||||
| model = get_peft_model(model, lora_config) | model = get_peft_model(model, lora_config) | ||||
| model.print_trainable_parameters() | model.print_trainable_parameters() | ||||
| print("LoRA adapters attached to the model.") | |||||
| print("LoRA adapters attached.") | |||||
| # ---------------------------- | # ---------------------------- | ||||
| # Dataset loading | # Dataset loading | ||||
| "json", | "json", | ||||
| data_files="traductions.json" | data_files="traductions.json" | ||||
| ) | ) | ||||
| print(f"Dataset loaded with {len(dataset['train'])} samples.") | print(f"Dataset loaded with {len(dataset['train'])} samples.") | ||||
| print("Formatting dataset for Ukrainian → French translation...") | print("Formatting dataset for Ukrainian → French translation...") | ||||
| ) | ) | ||||
| return {"text": prompt} | return {"text": prompt} | ||||
| dataset = dataset.map(format_prompt, remove_columns=dataset["train"].column_names) | |||||
| dataset = dataset.map( | |||||
| format_prompt, | |||||
| remove_columns=dataset["train"].column_names | |||||
| ) | |||||
| print("Dataset formatting completed.") | print("Dataset formatting completed.") | ||||
| # ---------------------------- | # ---------------------------- | ||||
| # Training arguments | |||||
| # Training arguments (AMP OFF) | |||||
| # ---------------------------- | # ---------------------------- | ||||
| print(f"{80 * '_'}\n[6/7] Initializing training arguments...") | print(f"{80 * '_'}\n[6/7] Initializing training arguments...") | ||||
| training_args = TrainingArguments( | training_args = TrainingArguments( | ||||
| output_dir="./qwen-uk-fr-lora", | |||||
| output_dir="./qwen2.5-7b-uk-fr-lora", | |||||
| per_device_train_batch_size=1, | per_device_train_batch_size=1, | ||||
| gradient_accumulation_steps=8, | gradient_accumulation_steps=8, | ||||
| learning_rate=2e-4, | learning_rate=2e-4, | ||||
| num_train_epochs=3, | |||||
| num_train_epochs=2, # 2 epochs usually enough for translation | |||||
| fp16=False, | fp16=False, | ||||
| bf16=False, | bf16=False, | ||||
| logging_steps=10, | logging_steps=10, | ||||
| save_steps=500, | save_steps=500, | ||||
| save_total_limit=2, | save_total_limit=2, | ||||
| # Use 32-bit optimizer | |||||
| optim="paged_adamw_32bit", | optim="paged_adamw_32bit", | ||||
| report_to="none", | report_to="none", | ||||
| ) | ) | ||||
| print("Training arguments ready.") | print("Training arguments ready.") | ||||
| # ---------------------------- | # ---------------------------- | ||||
| # Train | # Train | ||||
| # ---------------------------- | # ---------------------------- | ||||
| print(f"{80 * '_'}\n[7/7] Starting training...") | print(f"{80 * '_'}\n[7/7] Starting training...") | ||||
| trainer.train() | |||||
| trainer.train(resume_from_checkpoint=True) | |||||
| print("Training completed successfully.") | print("Training completed successfully.") | ||||
| # ---------------------------- | # ---------------------------- | ||||
| # Save LoRA adapter | # Save LoRA adapter | ||||
| # ---------------------------- | # ---------------------------- | ||||
| print("Saving LoRA adapter and tokenizer...") | print("Saving LoRA adapter and tokenizer...") | ||||
| trainer.model.save_pretrained("./qwen-uk-fr-lora") | |||||
| tokenizer.save_pretrained("./qwen-uk-fr-lora") | |||||
| trainer.model.save_pretrained("./qwen2.5-7b-uk-fr-lora") | |||||
| tokenizer.save_pretrained("./qwen2.5-7b-uk-fr-lora") | |||||
| print("=== Fine-tuning finished ===") | print("=== Fine-tuning finished ===") | ||||
| print("LoRA adapter saved in ./qwen-uk-fr-lora") | |||||
| print("LoRA adapter saved in ./qwen2.5-7b-uk-fr-lora") |