-
Notifications
You must be signed in to change notification settings - Fork 228
/
convert_phi3_checkpoints.py
445 lines (387 loc) · 13.9 KB
/
convert_phi3_checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
import re
import sys
os.environ["KERAS_BACKEND"] = "torch"
import huggingface_hub # noqa: E402
import keras # noqa: E402
import torch # noqa: E402
import transformers # noqa: E402
from keras_nlp import upload_preset # noqa: E402
from keras_nlp.src.models import Phi3Backbone # noqa: E402
from keras_nlp.src.models import Phi3Preprocessor # noqa: E402
from keras_nlp.src.models import Phi3Tokenizer # noqa: E402
PRESET_MAP = {
"phi3_mini_4k_instruct_en": "microsoft/Phi-3-mini-4k-instruct",
"phi3_mini_128k_instruct_en": "microsoft/Phi-3-mini-128k-instruct",
}
def download_hf_model(hf_model_name, extract_dir):
hf_model_dir = huggingface_hub.snapshot_download(
repo_id=hf_model_name,
allow_patterns=["*.json", "*.safetensors", "*.py", "*.model"],
ignore_patterns=["*/*"],
local_dir=extract_dir,
)
return hf_model_dir
def convert_tokenizer(hf_model_dir):
# We can't import `sentencepiece_model_pb2` from `sentencepiece` because of
# this protobuffer lib error
# `TypeError: Couldn't build proto file into descriptor pool: duplicate
# file name sentencepiece_model.proto`
# transformers library build `sentencepiece_model.proto` once. we can't
# import again because it will try to build the proto again into the
# descriptor pool and protobuffer forbids that.
sp_pb2_module = sys.modules.get(
"transformers.utils.sentencepiece_model_pb2_new", None
)
if sp_pb2_module is None:
sp_pb2_module = sys.modules.get(
"transformers.utils.sentencepiece_model_pb2", None
)
if sp_pb2_module is None:
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_module
model_path = os.path.join(hf_model_dir, "tokenizer.model")
added_tokens_path = os.path.join(hf_model_dir, "added_tokens.json")
with open(model_path, "rb") as sp_model_file:
model_proto = sp_pb2_module.ModelProto()
model_proto.ParseFromString(sp_model_file.read())
with open(added_tokens_path, "rb") as added_tokens_file:
added_tokens = json.load(added_tokens_file)
# Add the new tokens to the model as user defined pieces.
for token in added_tokens.keys():
new_token = sp_pb2_module.ModelProto().SentencePiece()
new_token.piece = token
new_token.score = 0.0
new_token.type = 4 # user defined symbols.
model_proto.pieces.append(new_token)
tokenizer = Phi3Tokenizer(model_proto.SerializeToString())
for key, value in added_tokens.items():
assert key == tokenizer.id_to_token(
value
), f"{key} token have different id in the tokenizer"
return tokenizer
def convert_model(hf_model, device, dtype):
hf_config = hf_model.config.to_dict()
kwargs = {}
kwargs["vocabulary_size"] = hf_config["vocab_size"]
kwargs["num_layers"] = hf_config["num_hidden_layers"]
kwargs["num_query_heads"] = hf_config["num_attention_heads"]
kwargs["num_key_value_heads"] = hf_config["num_key_value_heads"]
kwargs["hidden_dim"] = hf_config["hidden_size"]
kwargs["intermediate_dim"] = hf_config["intermediate_size"]
kwargs["dropout"] = hf_config["attention_dropout"]
kwargs["layer_norm_epsilon"] = hf_config["rms_norm_eps"]
kwargs["max_sequence_length"] = hf_config["max_position_embeddings"]
kwargs["original_max_sequence_length"] = hf_config[
"original_max_position_embeddings"
]
kwargs["rope_max_wavelength"] = hf_config["rope_theta"]
if hf_config["rope_scaling"] is not None:
kwargs["rope_scaling_type"] = hf_config["rope_scaling"]["type"]
kwargs["rope_scaling_short_factor"] = hf_config["rope_scaling"][
"short_factor"
]
kwargs["rope_scaling_long_factor"] = hf_config["rope_scaling"][
"long_factor"
]
kwargs["dtype"] = dtype
with keras.device(device):
keras_model = Phi3Backbone(**kwargs)
return keras_model
def convert_weights(keras_model, hf_model):
hidden_dim = keras_model.hidden_dim
intermediate_dim = keras_model.intermediate_dim
num_query_heads = keras_model.num_query_heads
num_key_value_heads = keras_model.num_key_value_heads
head_dim = hidden_dim // num_query_heads
# get huggingface model weights.
hf_wts = hf_model.state_dict()
# Embedding layer.
keras_model.token_embedding.embeddings.assign(
hf_wts["model.embed_tokens.weight"]
)
keras_model.token_embedding.reverse_embeddings.assign(
hf_wts["lm_head.weight"].t()
)
# LayerNorm.
keras_model.layer_norm.scale.assign(hf_wts["model.norm.weight"])
# Decoder layers.
for i, decoder_layer in enumerate(keras_model.transformer_layers):
# LayrNorm.
decoder_layer.pre_attention_layernorm.scale.assign(
hf_wts[f"model.layers.{i}.input_layernorm.weight"]
)
decoder_layer.post_attention_layernorm.scale.assign(
hf_wts[f"model.layers.{i}.post_attention_layernorm.weight"]
)
# Attention layer.
attention_layer = decoder_layer.attention
fused_qkv_kernel = hf_wts[
f"model.layers.{i}.self_attn.qkv_proj.weight"
].t()
query_kernel = fused_qkv_kernel[:, :hidden_dim]
query_kernel = query_kernel.reshape(
hidden_dim, num_query_heads, head_dim
)
key_kernel = fused_qkv_kernel[
:, hidden_dim : hidden_dim + num_key_value_heads * head_dim
]
key_kernel = key_kernel.reshape(
hidden_dim, num_key_value_heads, head_dim
)
value_kernel = fused_qkv_kernel[
:, hidden_dim + num_key_value_heads * head_dim :
]
value_kernel = value_kernel.reshape(
hidden_dim, num_key_value_heads, head_dim
)
attention_layer.query_dense._kernel.assign(query_kernel)
attention_layer.key_dense._kernel.assign(key_kernel)
attention_layer.value_dense._kernel.assign(value_kernel)
attention_layer.output_dense.kernel.assign(
hf_wts[f"model.layers.{i}.self_attn.o_proj.weight"]
.t()
.reshape(num_query_heads, head_dim, hidden_dim)
)
# feed dorward layer.
fused_intermediate_gate_ff_kernel = hf_wts[
f"model.layers.{i}.mlp.gate_up_proj.weight"
].t()
decoder_layer.feedforward_gate_dense._kernel.assign(
fused_intermediate_gate_ff_kernel[:, :intermediate_dim]
)
decoder_layer.feedforward_intermediate_dense._kernel.assign(
fused_intermediate_gate_ff_kernel[:, intermediate_dim:]
)
decoder_layer.feedforward_output_dense._kernel.assign(
hf_wts[f"model.layers.{i}.mlp.down_proj.weight"].t()
)
def validate_output(
hf_model,
keras_model,
hf_device,
keras_device,
hf_tokenizer,
keras_preprocessor,
):
# Hf
tokens = hf_tokenizer(
["<|user|>How to win?<|end|><|assistant|>"],
max_length=9,
padding="max_length",
return_tensors="pt",
)
hf_model_input = {
"input_ids": tokens["input_ids"].to(hf_device),
"attention_mask": tokens["attention_mask"].to(hf_device),
"use_cache": False,
"output_attentions": False,
"output_hidden_states": False,
"return_dict": False,
}
hf_model_outputs = hf_model(**hf_model_input)[0]
# KerasNLP
keras_model_input = keras_preprocessor(
["<|user|>How to win?<|end|><|assistant|>"]
)
keras_model_input = {
k: v.to(keras_device) for k, v in keras_model_input.items()
}
keras_model_outputs = keras_model(keras_model_input)
# Comparing the outputs.
print("🔶 KerasNLP output:", keras_model_outputs[0, 0, :10])
print("🔶 HF output:", hf_model_outputs[0, 0, :10])
print(
"🔶 Difference:",
torch.mean(
torch.abs(
keras_model_outputs.detach().cpu()
- hf_model_outputs.detach().cpu()
)
),
)
def get_torch_dtype(str_dtype):
if str_dtype == "float32":
return torch.float32
elif str_dtype == "float16":
return torch.float16
elif str_dtype == "bfloat16":
return torch.bfloat16
def convert_and_validate(
hf_model_dir,
hf_device,
keras_device,
validate_dtype,
hf_tokenizer,
keras_preprocessor,
):
print(f"✅ Numerics Validation in {validate_dtype}.")
# Load the causal model to convert lm_head weights.
hf_causal_model = transformers.AutoModelForCausalLM.from_pretrained(
hf_model_dir,
device_map=hf_device,
torch_dtype=get_torch_dtype(validate_dtype),
trust_remote_code=True,
)
hf_model = hf_causal_model.model
print("✅ Huggingface model loaded.")
keras_model = convert_model(hf_causal_model, keras_device, validate_dtype)
print("✅ Keras model loaded.")
convert_weights(keras_model, hf_causal_model)
print("✅ Weights converted")
validate_output(
hf_model,
keras_model,
hf_device,
keras_device,
hf_tokenizer,
keras_preprocessor,
)
print("✅ Numerics validated")
# Clean memory.
del keras_model
del hf_causal_model
del hf_model
gc.collect()
if not (hf_device == "cpu" and keras_device == "cpu"):
torch.cuda.empty_cache()
def convert_and_save(
preset,
hf_model_dir,
hf_device,
keras_device,
save_dtype,
keras_preprocessor,
):
print(f"✅ Saving model in {save_dtype}.")
# Load the causal model to convert lm_head weights.
hf_causal_model = transformers.AutoModelForCausalLM.from_pretrained(
hf_model_dir,
device_map=hf_device,
torch_dtype=get_torch_dtype(save_dtype),
trust_remote_code=True,
)
print("✅ Huggingface model loaded.")
keras_model = convert_model(hf_causal_model, keras_device, save_dtype)
print("✅ Keras model loaded.")
convert_weights(keras_model, hf_causal_model)
print("✅ Weights converted")
keras_model.save_to_preset(preset)
keras_preprocessor.tokenizer.save_to_preset(preset)
print("✅ Preset saved")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--preset",
default="phi3_mini_4k_instruct_en",
choices=PRESET_MAP.keys(),
required=True,
help=f'Preset must be one of {", ".join(PRESET_MAP.keys())}',
)
def device_regex(arg_value, pattern=re.compile(r"^cpu$|^cuda:[0-9]+$")):
if not pattern.match(arg_value):
raise argparse.ArgumentTypeError(
"The device must be one of: "
"'cpu', 'cuda:0', 'cuda:1', ...'cuda:n'"
)
return arg_value
parser.add_argument(
"--hf_device",
default="cpu",
type=device_regex,
help=(
"The device where huggingface model will be loaded. can be one of: "
"'cpu', 'cuda:0', 'cuda:1', ...'cuda:n'"
),
)
parser.add_argument(
"--keras_device",
default="cpu",
type=device_regex,
help=(
"The device where keras model will be loaded. can be one of: "
"'cpu', 'cuda:0', 'cuda:1', ...'cuda:n'"
),
)
parser.add_argument(
"--validate_dtype",
choices=["float32", "float16", "bfloat16"],
default="float32",
help=(
"The dtype of the two models while validating numerics. "
"can be 'float32', 'float16', or 'bfloat16'"
),
)
parser.add_argument(
"--save_dtype",
choices=["float32", "float16", "bfloat16"],
default="bfloat16",
help=(
"The dtype that keras model will be saved with. "
"can be 'float32', 'float16', or 'bfloat16'"
),
)
parser.add_argument(
"--upload_link",
type=str,
help=(
"The link to upload the model. can be in these formats: "
"`kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>`, "
"`hf://[<HF_USERNAME>/]<MODEL>`"
),
)
args = parser.parse_args()
preset = args.preset
hf_device = args.hf_device
keras_device = args.keras_device
validate_dtype = args.validate_dtype
save_dtype = args.save_dtype
upload_link = args.upload_link
print(f"✅ Coverting {preset}.")
hf_model_name = PRESET_MAP[preset]
hf_model_dir = download_hf_model(hf_model_name, f"./{preset}_hf_model")
print("✅ Huggingface model downloaded from the hub.")
keras_tokenizer = convert_tokenizer(hf_model_dir)
keras_preprocessor = Phi3Preprocessor(
tokenizer=keras_tokenizer, sequence_length=9
)
# phi3 uses llama tokenizer
hf_tokenizer = transformers.LlamaTokenizer.from_pretrained(
hf_model_dir, padding_side="right"
)
convert_and_validate(
hf_model_dir=hf_model_dir,
hf_device=hf_device,
keras_device=keras_device,
validate_dtype=validate_dtype,
hf_tokenizer=hf_tokenizer,
keras_preprocessor=keras_preprocessor,
)
convert_and_save(
preset=preset,
hf_device=hf_device,
keras_device=keras_device,
save_dtype=save_dtype,
hf_model_dir=hf_model_dir,
keras_preprocessor=keras_preprocessor,
)
if upload_link is not None:
upload_preset(upload_link, preset)
print("✅ Preset uploaded")
if __name__ == "__main__":
main()