-
Notifications
You must be signed in to change notification settings - Fork 13
/
tool_export_to_hf.py
executable file
·406 lines (341 loc) · 16.5 KB
/
tool_export_to_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import os
import sys
import torch
import transformers
from tqdm import tqdm
import types
from transformers import LlamaForCausalLM
def add_arguments(parser):
group = parser.add_argument_group(title='Llama-2 HF loader.')
group.add_argument('--checkpoint', type=str, default=None,
help='Path to the checkpoint to load.')
group.add_argument('--true-vocab-size', type=int, default=None,
help='original size of vocab, if specified will trim padding from embedding table.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.')
group.add_argument('--tokenizer-model', required=True,
help='Sentencepiece tokenizer model.')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of deepspeed repository')
def verify_transformers_version():
major, minor, patch = map(int, transformers.__version__.split('.'))
assert major >= 4 and minor >= 31
def load_args_from_checkpoint(args):
# Read Llama args.
llama_args_path = os.path.join(args.load, "config.json")
with open(llama_args_path) as f:
llama_args = json.load(f)
# Update Megatron args.
args.seq_length = 4096
args.max_position_embeddings = 4096
args.hidden_size = llama_args["hidden_size"]
args.num_attention_heads = llama_args["num_attention_heads"]
args.num_layers = llama_args["num_hidden_layers"]
args.global_batch_size = 1024
args.norm_epsilon = llama_args["rms_norm_eps"]
args.iteration = 1 # '0', 'release' don't work
args.add_position_embedding = False
args.use_rotary_position_embeddings = True
args.swiglu = True
args.tokenizer_type = "Llama2Tokenizer"
args.fp16 = True
args.normalization = "RMSNorm"
args.add_bias_linear = False
args.untie_embeddings_and_output_weights = True
args.vocab_size = llama_args["vocab_size"]
args.padded_vocab_size = llama_args["vocab_size"]
args.llama = llama_args
args.ffn_hidden_size = llama_args["intermediate_size"]
if "num_key_value_heads" in llama_args:
args.group_query_attention = True
args.num_query_groups = llama_args["num_key_value_heads"]
def set_preprocess_state(args, model, hf_model):
'''Set embedding params.'''
model.language_model.embedding.word_embeddings.weight.data.copy_(
hf_model.model.embed_tokens.weight)
def set_postprocess_state(args, model, hf_model):
'''Set output layer & norm params.'''
model.language_model.encoder.final_norm.weight.data.copy_(hf_model.model.norm.weight)
model.language_model.output_layer.weight.data.copy_(hf_model.lm_head.weight)
def set_preprocess_state_hf(args, hf_model, megatron_ckpt):
megatron_weight = megatron_ckpt['model']['language_model']['embedding']['word_embeddings']['weight']
hf_model.model.embed_tokens.weight.data.copy_(megatron_weight)
def set_postprocess_state_hf(args, hf_model, megatron_ckpt):
final_norm_weight = megatron_ckpt['model']['language_model']['encoder']['final_norm.weight']
lm_head_weight = megatron_ckpt['model']['language_model']['output_layer']['weight']
hf_model.model.norm.weight.data.copy_(final_norm_weight)
hf_model.lm_head.weight.data.copy_(lm_head_weight)
def set_attn_state(args, layer, hf_layer):
'''Set self-attention params.'''
# Get attention layer & state.
attn = layer.self_attention
hf_attn = hf_layer.self_attn
# Reshape loaded weights.
tp = args.tensor_model_parallel_size
nh = args.num_attention_heads // tp
ng = (args.num_query_groups if args.group_query_attention \
else args.num_attention_heads) // tp
dim = args.kv_channels
assert nh % ng == 0
# Copy weights (re-order dimensions for Megatron).
attn.query_key_value.weight.data.copy_(torch.cat([
hf_attn.q_proj.weight.reshape((ng, dim*nh//ng, -1)),
hf_attn.k_proj.weight.reshape((ng, dim, -1)),
hf_attn.v_proj.weight.reshape((ng, dim, -1)),
], dim=1).reshape((-1, args.hidden_size)))
attn.dense.weight.data.copy_(hf_attn.o_proj.weight)
def set_attn_state_hf(args, hf_layer, megatron_ckpt, layer_idx, hidden_size):
hf_attn = hf_layer.self_attn
llama_args_path = os.path.join(args.hf_ckpt, "config.json")
with open(llama_args_path) as f:
llama_args = json.load(f)
tp = 1
nh = llama_args["num_attention_heads"] // tp
ng = (args.num_query_groups if args.group_query_attention \
else nh // tp)
dim = llama_args["hidden_size"] // nh
assert nh % ng == 0
qkv_weight = megatron_ckpt['model']['language_model']['encoder'][f"layers.{layer_idx}.self_attention.query_key_value.weight"]
# Copy weights from layer to hf layer
qkv_reshape = qkv_weight.reshape((ng, dim*nh//ng + dim + dim, -1))
hf_attn.q_proj.weight.data.copy_(qkv_reshape[:, :dim*nh//ng, :].reshape((-1, llama_args["hidden_size"])))
hf_attn.k_proj.weight.data.copy_(qkv_reshape[:, dim*nh//ng:dim*nh//ng+dim, :].reshape((-1, llama_args["hidden_size"])))
hf_attn.v_proj.weight.data.copy_(qkv_reshape[:, dim*nh//ng+dim:, :].reshape((-1, llama_args["hidden_size"])))
dense_weight = megatron_ckpt['model']['language_model']['encoder'][f"layers.{layer_idx}.self_attention.dense.weight"]
hf_attn.o_proj.weight.data.copy_(dense_weight)
def set_mlp_state_hf(args, hf_layer, megatron_ckpt, layer_idx, hidden_size):
hf_mlp = hf_layer.mlp
dense_h_to_4h = megatron_ckpt['model']['language_model']['encoder'][f"layers.{layer_idx}.mlp.dense_h_to_4h.weight"]
dense_4h_to_h = megatron_ckpt['model']['language_model']['encoder'][f"layers.{layer_idx}.mlp.dense_4h_to_h.weight"]
hf_mlp.gate_proj.weight.data.copy_(dense_h_to_4h[:hf_mlp.gate_proj.weight.shape[0]])
hf_mlp.up_proj.weight.data.copy_(dense_h_to_4h[hf_mlp.gate_proj.weight.shape[0]:])
hf_mlp.down_proj.weight.data.copy_(dense_4h_to_h)
def set_layer_state_hf(args, hf_model, megatron_ckpt, layer_idx):
hf_layer = hf_model.model.layers[layer_idx]
hidden_size = hf_model.config.hidden_size
set_attn_state_hf(args, hf_layer, megatron_ckpt, layer_idx, hidden_size)
set_mlp_state_hf(args, hf_layer, megatron_ckpt, layer_idx, hidden_size)
input_norm_weight = megatron_ckpt['model']['language_model']['encoder'][f"layers.{layer_idx}.input_norm.weight"]
post_attention_layernorm_weight = megatron_ckpt['model']['language_model']['encoder'][f"layers.{layer_idx}.post_attention_norm.weight"]
hf_layer.input_layernorm.weight.data.copy_(input_norm_weight)
hf_layer.post_attention_layernorm.weight.data.copy_(post_attention_layernorm_weight)
def load_checkpoint_to_model(hf_model, megatron_ckpt):
'''Set model params.'''
# Set model state.
set_preprocess_state_hf(args, hf_model, megatron_ckpt)
set_postprocess_state_hf(args, hf_model, megatron_ckpt)
num_layers = hf_model.config.num_hidden_layers
for layer_idx in tqdm(range(num_layers), "set layer states"):
set_layer_state_hf(args, hf_model, megatron_ckpt, layer_idx)
return hf_model
def _load_checkpoint(queue, args):
# Llama-2 requires HF transformers >=4.31.0.
verify_transformers_version()
# Search in directory above this.
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir,
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables
from megatron.model import module
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron import fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit")
exit(1)
# We want all arguments to come from us.
sys.argv = ['script.py',
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--load', args.load_dir
]
margs = parse_args()
margs.tokenizer_model = args.tokenizer_model
load_args_from_checkpoint(margs)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes.
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
margs = validate_args(margs)
def check_for_arg(arg_name, default=None):
if getattr(margs, arg_name, None) is None:
if default is not None:
setattr(margs, arg_name, default)
else:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size')
check_for_arg('num_layers')
check_for_arg('hidden_size')
check_for_arg('seq_length')
check_for_arg('num_attention_heads')
check_for_arg('max_position_embeddings')
check_for_arg('position_embedding_type')
check_for_arg('tokenizer_type')
check_for_arg('iteration')
check_for_arg('bert_binary_head')
check_for_arg('disable_bias_linear', False)
check_for_arg('params_dtype')
check_for_arg('swiglu', False)
# Determine how to make our models.
assert args.model_type == 'GPT', 'Llama-2 is a GPT model.'
margs.model_type = ModelType.encoder_or_decoder
# Suppress warning about torch.distributed not being initialized.
module.MegatronModule.embedding_warning_printed = True
set_global_variables(margs, build_tokenizer=False)
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size)
fused_kernels.load(margs)
# Short aliases.
tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size
vp_size = margs.virtual_pipeline_model_parallel_size
if vp_size is None:
vp_size = 1
# Metadata.
md = types.SimpleNamespace()
md.model_type = args.model_type
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
md.num_attention_heads = margs.num_attention_heads
md.max_position_embeddings = margs.max_position_embeddings
md.tokenizer_type = margs.tokenizer_type
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.output_layer = margs.untie_embeddings_and_output_weights
md.position_embedding_type = margs.position_embedding_type
md.linear_bias = margs.add_bias_linear
md.norm_has_bias = False
md.swiglu = margs.swiglu
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = None # skips padding in saver
md.make_vocab_size_divisible_by = None
md.checkpoint_args = margs
md.consumed_train_samples = 0
md.consumed_valid_samples = 0
# Get first pipe stage.
mpu.set_tensor_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
model = load_checkpoint_to_model(margs)
queue.put(md)
def queue_put(name, msg):
print(f"sending {name}")
msg["name"] = name
queue.put(msg)
# Send embeddings.
message = {
"word embeddings": model.language_model.embedding.word_embeddings.weight.data
}
if md.position_embedding_type == 'learned_absolute':
message["position embeddings"] = model.language_model.embedding.position_embeddings.weight.data
else:
assert not hasattr(model.language_model.embedding, 'position_embeddings')
queue_put("embeddings", message)
for layer_num in range(margs.num_layers):
message = {}
# Get non-parallel tensors from tp_rank 0.
layer = model.language_model.encoder.layers[layer_num]
message["input norm weight"] = layer.input_norm.weight.data
message["post norm weight"] = layer.post_attention_norm.weight.data
if md.linear_bias:
message["dense bias"] = layer.self_attention.dense.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer.
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
if md.linear_bias:
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
# Handle gated linear units.
if md.swiglu:
# Concat all the first halves ('W's) and all the second halves ('V's).
for tp_rank in range(tp_size):
mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0)
message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0)
message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0)
else:
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
# Simple concat of the rest.
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
if md.linear_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
if md.swiglu:
for tp_rank in range(tp_size):
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias],dim=0)
message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0)
else:
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
queue_put(f"transformer layer {layer_num}", message)
# Send final norm from tp_rank 0.
message = {
"weight": model.language_model.encoder.final_norm.weight.data,
}
queue_put("final norm", message)
if md.output_layer:
message = {
"weight": model.language_model.output_layer.weight.data
}
queue_put("output layer", message)
queue.put("done")
def load_checkpoint(queue, args):
try:
_load_checkpoint(queue, args)
except:
queue.put("exit")
raise
import megatron
from megatron.initialize import initialize_megatron
import argparse
if __name__=='__main__':
parser = argparse.ArgumentParser(description='convert llama checkpoint')
parser.add_argument('--hf_ckpt', type=str, default=None,
help='Path to the checkpoint to load.')
parser.add_argument('--megatron_ckpt', type=str, default=None,
help='Path to the checkpoint to load.')
parser.add_argument('--save_ckpt', type=str, default=None,
help='Path to the checkpoint to load.')
parser.add_argument('--num_query_groups', type=int, default=None)
parser.add_argument('--group_query_attention', action='store_true', default=False)
args = parser.parse_args()
hf_model = LlamaForCausalLM.from_pretrained(args.hf_ckpt, device_map="cpu")
megatron_ckpt = torch.load(args.megatron_ckpt, map_location='cpu')
print(megatron_ckpt['model']['language_model']['encoder'].keys())
sparse_model = load_checkpoint_to_model(hf_model, megatron_ckpt)
print(sparse_model)
sparse_model.save_pretrained(args.save_ckpt)