optimisation

This commit is contained in:
Alex
2026-01-14 23:54:12 +01:00
parent aee2716a41
commit 70e4932cd0
3 changed files with 62 additions and 21 deletions

View File

@@ -5,7 +5,7 @@ from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
BitsAndBytesConfig
BitsAndBytesConfig,
)
from peft import (
LoraConfig,
@@ -46,17 +46,27 @@ print(f"Pad token id: {tokenizer.pad_token_id}")
print(f"Max sequence length: {tokenizer.model_max_length}")
# ----------------------------
# [2/7] Model loading (QLoRA)
# [2/7] Quantization config (QLoRA)
# ----------------------------
print(f"{80 * '_'}\n[2/7] Loading model in 4-bit mode (QLoRA)...")
print(f"{80 * '_'}\n[2/7] Configuring 4-bit quantization (BitsAndBytes)...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
print("4-bit NF4 quantization configured.")
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
load_in_4bit=True,
device_map="auto",
quantization_config=bnb_config,
dtype=torch.float16,
trust_remote_code=True,
)
print("Model loaded.")
print("Model loaded successfully.")
# ----------------------------
# [3/7] Prepare model for k-bit training
@@ -82,10 +92,14 @@ lora_config = LoraConfig(
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
model = get_peft_model(model, lora_config)
@@ -97,10 +111,7 @@ print("LoRA adapters successfully attached.")
# [5/7] Dataset loading & formatting
# ----------------------------
print(f"{80 * '_'}\n[5/7] Loading dataset from JSON file...")
dataset = load_dataset(
"json",
data_files=DATA_FILE
)
dataset = load_dataset("json", data_files=DATA_FILE)
print(f"Dataset loaded with {len(dataset['train'])} samples.")
@@ -123,7 +134,8 @@ dataset = dataset.map(
)
print("Dataset formatting completed.")
print(f"Example prompt:\n{dataset['train'][0]['text']}")
print("Example prompt:\n")
print(dataset["train"][0]["text"])
# ----------------------------
# [6/7] Training arguments
@@ -147,7 +159,10 @@ training_args = TrainingArguments(
print("Training arguments ready.")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Epochs: {training_args.num_train_epochs}")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(
f"Effective batch size: "
f"{training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}"
)
# ----------------------------
# Trainer