diff --git a/examples/pytorch/gptj/utils/huggingface_gptj_ckpt_convert.py b/examples/pytorch/gptj/utils/huggingface_gptj_ckpt_convert.py index 414e7c5b5..077a9bbd3 100644 --- a/examples/pytorch/gptj/utils/huggingface_gptj_ckpt_convert.py +++ b/examples/pytorch/gptj/utils/huggingface_gptj_ckpt_convert.py @@ -5,7 +5,7 @@ import torch import configparser -from transformers import PretrainedConfig +from transformers import PretrainedConfig, GPTJForCausalLM torch.set_printoptions(linewidth=130, sci_mode=False) np.set_printoptions(linewidth=130, suppress=True) @@ -111,9 +111,9 @@ def save(w, save_dir, n_inference_gpus, n_layers, layer_id): ) args = parser.parse_args() - ckpt_file = args.ckpt_dir + "/pytorch_model.bin" - checkpoint = torch.load(ckpt_file) - print(f"loading from {ckpt_file}") + checkpoint = GPTJForCausalLM.from_pretrained(args.ckpt_dir) + checkpoint = checkpoint.state_dict() + print(f"loading from {args.ckpt_dir}") out_path = args.output_dir output_dir = out_path + f"/{args.n_inference_gpus}-gpu/"