Script python permettant de traduire un long texte
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

mergeLora.py 1.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import torch
  2. from transformers import AutoModelForCausalLM, AutoTokenizer
  3. from peft import PeftModel
  4. # ----------------------------
  5. # Configuration
  6. # ----------------------------
  7. BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
  8. LORA_DIR = "./qwen2.5-7b-uk-fr-lora" # dossier issu du fine-tuning
  9. OUTPUT_DIR = "./qwen2.5-7b-uk-fr-merged" # modèle fusionné final
  10. DTYPE = torch.float16 # GGUF-friendly
  11. DEVICE = "cpu" # merge sur CPU (stable, sûr)
  12. print("=== LoRA merge script started ===")
  13. # ----------------------------
  14. # Load base model
  15. # ----------------------------
  16. print(f"{80 * '_'}\n[1/4] Loading base model...")
  17. base_model = AutoModelForCausalLM.from_pretrained(
  18. BASE_MODEL,
  19. torch_dtype=DTYPE,
  20. device_map=DEVICE,
  21. trust_remote_code=True,
  22. )
  23. print("Base model loaded.")
  24. # ----------------------------
  25. # Load tokenizer
  26. # ----------------------------
  27. print(f"{80 * '_'}\n[2/4] Loading tokenizer...")
  28. tokenizer = AutoTokenizer.from_pretrained(
  29. BASE_MODEL,
  30. trust_remote_code=True
  31. )
  32. tokenizer.pad_token = tokenizer.eos_token
  33. print("Tokenizer loaded.")
  34. # ----------------------------
  35. # Load LoRA adapter
  36. # ----------------------------
  37. print(f"{80 * '_'}\n[3/4] Loading LoRA adapter...")
  38. model = PeftModel.from_pretrained(
  39. base_model,
  40. LORA_DIR,
  41. )
  42. print("LoRA adapter loaded.")
  43. # ----------------------------
  44. # Merge LoRA into base model
  45. # ----------------------------
  46. print(f"{80 * '_'}\n[4/4] Merging LoRA into base model...")
  47. model = model.merge_and_unload()
  48. print("LoRA successfully merged.")
  49. # ----------------------------
  50. # Save merged model
  51. # ----------------------------
  52. print("Saving merged model...")
  53. model.save_pretrained(
  54. OUTPUT_DIR,
  55. safe_serialization=True,
  56. )
  57. tokenizer.save_pretrained(OUTPUT_DIR)
  58. print("=== Merge completed successfully ===")
  59. print(f"Merged model saved in: {OUTPUT_DIR}")