Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix erniegen no model_config_file #3321

Merged
merged 3 commits into from
Sep 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 4 additions & 32 deletions paddlenlp/transformers/ernie_gen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.log import logger
from paddlenlp.transformers import BertPretrainedModel, ElectraPretrainedModel, RobertaPretrainedModel, ErniePretrainedModel
from .. import PretrainedModel, register_base_model

from ..utils import InitTrackerMeta, fn_args_to_dict

Expand Down Expand Up @@ -216,7 +217,7 @@ def forward(self, inputs, attn_bias=None, past_cache=None):


@six.add_metaclass(InitTrackerMeta)
class ErnieGenPretrainedModel(object):
class ErnieGenPretrainedModel(PretrainedModel):
r"""
An abstract class for pretrained ErnieGen models. It provides ErnieGen related
`model_config_file`, `pretrained_init_configuration`, `resource_files_names`,
Expand Down Expand Up @@ -389,36 +390,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
weight_path)
return model

def save_pretrained(self, save_directory):
"""
Save model configuration and related resources (model state) to files
under `save_directory`.
Args:
save_directory (str): Directory to save files into.
"""
assert os.path.isdir(
save_directory
), "Saving directory ({}) should be a directory".format(save_directory)
# save model config
model_config_file = os.path.join(save_directory, self.model_config_file)
model_config = self.init_config
# If init_config contains a Layer, use the layer's init_config to save
for key, value in model_config.items():
if key == "init_args":
args = []
for arg in value:
args.append(arg.init_config if isinstance(
arg, ErnieGenPretrainedModel) else arg)
model_config[key] = tuple(args)
elif isinstance(value, ErnieGenPretrainedModel):
model_config[key] = value.init_config
with io.open(model_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(model_config, ensure_ascii=False))
# save model
file_name = os.path.join(save_directory,
list(self.resource_files_names.values())[0])
paddle.save(self.state_dict(), file_name)

def _post_init(self, original_init, *args, **kwargs):
"""
It would be hooked after `__init__` to add a dict including arguments of
Expand All @@ -428,7 +399,8 @@ def _post_init(self, original_init, *args, **kwargs):
self.config = init_dict


class ErnieModel(nn.Layer, ErnieGenPretrainedModel):
@register_base_model
class ErnieModel(ErnieGenPretrainedModel):

def __init__(self, cfg, name=None):
"""
Expand Down