Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Use HF Transformers output types (#5035)
Browse files Browse the repository at this point in the history
* Use HF Transformers output types

* Update changelog

* Remove a comment that is no longer relevant
  • Loading branch information
JohnGiorgi authored Mar 4, 2021
1 parent 0c36019 commit 96415b2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `histogram_interval` parameter is now deprecated in `TensorboardWriter`, please use `distribution_interval` instead.
- Memory usage is not logged in tensorboard during training now. `ConsoleLoggerCallback` should be used instead.
- Use attributes of `ModelOutputs` object in `PretrainedTransformerEmbedder` instead of indexing.

### Added

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,12 @@ def forward(

transformer_output = self.transformer_model(**parameters)
if self._scalar_mix is not None:
# As far as I can tell, the hidden states will always be the last element
# in the output tuple as long as the model is not also configured to return
# attention scores.
# See, for example, the return value description for BERT:
# https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel.forward
# These hidden states will also include the embedding layer, which we don't
# The hidden states will also include the embedding layer, which we don't
# include in the scalar mix. Hence the `[1:]` slicing.
hidden_states = transformer_output[-1][1:]
hidden_states = transformer_output.hidden_states[1:]
embeddings = self._scalar_mix(hidden_states)
else:
embeddings = transformer_output[0]
embeddings = transformer_output.last_hidden_state

if fold_long_sequences:
embeddings = self._unfold_long_sequences(
Expand Down

0 comments on commit 96415b2

Please sign in to comment.