ne pas versioner le lora
This commit is contained in:
1
Finetunning/.gitignore
vendored
Normal file
1
Finetunning/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
qwen2.5*/
|
||||
68
Finetunning/mergeLora.py
Normal file
68
Finetunning/mergeLora.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
|
||||
# ----------------------------
|
||||
# Configuration
|
||||
# ----------------------------
|
||||
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
|
||||
LORA_DIR = "./qwen2.5-7b-uk-fr-lora" # dossier issu du fine-tuning
|
||||
OUTPUT_DIR = "./qwen2.5-7b-uk-fr-merged" # modèle fusionné final
|
||||
|
||||
DTYPE = torch.float16 # GGUF-friendly
|
||||
DEVICE = "cpu" # merge sur CPU (stable, sûr)
|
||||
|
||||
print("=== LoRA merge script started ===")
|
||||
|
||||
# ----------------------------
|
||||
# Load base model
|
||||
# ----------------------------
|
||||
print("[1/4] Loading base model...")
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
torch_dtype=DTYPE,
|
||||
device_map=DEVICE,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
print("Base model loaded.")
|
||||
|
||||
# ----------------------------
|
||||
# Load tokenizer
|
||||
# ----------------------------
|
||||
print("[2/4] Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
BASE_MODEL,
|
||||
trust_remote_code=True
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
print("Tokenizer loaded.")
|
||||
|
||||
# ----------------------------
|
||||
# Load LoRA adapter
|
||||
# ----------------------------
|
||||
print("[3/4] Loading LoRA adapter...")
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
LORA_DIR,
|
||||
)
|
||||
print("LoRA adapter loaded.")
|
||||
|
||||
# ----------------------------
|
||||
# Merge LoRA into base model
|
||||
# ----------------------------
|
||||
print("[4/4] Merging LoRA into base model...")
|
||||
model = model.merge_and_unload()
|
||||
print("LoRA successfully merged.")
|
||||
|
||||
# ----------------------------
|
||||
# Save merged model
|
||||
# ----------------------------
|
||||
print("Saving merged model...")
|
||||
model.save_pretrained(
|
||||
OUTPUT_DIR,
|
||||
safe_serialization=True,
|
||||
)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
|
||||
print("=== Merge completed successfully ===")
|
||||
print(f"Merged model saved in: {OUTPUT_DIR}")
|
||||
Reference in New Issue
Block a user