-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmerge_lora.py
32 lines (28 loc) · 1.04 KB
/
merge_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from internvl_chat.internvl.model.internvl_chat import InternVLChatModel
from transformers import AutoTokenizer
input_path = (
"checkpoints/web/internvl2_2b_1epoch-16batch_size-webqa-reranker-caption-lora"
)
output_path = (
"checkpoints/web/internvl2_2b_1epoch-16batch_size-webqa-reranker-caption-lora-merge"
)
print("Loading model...")
model = InternVLChatModel.from_pretrained(
input_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).eval()
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(input_path, trust_remote_code=True)
if model.config.use_backbone_lora:
model.vision_model.merge_and_unload()
model.vision_model = model.vision_model.model
model.config.use_backbone_lora = 0
if model.config.use_llm_lora:
model.language_model.merge_and_unload()
model.language_model = model.language_model.model
model.config.use_llm_lora = 0
print("Saving model...")
model.save_pretrained(output_path)
print("Saving tokenizer...")
tokenizer.save_pretrained(output_path)
print("Done!")