Skip to content

Commit

Permalink
[Flax] Add remat (gradient checkpointing) (huggingface#17843)
Browse files Browse the repository at this point in the history
* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* make fix-copies

* fix big-bird, electra, roberta

* cookie-cutter

* fix flax big-bird

* move test to common
  • Loading branch information
sanchit-gandhi authored and viclzhu committed Jul 18, 2022
1 parent 14af6d4 commit e6ae1f2
Show file tree
Hide file tree
Showing 7 changed files with 414 additions and 96 deletions.
3 changes: 3 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def __init__(
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")

def enable_gradient_checkpointing(self):
raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")

@classmethod
def _from_config(cls, config, **kwargs):
"""
Expand Down
118 changes: 98 additions & 20 deletions src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
Expand Down Expand Up @@ -56,6 +57,8 @@
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"

remat = nn_partitioning.remat


@flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):
Expand Down Expand Up @@ -544,11 +547,19 @@ def __call__(
class FlaxBertLayerCollection(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layers = [
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
if self.gradient_checkpointing:
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
self.layers = [
FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]

def __call__(
self,
Expand Down Expand Up @@ -582,12 +593,12 @@ def __call__(
layer_outputs = layer(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -617,9 +628,14 @@ def __call__(
class FlaxBertEncoder(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
self.layer = FlaxBertLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)

def __call__(
self,
Expand Down Expand Up @@ -756,11 +772,24 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(
config=config,
dtype=dtype,
gradient_checkpointing=gradient_checkpointing,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
Expand Down Expand Up @@ -925,10 +954,15 @@ class FlaxBertModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
gradient_checkpointing: bool = False

def setup(self):
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1003,9 +1037,14 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
class FlaxBertForPreTrainingModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1099,9 +1138,15 @@ class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
class FlaxBertForMaskedLMModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1161,9 +1206,14 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
class FlaxBertForNextSentencePredictionModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1248,9 +1298,14 @@ class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
class FlaxBertForSequenceClassificationModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
Expand Down Expand Up @@ -1324,9 +1379,14 @@ class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
class FlaxBertForMultipleChoiceModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)

Expand Down Expand Up @@ -1399,9 +1459,15 @@ class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
class FlaxBertForTokenClassificationModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
Expand Down Expand Up @@ -1468,9 +1534,15 @@ class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
class FlaxBertForQuestionAnsweringModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1539,9 +1611,15 @@ class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
class FlaxBertForCausalLMModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)

def __call__(
Expand Down
Loading

0 comments on commit e6ae1f2

Please sign in to comment.