-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
executable file
·52 lines (46 loc) · 1.78 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
'''
code by TaeHwan Jung(@graykode)
Original Paper and repository here : https://github.com/openai/gpt-2
GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
'''
import logging
logger = logging.getLogger(__name__)
def load_weight(model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if key.endswith(".g"):
new_key = key[:-2] + ".weight"
elif key.endswith(".b"):
new_key = key[:-2] + ".bias"
elif key.endswith(".w"):
new_key = key[:-2] + ".weight"
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
start_model = model
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
start_model = model.transformer
load(start_model, prefix="")
# Make sure we are still sharing the output and input embeddings after loading weights
model.set_tied()
return model