Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add way of skipping pretrained weights download (#5172)
Browse files Browse the repository at this point in the history
* add way of skipping pretrained weights download

* clarify docstring

* add link to PR in CHANGELOG
  • Loading branch information
epwalsh authored and dirkgr committed May 10, 2021
1 parent 42e8202 commit b234edd
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 22 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module.
- Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers
such as the `PretrainedTransformerEmbedder` and `PretrainedTransformerMismatchedEmbedder`.
You can do this by setting the parameter `load_weights` to `False`.
See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details.


## Unreleased
Expand Down
36 changes: 33 additions & 3 deletions allennlp/common/cached_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import warnings
from typing import NamedTuple, Optional, Dict, Tuple

import transformers
from transformers import AutoModel, AutoConfig

Expand All @@ -21,6 +23,7 @@ def get(
make_copy: bool,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
load_weights: bool = True,
**kwargs,
) -> transformers.PreTrainedModel:
"""
Expand All @@ -34,18 +37,35 @@ def get(
If this is `True`, return a copy of the model instead of the cached model itself. If you want to modify the
parameters of the model, set this to `True`. If you want only part of the model, set this to `False`, but
make sure to `copy.deepcopy()` the bits you are keeping.
override_weights_file : `str`, optional
override_weights_file : `str`, optional (default = `None`)
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
with `torch.save()`.
override_weights_strip_prefix : `str`, optional
override_weights_strip_prefix : `str`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
load_weights : `bool`, optional (default = `True`)
If set to `False`, no weights will be loaded. This is helpful when you only
want to initialize the architecture, like when you've already fine-tuned a model
and are going to load the weights from a state dict elsewhere.
"""
global _model_cache
spec = TransformerSpec(model_name, override_weights_file, override_weights_strip_prefix)
transformer = _model_cache.get(spec, None)
if transformer is None:
if override_weights_file is not None:
if not load_weights:
if override_weights_file is not None:
warnings.warn(
"You specified an 'override_weights_file' in allennlp.common.cached_transformers.get(), "
"but 'load_weights' is set to False, so 'override_weights_file' will be ignored.",
UserWarning,
)
transformer = AutoModel.from_config(
AutoConfig.from_pretrained(
model_name,
**kwargs,
)
)
elif override_weights_file is not None:
from allennlp.common.file_utils import cached_path
import torch

Expand Down Expand Up @@ -121,3 +141,13 @@ def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer
)
_tokenizer_cache[cache_key] = tokenizer
return tokenizer


def _clear_caches():
"""
Clears in-memory transformer and tokenizer caches.
"""
global _model_cache
global _tokenizer_cache
_model_cache.clear()
_tokenizer_cache.clear()
16 changes: 13 additions & 3 deletions allennlp/modules/seq2vec_encoders/bert_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ class BertPooler(Seq2VecEncoder):
The pretrained BERT model to use. If this is a string,
we will call `transformers.AutoModel.from_pretrained(pretrained_model)`
and use that.
requires_grad : `bool`, optional, (default = `True`)
override_weights_file: `Optional[str]`, optional (default = `None`)
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
with `torch.save()`.
override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
load_weights: `bool`, optional (default = `True`)
Whether to load the pretraiend weights.
requires_grad : `bool`, optional (default = `True`)
If True, the weights of the pooler will be updated during training.
Otherwise they will not.
dropout : `float`, optional, (default = `0.0`)
Expand All @@ -43,6 +51,7 @@ def __init__(
*,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
load_weights: bool = True,
requires_grad: bool = True,
dropout: float = 0.0,
transformer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -54,8 +63,9 @@ def __init__(
model = cached_transformers.get(
pretrained_model,
False,
override_weights_file,
override_weights_strip_prefix,
override_weights_file=override_weights_file,
override_weights_strip_prefix=override_weights_strip_prefix,
load_weights=load_weights,
**(transformer_kwargs or {}),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
When `True` (the default), only the final layer of the pretrained transformer is taken
for the embeddings. But if set to `False`, a scalar mix of all of the layers
is used.
override_weights_file: `Optional[str]`, optional (default = `None`)
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
with `torch.save()`.
override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
load_weights: `bool`, optional (default = `True`)
Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive
it usually makes sense to set this to `False` (via the `overrides` parameter)
to avoid unnecessarily caching and loading the original pretrained weights,
since the archive will already contain all of the weights needed.
gradient_checkpointing: `bool`, optional (default = `None`)
Enable or disable gradient checkpointing.
tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
Expand All @@ -74,6 +85,7 @@ def __init__(
last_layer_only: bool = True,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
load_weights: bool = True,
gradient_checkpointing: Optional[bool] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
transformer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -86,6 +98,7 @@ def __init__(
True,
override_weights_file=override_weights_file,
override_weights_strip_prefix=override_weights_strip_prefix,
load_weights=load_weights,
**(transformer_kwargs or {}),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
When `True` (the default), only the final layer of the pretrained transformer is taken
for the embeddings. But if set to `False`, a scalar mix of all of the layers
is used.
override_weights_file: `Optional[str]`, optional (default = `None`)
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
with `torch.save()`.
override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
load_weights: `bool`, optional (default = `True`)
Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive
it usually makes sense to set this to `False` (via the `overrides` parameter)
to avoid unnecessarily caching and loading the original pretrained weights,
since the archive will already contain all of the weights needed.
gradient_checkpointing: `bool`, optional (default = `None`)
Enable or disable gradient checkpointing.
tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
Expand All @@ -56,6 +67,9 @@ def __init__(
max_length: int = None,
train_parameters: bool = True,
last_layer_only: bool = True,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
load_weights: bool = True,
gradient_checkpointing: Optional[bool] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
transformer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -68,6 +82,9 @@ def __init__(
max_length=max_length,
train_parameters=train_parameters,
last_layer_only=last_layer_only,
override_weights_file=override_weights_file,
override_weights_strip_prefix=override_weights_strip_prefix,
load_weights=load_weights,
gradient_checkpointing=gradient_checkpointing,
tokenizer_kwargs=tokenizer_kwargs,
transformer_kwargs=transformer_kwargs,
Expand Down
32 changes: 22 additions & 10 deletions allennlp/modules/transformer/transformer_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,32 @@ def get_relevant_module(
relevant_module: Optional[Union[str, List[str]]] = None,
source: str = "huggingface",
mapping: Optional[Dict[str, str]] = None,
load_weights: bool = True,
):
"""
Returns the relevant underlying module given a model name/object.
# Parameters:
pretrained_module: Name of the transformer model containing the layer,
or the actual layer (not the model object).
relevant_module: Name of the desired module. Defaults to cls._relevant_module.
source: Where the model came from. Default - huggingface.
mapping: Optional mapping that determines any differences in the module names
between the class modules and the input model's modules. Default - cls._huggingface_mapping
# Parameters
pretrained_module : `Union[str, torch.nn.Module]`
Name of the transformer model containing the layer,
or the actual layer (not the model object).
relevant_module : `Optional[Union[str, List[str]]]`, optional
Name of the desired module. Defaults to cls._relevant_module.
source : `str`, optional
Where the model came from. Default - huggingface.
mapping : `Dict[str, str]`, optional
Optional mapping that determines any differences in the module names
between the class modules and the input model's modules.
Default - cls._huggingface_mapping
load_weights : `bool`, optional
Whether or not to load the pretrained weights.
Default is `True`.
"""
if isinstance(pretrained_module, str):
pretrained_module = cached_transformers.get(pretrained_module, False)
pretrained_module = cached_transformers.get(
pretrained_module, False, load_weights=load_weights
)

relevant_module = relevant_module or cls._relevant_module

Expand Down Expand Up @@ -192,6 +203,7 @@ def from_pretrained_module(
pretrained_module: Union[str, torch.nn.Module],
source: str = "huggingface",
mapping: Optional[Dict[str, str]] = None,
load_weights: bool = True,
**kwargs,
):
"""
Expand All @@ -208,7 +220,7 @@ def from_pretrained_module(
)

pretrained_module = cls.get_relevant_module(
pretrained_module, source=source, mapping=mapping
pretrained_module, source=source, mapping=mapping, load_weights=load_weights
)
final_kwargs = cls._get_input_arguments(pretrained_module, source, mapping)
final_kwargs.update(kwargs)
Expand Down
9 changes: 8 additions & 1 deletion allennlp/modules/transformer/transformer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def from_pretrained_module( # type: ignore
num_hidden_layers: Optional[Union[int, range]] = None,
source="huggingface",
mapping: Optional[Dict[str, str]] = None,
load_weights: bool = True,
**kwargs,
):
final_kwargs = {}
Expand All @@ -185,4 +186,10 @@ def from_pretrained_module( # type: ignore
else:
final_kwargs["num_hidden_layers"] = num_hidden_layers

return super().from_pretrained_module(pretrained_module, source, mapping, **final_kwargs)
return super().from_pretrained_module(
pretrained_module,
source=source,
mapping=mapping,
load_weights=load_weights,
**final_kwargs,
)
59 changes: 54 additions & 5 deletions tests/common/cached_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@


class TestCachedTransformers(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
cached_transformers._clear_caches()

def teardown_method(self):
super().teardown_method()
cached_transformers._clear_caches()

def test_get_missing_from_cache_local_files_only(self):
with pytest.raises((OSError, ValueError)):
cached_transformers.get(
Expand All @@ -19,17 +27,23 @@ def test_get_missing_from_cache_local_files_only(self):
local_files_only=True,
)

def clear_test_dir(self):
for f in os.listdir(str(self.TEST_DIR)):
os.remove(str(self.TEST_DIR) + "/" + f)
assert len(os.listdir(str(self.TEST_DIR))) == 0

def test_from_pretrained_avoids_weights_download_if_override_weights(self):
config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
# only download config because downloading pretrained weights in addition takes too long
transformer = AutoModel.from_config(
AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
)
transformer = AutoModel.from_config(config)

# clear cache directory
for f in os.listdir(str(self.TEST_DIR)):
os.remove(str(self.TEST_DIR) + "/" + f)
assert len(os.listdir(str(self.TEST_DIR))) == 0
self.clear_test_dir()

save_weights_path = str(self.TEST_DIR) + "/bert_weights.pth"
save_weights_path = str(self.TEST_DIR / "bert_weights.pth")
torch.save(transformer.state_dict(), save_weights_path)

override_transformer = cached_transformers.get(
Expand All @@ -44,7 +58,7 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self):
# so this assertion could fail in the future
json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")]
assert len(json_fnames) == 1
json_data = json.load(open(str(self.TEST_DIR) + "/" + json_fnames[0]))
json_data = json.load(open(str(self.TEST_DIR / json_fnames[0])))
assert (
json_data["url"]
== "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json"
Expand All @@ -58,6 +72,41 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self):
for p1, p2 in zip(transformer.parameters(), override_transformer.parameters()):
assert p1.data.ne(p2.data).sum() == 0

def test_from_pretrained_no_load_weights(self):
_ = cached_transformers.get(
"epwalsh/bert-xsmall-dummy", False, load_weights=False, cache_dir=self.TEST_DIR
)
# check that only three files were downloaded (filename.json, filename, filename.lock), for config.json
# if more than three files were downloaded, then model weights were also (incorrectly) downloaded
# NOTE: downloaded files are not explicitly detailed in Huggingface's public API,
# so this assertion could fail in the future
json_fnames = [fname for fname in os.listdir(str(self.TEST_DIR)) if fname.endswith(".json")]
assert len(json_fnames) == 1
json_data = json.load(open(str(self.TEST_DIR / json_fnames[0])))
assert (
json_data["url"]
== "https://huggingface.co/epwalsh/bert-xsmall-dummy/resolve/main/config.json"
)
resource_id = os.path.splitext(json_fnames[0])[0]
assert set(os.listdir(str(self.TEST_DIR))) == set(
[json_fnames[0], resource_id, resource_id + ".lock"]
)

def test_from_pretrained_no_load_weights_local_config(self):
config = AutoConfig.from_pretrained("epwalsh/bert-xsmall-dummy", cache_dir=self.TEST_DIR)
self.clear_test_dir()

# Save config to file.
local_config_path = str(self.TEST_DIR / "local_config.json")
config.to_json_file(local_config_path, use_diff=False)

# Now load the model from the local config.
_ = cached_transformers.get(
local_config_path, False, load_weights=False, cache_dir=self.TEST_DIR
)
# Make sure no other files were downloaded.
assert os.listdir(str(self.TEST_DIR)) == ["local_config.json"]

def test_get_tokenizer_missing_from_cache_local_files_only(self):
with pytest.raises((OSError, ValueError)):
cached_transformers.get_tokenizer(
Expand Down

0 comments on commit b234edd

Please sign in to comment.