Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to implement fine-tuned model by myself? #137

Open
kaneyxx opened this issue Apr 19, 2020 · 1 comment
Open

How to implement fine-tuned model by myself? #137

kaneyxx opened this issue Apr 19, 2020 · 1 comment

Comments

@kaneyxx
Copy link

kaneyxx commented Apr 19, 2020

Hello, thanks for sharing this awesome project at first :)
I have fine-tuned the supNMT pre-trained model and save the checkpoint out there. Now I want to build a model out of MASS directory and implement it on a chatbot to use with my friends. I'm not familiar with fairseq module. How should I do, any suggestion?

e.g. I have the weights but I don't know how to build the model and load it. I found the build_model function in xmasked_seq2seq.py but I don't know how to do next.

@kaneyxx
Copy link
Author

kaneyxx commented Apr 20, 2020

I used followed function to load pre-trained model, and got some errors

en2zh = TransformerModel.from_pretrained("./",
checkpoint_file="checkpoint_best.pt",
task="xmasked_seq2seq",
arch="xtransformer",
langs="en,zh",
source_langs="en",
target_langs="zh",
mt_steps="en-zh",
mass_steps="",
memt_steps="",
valid_lang_pairs="",
no_scale_embedding=True,
src_dict="./test/processed/dict.en.txt",
tgt_dict="./test/processed/dict.zh.txt"
)

<in fairseq 0.71, same as MASS>

NotImplementedError Traceback (most recent call last)
in
12 no_scale_embedding=True,
13 src_dict="./test/processed/dict.en.txt",
---> 14 tgt_dict="./test/processed/dict.zh.txt"
15 )

/opt/conda/lib/python3.7/site-packages/fairseq/models/fairseq_model.py in from_pretrained(cls, model_name_or_path, checkpoint_file, data_name_or_path, **kwargs)
199 print(args)
200
--> 201 return hub_utils.Generator(args, task, models)
202
203 @classmethod

/opt/conda/lib/python3.7/site-packages/fairseq/hub_utils.py in init(self, args, task, models)
21 self.task = task
22 self.models = models
---> 23 self.src_dict = task.source_dictionary
24 self.tgt_dict = task.target_dictionary
25 self.use_cuda = torch.cuda.is_available() and not getattr(args, 'cpu', False)

/opt/conda/lib/python3.7/site-packages/fairseq/tasks/fairseq_task.py in source_dictionary(self)
264 """Return the source :class:~fairseq.data.Dictionary (if applicable
265 for this task)."""
--> 266 raise NotImplementedError
267
268 @Property

NotImplementedError:


<fairseq 0.9.0>

RuntimeError Traceback (most recent call last)
in
12 no_scale_embedding=True,
13 src_dict="./test/processed/dict.en.txt",
---> 14 tgt_dict="./test/processed/dict.zh.txt"
15 )

~/fairseq/fairseq/models/fairseq_model.py in from_pretrained(cls, model_name_or_path, checkpoint_file, data_name_or_path, **kwargs)
216 data_name_or_path,
217 archive_map=cls.hub_models(),
--> 218 **kwargs,
219 )
220 logger.info(x["args"])

~/fairseq/fairseq/hub_utils.py in from_pretrained(model_name_or_path, checkpoint_file, data_name_or_path, archive_map, **kwargs)
71 models, args, task = checkpoint_utils.load_model_ensemble_and_task(
72 [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
---> 73 arg_overrides=kwargs,
74 )
75

~/fairseq/fairseq/checkpoint_utils.py in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix)
209 # build model for ensemble
210 model = task.build_model(args)
--> 211 model.load_state_dict(state["model"], strict=strict, args=args)
212 ensemble.append(model)
213 return ensemble, args, task

~/fairseq/fairseq/models/fairseq_model.py in load_state_dict(self, state_dict, strict, args)
91 self.upgrade_state_dict(state_dict)
92 new_state_dict = prune_state_dict(state_dict, args)
---> 93 return super().load_state_dict(new_state_dict, strict)
94
95 def upgrade_state_dict(self, state_dict):

~/miniconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
828 if len(error_msgs) > 0:
829 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 830 self.class.name, "\n\t".join(error_msgs)))
831 return _IncompatibleKeys(missing_keys, unexpected_keys)
832

RuntimeError: Error(s) in loading state_dict for XTransformerModel:
Missing key(s) in state_dict: "decoders.zh.output_projection.weight".


Any help please?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant