optimisation
This commit is contained in:
@@ -5,7 +5,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
TrainingArguments,
|
||||
BitsAndBytesConfig
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
@@ -46,17 +46,27 @@ print(f"Pad token id: {tokenizer.pad_token_id}")
|
||||
print(f"Max sequence length: {tokenizer.model_max_length}")
|
||||
|
||||
# ----------------------------
|
||||
# [2/7] Model loading (QLoRA)
|
||||
# [2/7] Quantization config (QLoRA)
|
||||
# ----------------------------
|
||||
print(f"{80 * '_'}\n[2/7] Loading model in 4-bit mode (QLoRA)...")
|
||||
print(f"{80 * '_'}\n[2/7] Configuring 4-bit quantization (BitsAndBytes)...")
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
print("4-bit NF4 quantization configured.")
|
||||
|
||||
print("Loading model...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
load_in_4bit=True,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config,
|
||||
dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
print("Model loaded.")
|
||||
print("Model loaded successfully.")
|
||||
|
||||
# ----------------------------
|
||||
# [3/7] Prepare model for k-bit training
|
||||
@@ -82,10 +92,14 @@ lora_config = LoraConfig(
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
@@ -97,10 +111,7 @@ print("LoRA adapters successfully attached.")
|
||||
# [5/7] Dataset loading & formatting
|
||||
# ----------------------------
|
||||
print(f"{80 * '_'}\n[5/7] Loading dataset from JSON file...")
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files=DATA_FILE
|
||||
)
|
||||
dataset = load_dataset("json", data_files=DATA_FILE)
|
||||
|
||||
print(f"Dataset loaded with {len(dataset['train'])} samples.")
|
||||
|
||||
@@ -123,7 +134,8 @@ dataset = dataset.map(
|
||||
)
|
||||
|
||||
print("Dataset formatting completed.")
|
||||
print(f"Example prompt:\n{dataset['train'][0]['text']}")
|
||||
print("Example prompt:\n")
|
||||
print(dataset["train"][0]["text"])
|
||||
|
||||
# ----------------------------
|
||||
# [6/7] Training arguments
|
||||
@@ -147,7 +159,10 @@ training_args = TrainingArguments(
|
||||
print("Training arguments ready.")
|
||||
print(f"Output directory: {OUTPUT_DIR}")
|
||||
print(f"Epochs: {training_args.num_train_epochs}")
|
||||
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
|
||||
print(
|
||||
f"Effective batch size: "
|
||||
f"{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}"
|
||||
)
|
||||
|
||||
# ----------------------------
|
||||
# Trainer
|
||||
|
||||
Reference in New Issue
Block a user