| # ---------------------------- | # ---------------------------- | ||||
| # [2/7] Quantization config (QLoRA) | # [2/7] Quantization config (QLoRA) | ||||
| # ---------------------------- | # ---------------------------- | ||||
| print(f"{80 * '_'}\n[2/7] Configuring 4-bit quantization (BitsAndBytes)...") | |||||
| print(f"{80 * '_'}\n[2/7] Loading model in 4-bit mode (optimized QLoRA)...") | |||||
| assert torch.cuda.is_available(), "CUDA GPU not detected!" | |||||
| print(f"Using GPU: {torch.cuda.get_device_name(0)}") | |||||
| bnb_config = BitsAndBytesConfig( | bnb_config = BitsAndBytesConfig( | ||||
| load_in_4bit=True, | load_in_4bit=True, | ||||
| bnb_4bit_quant_type="nf4", | bnb_4bit_quant_type="nf4", | ||||
| bnb_4bit_use_double_quant=True, | bnb_4bit_use_double_quant=True, | ||||
| ) | ) | ||||
| print("4-bit NF4 quantization configured.") | |||||
| print("Loading model...") | |||||
| model = AutoModelForCausalLM.from_pretrained( | model = AutoModelForCausalLM.from_pretrained( | ||||
| MODEL_NAME, | MODEL_NAME, | ||||
| device_map="auto", | |||||
| device_map="cuda", # 🔥 SAFE | |||||
| quantization_config=bnb_config, | quantization_config=bnb_config, | ||||
| dtype=torch.float16, | |||||
| low_cpu_mem_usage=True, | |||||
| trust_remote_code=True, | trust_remote_code=True, | ||||
| ) | ) | ||||
| print("Model loaded successfully.") | |||||
| print("Model loaded successfully in 4-bit mode on GPU.") | |||||
| # ---------------------------- | # ---------------------------- | ||||
| # [3/7] Prepare model for k-bit training | # [3/7] Prepare model for k-bit training | ||||
| def format_prompt(example): | def format_prompt(example): | ||||
| return { | return { | ||||
| "text": ( | |||||
| "<|user|>\n" | |||||
| "text": ("<|user|>\n" | |||||
| "Translate the following Ukrainian text into French.\n" | "Translate the following Ukrainian text into French.\n" | ||||
| f"Ukrainian: {example['text']}\n" | f"Ukrainian: {example['text']}\n" | ||||
| "<|assistant|>\n" | "<|assistant|>\n" | ||||
| save_steps=500, | save_steps=500, | ||||
| save_total_limit=2, | save_total_limit=2, | ||||
| report_to="none", | report_to="none", | ||||
| dataloader_pin_memory=False, | |||||
| ) | ) | ||||
| print("Training arguments ready.") | print("Training arguments ready.") | ||||
| print(f"Output directory: {OUTPUT_DIR}") | print(f"Output directory: {OUTPUT_DIR}") | ||||
| print(f"Epochs: {training_args.num_train_epochs}") | print(f"Epochs: {training_args.num_train_epochs}") | ||||
| print( | |||||
| f"Effective batch size: " | |||||
| print(f"Effective batch size: " | |||||
| f"{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}" | f"{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}" | ||||
| ) | ) | ||||
| trainer = SFTTrainer( | trainer = SFTTrainer( | ||||
| model=model, | model=model, | ||||
| train_dataset=dataset["train"], | train_dataset=dataset["train"], | ||||
| tokenizer=tokenizer, | |||||
| processing_class=tokenizer, | |||||
| args=training_args, | args=training_args, | ||||
| ) | ) | ||||
| print("Trainer initialized.") | print("Trainer initialized.") | ||||
| # ---------------------------- | # ---------------------------- | ||||
| print(f"{80 * '_'}\n[7/7] Starting training...") | print(f"{80 * '_'}\n[7/7] Starting training...") | ||||
| try: | try: | ||||
| trainer.train(resume_from_checkpoint=True) | |||||
| train_output = trainer.train(resume_from_checkpoint=True) | |||||
| except Exception as e: | except Exception as e: | ||||
| print("No checkpoint found or resume failed, starting fresh training.") | print("No checkpoint found or resume failed, starting fresh training.") | ||||
| print(f"Reason: {e}") | print(f"Reason: {e}") | ||||
| trainer.train() | |||||
| train_output = trainer.train() | |||||
| print("\n=== Training summary ===") | |||||
| print(f"Global steps: {train_output.global_step}") | |||||
| print(f"Training loss: {train_output.training_loss}") | |||||
| print(f"Metrics: {train_output.metrics}") | |||||
| print("Training completed successfully.") | print("Training completed successfully.") | ||||
| # ---------------------------- | # ---------------------------- |