Skip to content

Commit

Permalink
Refactor RobertaModel base class (fixes #2186)
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed May 27, 2020
1 parent 95294bf commit 307df56
Showing 1 changed file with 28 additions and 19 deletions.
47 changes: 28 additions & 19 deletions fairseq/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from fairseq import utils
from fairseq.models import (
FairseqDecoder,
FairseqLanguageModel,
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
Expand All @@ -33,7 +33,7 @@


@register_model('roberta')
class RobertaModel(FairseqLanguageModel):
class RobertaModel(FairseqEncoderModel):

@classmethod
def hub_models(cls):
Expand Down Expand Up @@ -116,12 +116,20 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, cla
if classification_head_name is not None:
features_only = True

x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs)
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)

if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra

def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)

def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head."""
if name in self.classification_heads:
Expand Down Expand Up @@ -163,13 +171,23 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na
return RobertaHubInterface(x['args'], x['task'], x['models'][0])

def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)

prefix = name + '.' if name != '' else ''
current_head_names = [] if not hasattr(self, 'classification_heads') else \
self.classification_heads.keys()

# rename decoder -> encoder before upgrading children modules
for k in list(state_dict.keys()):
if k.startswith(prefix + 'decoder'):
new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):]
state_dict[new_k] = state_dict[k]
del state_dict[k]

# upgrade children modules
super().upgrade_state_dict_named(state_dict, name)

# Handle new classification heads present in the state dict.
current_head_names = (
[] if not hasattr(self, 'classification_heads')
else self.classification_heads.keys()
)
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'):
Expand Down Expand Up @@ -261,24 +279,15 @@ def forward(self, features, **kwargs):
return x


class RobertaEncoder(FairseqDecoder):
"""RoBERTa encoder.
Implements the :class:`~fairseq.models.FairseqDecoder` interface required
by :class:`~fairseq.models.FairseqLanguageModel`.
"""
class RobertaEncoder(FairseqEncoder):
"""RoBERTa encoder."""

def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args

# RoBERTa is a sentence encoder model, so users will intuitively trim
# encoder layers. However, the implementation uses the fairseq decoder,
# so we fix here.
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
args.decoder_layers_to_keep = args.encoder_layers_to_keep
args.encoder_layers_to_keep = None

self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=dictionary.pad(),
Expand Down

0 comments on commit 307df56

Please sign in to comment.