Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add way to initialize SrlBert without pretrained BERT weights #257

Merged
merged 3 commits into from
May 2, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}-v2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache was corrupted for some reason.


- name: Install requirements
run: |
Expand Down Expand Up @@ -192,7 +192,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}-v2

- name: Install requirements
run: |
Expand Down Expand Up @@ -336,7 +336,7 @@ jobs:
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}
key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev-requirements.txt') }}-v2

- name: Install requirements
run: |
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added tests for checklist suites for SQuAD-style reading comprehension models (`bidaf`), and textual entailment models (`decomposable_attention` and `esim`).
- Added a way to initialize the `SrlBert` model without caching/loading pretrained transformer weights.
You need to set the `bert_model` parameter to the dictionary form of the corresponding `BertConfig` from HuggingFace.
See [PR #257](https://github.com/allenai/allennlp-models/pull/257) for more details.


## [v2.4.0](https://github.com/allenai/allennlp-models/releases/tag/v2.4.0) - 2021-04-22
Expand Down
28 changes: 25 additions & 3 deletions allennlp_models/structured_prediction/models/srl_bert.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import warnings
from typing import Dict, List, Any, Union

from overrides import overrides
import torch
from torch.nn.modules import Linear, Dropout
import torch.nn.functional as F
from transformers.models.bert.configuration_bert import BertConfig
from transformers.models.bert.modeling_bert import BertModel

from allennlp.data import TextFieldTensors, Vocabulary
Expand Down Expand Up @@ -31,14 +33,26 @@ class SrlBert(Model):

vocab : `Vocabulary`, required
A Vocabulary, required in order to compute sizes for input/output projections.
model : `Union[str, BertModel]`, required.
A string describing the BERT model to load or an already constructed BertModel.

bert_model : `Union[str, Dict[str, Any], BertModel]`, required.
A string describing the BERT model to load, a BERT config in the form of a dictionary,
or an already constructed BertModel.

!!! Note
If you pass a config `bert_model` (a dictionary), pretrained weights will
not be cached and loaded! This is ideal if you're loading this model from an
AllenNLP archive since the weights you need will already be included in the
archive, but not what you want if you're training.

initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
Used to initialize the model parameters.

label_smoothing : `float`, optional (default = `0.0`)
Whether or not to use label smoothing on the labels when computing cross entropy loss.

ignore_span_metric : `bool`, optional (default = `False`)
Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.

srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
Expand All @@ -47,7 +61,7 @@ class SrlBert(Model):
def __init__(
self,
vocab: Vocabulary,
bert_model: Union[str, BertModel],
bert_model: Union[str, Dict[str, Any], BertModel],
embedding_dropout: float = 0.0,
initializer: InitializerApplicator = InitializerApplicator(),
label_smoothing: float = None,
Expand All @@ -59,6 +73,14 @@ def __init__(

if isinstance(bert_model, str):
self.bert_model = BertModel.from_pretrained(bert_model)
elif isinstance(bert_model, dict):
warnings.warn(
"Initializing BertModel without pretrained weights. This is fine if you're loading "
"from an AllenNLP archive, but not if you're training.",
UserWarning,
)
bert_config = BertConfig.from_dict(bert_model)
self.bert_model = BertModel(bert_config)
else:
self.bert_model = bert_model

Expand Down