Skip to content

Commit

Permalink
Flax Remat for LongT5 (huggingface#17994)
Browse files Browse the repository at this point in the history
* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* add gradient_checkpointing to examples

* Add gradient_checkpointing to run_mlm_flax

* Add remat to longt5

* Add gradient checkpointing test longt5

* Fix args errors

* Fix remaining tests

* Make fixup & quality fixes

* replace kwargs

* remove unecessary kwargs

* Make fixup changes

* revert long_t5_flax changes

* Remove return_dict and copy to LongT5

* Remove test_gradient_checkpointing

Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
  • Loading branch information
2 people authored and amyeroberts committed Aug 17, 2022
1 parent af92441 commit 2287492
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 39 deletions.
9 changes: 9 additions & 0 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)

def __post_init__(self):
if self.output_dir is not None:
Expand Down Expand Up @@ -640,6 +646,9 @@ def group_texts(examples):
dtype=getattr(jnp, model_args.dtype),
)

if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()

# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
Expand Down
9 changes: 9 additions & 0 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)

def __post_init__(self):
if self.output_dir is not None:
Expand Down Expand Up @@ -535,6 +541,9 @@ def main():
dtype=getattr(jnp, model_args.dtype),
)

if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()

if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

Expand Down
75 changes: 57 additions & 18 deletions src/transformers/models/longt5/modeling_flax_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
Expand Down Expand Up @@ -53,6 +54,8 @@
_CONFIG_FOR_DOC = "LongT5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"

remat = nn_partitioning.remat


# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
Expand Down Expand Up @@ -1356,7 +1359,6 @@ def __call__(
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
output_attentions=False,
return_dict=True,
deterministic=True,
init_cache=False,
):
Expand All @@ -1377,13 +1379,31 @@ def __call__(
class FlaxLongT5BlockCollection(nn.Module):
config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.causal = self.config.causal
self.blocks = [
FlaxLongT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i))
for i in range(self.config.num_layers)
]
if self.gradient_checkpointing:
FlaxLongT5CheckpointLayer = remat(FlaxLongT5LayerCollection, static_argnums=(6, 7, 8))
self.blocks = [
FlaxLongT5CheckpointLayer(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]
else:
self.blocks = [
FlaxLongT5LayerCollection(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
)
for i in range(self.config.num_layers)
]

def __call__(
self,
Expand All @@ -1409,14 +1429,14 @@ def __call__(

layer_outputs = layer_module(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
attention_mask,
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
output_attentions,
deterministic,
init_cache,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1447,11 +1467,14 @@ class FlaxLongT5Stack(nn.Module):
config: LongT5Config
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.causal = self.config.causal

self.block = FlaxLongT5BlockCollection(self.config, dtype=self.dtype)
self.block = FlaxLongT5BlockCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.final_layer_norm = FlaxLongT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
Expand Down Expand Up @@ -1989,6 +2012,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
class FlaxLongT5Module(nn.Module):
config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def _get_encoder_module(self):
return self.encoder
Expand All @@ -2005,12 +2029,22 @@ def setup(self):

encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
self.encoder = FlaxLongT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.encoder = FlaxLongT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)

decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxLongT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.decoder = FlaxLongT5Stack(
decoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)

def __call__(
self,
Expand Down Expand Up @@ -2104,6 +2138,7 @@ class FlaxLongT5Model(FlaxLongT5PreTrainedModel):
class FlaxLongT5ForConditionalGenerationModule(nn.Module):
config: LongT5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def _get_encoder_module(self):
return self.encoder
Expand All @@ -2124,13 +2159,17 @@ def setup(self):
encoder_config.causal = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = FlaxLongT5Stack(encoder_config, self.shared, dtype=self.dtype)
self.encoder = FlaxLongT5Stack(
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)

decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxLongT5Stack(decoder_config, self.shared, dtype=self.dtype)
self.decoder = FlaxLongT5Stack(
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)

self.lm_head = nn.Dense(
self.config.vocab_size,
Expand Down
Loading

0 comments on commit 2287492

Please sign in to comment.