Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Add remat (gradient checkpointing) #17843

Merged
merged 11 commits into from
Jul 1, 2022
3 changes: 3 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def __init__(
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}")

def enable_gradient_checkpointing(self):
raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")

@classmethod
def _from_config(cls, config, **kwargs):
"""
Expand Down
118 changes: 98 additions & 20 deletions src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
Expand Down Expand Up @@ -56,6 +57,8 @@
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"

remat = nn_partitioning.remat


@flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):
Expand Down Expand Up @@ -544,11 +547,19 @@ def __call__(
class FlaxBertLayerCollection(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layers = [
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
if self.gradient_checkpointing:
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
self.layers = [
FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
else:
self.layers = [
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]

def __call__(
self,
Expand Down Expand Up @@ -582,12 +593,12 @@ def __call__(
layer_outputs = layer(
hidden_states,
attention_mask,
layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
Comment on lines -585 to +601
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: remat does not support kwargs, hence the need to change to args

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -617,9 +628,14 @@ def __call__(
class FlaxBertEncoder(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
self.layer = FlaxBertLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)

def __call__(
self,
Expand Down Expand Up @@ -756,11 +772,24 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(
config=config,
dtype=dtype,
gradient_checkpointing=gradient_checkpointing,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
Expand Down Expand Up @@ -925,10 +954,15 @@ class FlaxBertModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
add_pooling_layer: bool = True
gradient_checkpointing: bool = False

def setup(self):
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1003,9 +1037,14 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
class FlaxBertForPreTrainingModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1099,9 +1138,15 @@ class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
class FlaxBertForMaskedLMModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1161,9 +1206,14 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
class FlaxBertForNextSentencePredictionModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1248,9 +1298,14 @@ class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
class FlaxBertForSequenceClassificationModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
Expand Down Expand Up @@ -1324,9 +1379,14 @@ class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
class FlaxBertForMultipleChoiceModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.classifier = nn.Dense(1, dtype=self.dtype)

Expand Down Expand Up @@ -1399,9 +1459,15 @@ class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
class FlaxBertForTokenClassificationModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
classifier_dropout = (
self.config.classifier_dropout
if self.config.classifier_dropout is not None
Expand Down Expand Up @@ -1468,9 +1534,15 @@ class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
class FlaxBertForQuestionAnsweringModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
self.bert = FlaxBertModule(
config=self.config,
dtype=self.dtype,
add_pooling_layer=False,
gradient_checkpointing=self.gradient_checkpointing,
)
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)

def __call__(
Expand Down Expand Up @@ -1539,9 +1611,15 @@ class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
class FlaxBertForCausalLMModule(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self):
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)

def __call__(
Expand Down
Loading