From a463e0e74f4908950ccdfe35198d13f38de6ae1b Mon Sep 17 00:00:00 2001 From: Pete Date: Sun, 2 May 2021 14:51:42 -0700 Subject: [PATCH] Add way of skipping pretrained weights download (#5172) * add way of skipping pretrained weights download * clarify docstring * add link to PR in CHANGELOG --- CHANGELOG.md | 4 ++ allennlp/common/cached_transformers.py | 36 ++++++++++- .../modules/seq2vec_encoders/bert_pooler.py | 16 ++++- .../pretrained_transformer_embedder.py | 13 ++++ ...trained_transformer_mismatched_embedder.py | 17 ++++++ .../modules/transformer/transformer_module.py | 32 ++++++---- .../modules/transformer/transformer_stack.py | 9 ++- tests/common/cached_transformers_test.py | 59 +++++++++++++++++-- 8 files changed, 164 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed0d6e533f1..ae6b0fb6a3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module. +- Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers + such as the `PretrainedTransformerEmbedder` and `PretrainedTransformerMismatchedEmbedder`. + You can do this by setting the parameter `load_weights` to `False`. + See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details. ## Unreleased diff --git a/allennlp/common/cached_transformers.py b/allennlp/common/cached_transformers.py index 37d88c1c41a..bc7cc4a6dfd 100644 --- a/allennlp/common/cached_transformers.py +++ b/allennlp/common/cached_transformers.py @@ -1,5 +1,7 @@ import logging +import warnings from typing import NamedTuple, Optional, Dict, Tuple + import transformers from transformers import AutoModel, AutoConfig @@ -21,6 +23,7 @@ def get( make_copy: bool, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, + load_weights: bool = True, **kwargs, ) -> transformers.PreTrainedModel: """ @@ -34,18 +37,35 @@ def get( If this is `True`, return a copy of the model instead of the cached model itself. If you want to modify the parameters of the model, set this to `True`. If you want only part of the model, set this to `False`, but make sure to `copy.deepcopy()` the bits you are keeping. - override_weights_file : `str`, optional + override_weights_file : `str`, optional (default = `None`) If set, this specifies a file from which to load alternate weights that override the weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created with `torch.save()`. - override_weights_strip_prefix : `str`, optional + override_weights_strip_prefix : `str`, optional (default = `None`) If set, strip the given prefix from the state dict when loading it. + load_weights : `bool`, optional (default = `True`) + If set to `False`, no weights will be loaded. This is helpful when you only + want to initialize the architecture, like when you've already fine-tuned a model + and are going to load the weights from a state dict elsewhere. """ global _model_cache spec = TransformerSpec(model_name, override_weights_file, override_weights_strip_prefix) transformer = _model_cache.get(spec, None) if transformer is None: - if override_weights_file is not None: + if not load_weights: + if override_weights_file is not None: + warnings.warn( + "You specified an 'override_weights_file' in allennlp.common.cached_transformers.get(), " + "but 'load_weights' is set to False, so 'override_weights_file' will be ignored.", + UserWarning, + ) + transformer = AutoModel.from_config( + AutoConfig.from_pretrained( + model_name, + **kwargs, + ) + ) + elif override_weights_file is not None: from allennlp.common.file_utils import cached_path import torch @@ -121,3 +141,13 @@ def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer ) _tokenizer_cache[cache_key] = tokenizer return tokenizer + + +def _clear_caches(): + """ + Clears in-memory transformer and tokenizer caches. + """ + global _model_cache + global _tokenizer_cache + _model_cache.clear() + _tokenizer_cache.clear() diff --git a/allennlp/modules/seq2vec_encoders/bert_pooler.py b/allennlp/modules/seq2vec_encoders/bert_pooler.py index 7509807f49f..727927a03cd 100644 --- a/allennlp/modules/seq2vec_encoders/bert_pooler.py +++ b/allennlp/modules/seq2vec_encoders/bert_pooler.py @@ -26,7 +26,15 @@ class BertPooler(Seq2VecEncoder): The pretrained BERT model to use. If this is a string, we will call `transformers.AutoModel.from_pretrained(pretrained_model)` and use that. - requires_grad : `bool`, optional, (default = `True`) + override_weights_file: `Optional[str]`, optional (default = `None`) + If set, this specifies a file from which to load alternate weights that override the + weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created + with `torch.save()`. + override_weights_strip_prefix: `Optional[str]`, optional (default = `None`) + If set, strip the given prefix from the state dict when loading it. + load_weights: `bool`, optional (default = `True`) + Whether to load the pretraiend weights. + requires_grad : `bool`, optional (default = `True`) If True, the weights of the pooler will be updated during training. Otherwise they will not. dropout : `float`, optional, (default = `0.0`) @@ -43,6 +51,7 @@ def __init__( *, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, + load_weights: bool = True, requires_grad: bool = True, dropout: float = 0.0, transformer_kwargs: Optional[Dict[str, Any]] = None, @@ -54,8 +63,9 @@ def __init__( model = cached_transformers.get( pretrained_model, False, - override_weights_file, - override_weights_strip_prefix, + override_weights_file=override_weights_file, + override_weights_strip_prefix=override_weights_strip_prefix, + load_weights=load_weights, **(transformer_kwargs or {}), ) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 1ce457b1adb..6ef220fbcca 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -49,6 +49,17 @@ class PretrainedTransformerEmbedder(TokenEmbedder): When `True` (the default), only the final layer of the pretrained transformer is taken for the embeddings. But if set to `False`, a scalar mix of all of the layers is used. + override_weights_file: `Optional[str]`, optional (default = `None`) + If set, this specifies a file from which to load alternate weights that override the + weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created + with `torch.save()`. + override_weights_strip_prefix: `Optional[str]`, optional (default = `None`) + If set, strip the given prefix from the state dict when loading it. + load_weights: `bool`, optional (default = `True`) + Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive + it usually makes sense to set this to `False` (via the `overrides` parameter) + to avoid unnecessarily caching and loading the original pretrained weights, + since the archive will already contain all of the weights needed. gradient_checkpointing: `bool`, optional (default = `None`) Enable or disable gradient checkpointing. tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`) @@ -74,6 +85,7 @@ def __init__( last_layer_only: bool = True, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, + load_weights: bool = True, gradient_checkpointing: Optional[bool] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, transformer_kwargs: Optional[Dict[str, Any]] = None, @@ -86,6 +98,7 @@ def __init__( True, override_weights_file=override_weights_file, override_weights_strip_prefix=override_weights_strip_prefix, + load_weights=load_weights, **(transformer_kwargs or {}), ) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_mismatched_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_mismatched_embedder.py index b2501cde11c..9bad2aac9fc 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_mismatched_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_mismatched_embedder.py @@ -32,6 +32,17 @@ class PretrainedTransformerMismatchedEmbedder(TokenEmbedder): When `True` (the default), only the final layer of the pretrained transformer is taken for the embeddings. But if set to `False`, a scalar mix of all of the layers is used. + override_weights_file: `Optional[str]`, optional (default = `None`) + If set, this specifies a file from which to load alternate weights that override the + weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created + with `torch.save()`. + override_weights_strip_prefix: `Optional[str]`, optional (default = `None`) + If set, strip the given prefix from the state dict when loading it. + load_weights: `bool`, optional (default = `True`) + Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive + it usually makes sense to set this to `False` (via the `overrides` parameter) + to avoid unnecessarily caching and loading the original pretrained weights, + since the archive will already contain all of the weights needed. gradient_checkpointing: `bool`, optional (default = `None`) Enable or disable gradient checkpointing. tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`) @@ -56,6 +67,9 @@ def __init__( max_length: int = None, train_parameters: bool = True, last_layer_only: bool = True, + override_weights_file: Optional[str] = None, + override_weights_strip_prefix: Optional[str] = None, + load_weights: bool = True, gradient_checkpointing: Optional[bool] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, transformer_kwargs: Optional[Dict[str, Any]] = None, @@ -68,6 +82,9 @@ def __init__( max_length=max_length, train_parameters=train_parameters, last_layer_only=last_layer_only, + override_weights_file=override_weights_file, + override_weights_strip_prefix=override_weights_strip_prefix, + load_weights=load_weights, gradient_checkpointing=gradient_checkpointing, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs, diff --git a/allennlp/modules/transformer/transformer_module.py b/allennlp/modules/transformer/transformer_module.py index f03b5ad838b..861120deca2 100644 --- a/allennlp/modules/transformer/transformer_module.py +++ b/allennlp/modules/transformer/transformer_module.py @@ -147,21 +147,32 @@ def get_relevant_module( relevant_module: Optional[Union[str, List[str]]] = None, source: str = "huggingface", mapping: Optional[Dict[str, str]] = None, + load_weights: bool = True, ): """ Returns the relevant underlying module given a model name/object. - # Parameters: - - pretrained_module: Name of the transformer model containing the layer, - or the actual layer (not the model object). - relevant_module: Name of the desired module. Defaults to cls._relevant_module. - source: Where the model came from. Default - huggingface. - mapping: Optional mapping that determines any differences in the module names - between the class modules and the input model's modules. Default - cls._huggingface_mapping + # Parameters + + pretrained_module : `Union[str, torch.nn.Module]` + Name of the transformer model containing the layer, + or the actual layer (not the model object). + relevant_module : `Optional[Union[str, List[str]]]`, optional + Name of the desired module. Defaults to cls._relevant_module. + source : `str`, optional + Where the model came from. Default - huggingface. + mapping : `Dict[str, str]`, optional + Optional mapping that determines any differences in the module names + between the class modules and the input model's modules. + Default - cls._huggingface_mapping + load_weights : `bool`, optional + Whether or not to load the pretrained weights. + Default is `True`. """ if isinstance(pretrained_module, str): - pretrained_module = cached_transformers.get(pretrained_module, False) + pretrained_module = cached_transformers.get( + pretrained_module, False, load_weights=load_weights + ) relevant_module = relevant_module or cls._relevant_module @@ -192,6 +203,7 @@ def from_pretrained_module( pretrained_module: Union[str, torch.nn.Module], source: str = "huggingface", mapping: Optional[Dict[str, str]] = None, + load_weights: bool = True, **kwargs, ): """ @@ -208,7 +220,7 @@ def from_pretrained_module( ) pretrained_module = cls.get_relevant_module( - pretrained_module, source=source, mapping=mapping + pretrained_module, source=source, mapping=mapping, load_weights=load_weights ) final_kwargs = cls._get_input_arguments(pretrained_module, source, mapping) final_kwargs.update(kwargs) diff --git a/allennlp/modules/transformer/transformer_stack.py b/allennlp/modules/transformer/transformer_stack.py index edeefc27ba9..09fb1d2bc40 100644 --- a/allennlp/modules/transformer/transformer_stack.py +++ b/allennlp/modules/transformer/transformer_stack.py @@ -172,6 +172,7 @@ def from_pretrained_module( # type: ignore num_hidden_layers: Optional[Union[int, range]] = None, source="huggingface", mapping: Optional[Dict[str, str]] = None, + load_weights: bool = True, **kwargs, ): final_kwargs = {} @@ -185,4 +186,10 @@ def from_pretrained_module( # type: ignore else: final_kwargs["num_hidden_layers"] = num_hidden_layers - return super().from_pretrained_module(pretrained_module, source, mapping, **final_kwargs) + return super().from_pretrained_module( + pretrained_module, + source=source, + mapping=mapping, + load_weights=load_weights, + **final_kwargs, + ) diff --git a/tests/common/cached_transformers_test.py b/tests/common/cached_transformers_test.py index a09b4b4e84d..e0a8bee054b 100644 --- a/tests/common/cached_transformers_test.py +++ b/tests/common/cached_transformers_test.py @@ -10,6 +10,14 @@ class TestCachedTransformers(AllenNlpTestCase): + def setup_method(self): + super().setup_method() + cached_transformers._clear_caches() + + def teardown_method(self): + super().teardown_method() + cached_transformers._clear_caches() + def test_get_missing_from_cache_local_files_only(self): with pytest.raises((OSError, ValueError)): cached_transformers.get( @@ -19,17 +27,23 @@ def test_get_missing_from_cache_local_files_only(self): local_files_only=True, ) + def clear_test_dir(self): + for f in os.listdir(str(self.TEST_DIR)): + os.remove(str(self.TEST_DIR) + "/" + f) + assert len(os.listdir(str(self.TEST_DIR))) == 0 + def test_from_pretrained_avoids_weights_download_if_override_weights(self): + config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR) # only download config because downloading pretrained weights in addition takes too long transformer = AutoModel.from_config( AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR) ) + transformer = AutoModel.from_config(config) + # clear cache directory - for f in os.listdir(str(self.TEST_DIR)): - os.remove(str(self.TEST_DIR) + "/" + f) - assert len(os.listdir(str(self.TEST_DIR))) == 0 + self.clear_test_dir() - save_weights_path = str(self.TEST_DIR) + "/bert_weights.pth" + save_weights_path = str(self.TEST_DIR / "bert_weights.pth") torch.save(transformer.state_dict(), save_weights_path) override_transformer = cached_transformers.get( @@ -44,7 +58,7 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self): # so this assertion could fail in the future json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")] assert len(json_fnames) == 1 - json_data = json.load(open(str(self.TEST_DIR) + "/" + json_fnames[0])) + json_data = json.load(open(str(self.TEST_DIR / json_fnames[0]))) assert ( json_data["url"] == "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json" @@ -58,6 +72,41 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self): for p1, p2 in zip(transformer.parameters(), override_transformer.parameters()): assert p1.data.ne(p2.data).sum() == 0 + def test_from_pretrained_no_load_weights(self): + _ = cached_transformers.get( + "epwalsh/bert-xsmall-dummy", False, load_weights=False, cache_dir=self.TEST_DIR + ) + # check that only three files were downloaded (filename.json, filename, filename.lock), for config.json + # if more than three files were downloaded, then model weights were also (incorrectly) downloaded + # NOTE: downloaded files are not explicitly detailed in Huggingface's public API, + # so this assertion could fail in the future + json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")] + assert len(json_fnames) == 1 + json_data = json.load(open(str(self.TEST_DIR / json_fnames[0]))) + assert ( + json_data["url"] + == "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json" + ) + resource_id = os.path.splitext(json_fnames[0])[0] + assert set(os.listdir(str(self.TEST_DIR))) == set( + [json_fnames[0], resource_id, resource_id + ".lock"] + ) + + def test_from_pretrained_no_load_weights_local_config(self): + config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR) + self.clear_test_dir() + + # Save config to file. + local_config_path = str(self.TEST_DIR / "local_config.json") + config.to_json_file(local_config_path, use_diff=False) + + # Now load the model from the local config. + _ = cached_transformers.get( + local_config_path, False, load_weights=False, cache_dir=self.TEST_DIR + ) + # Make sure no other files were downloaded. + assert os.listdir(str(self.TEST_DIR)) == ["local_config.json"] + def test_get_tokenizer_missing_from_cache_local_files_only(self): with pytest.raises((OSError, ValueError)): cached_transformers.get_tokenizer(