Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax Remat for LongT5 #17994

Merged
merged 21 commits into from
Aug 14, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)

def __post_init__(self):
if self.output_dir is not None:
Expand Down Expand Up @@ -641,6 +647,9 @@ def group_texts(examples):
use_auth_token=True if model_args.use_auth_token else None,
)

if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()

# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
Expand Down
9 changes: 9 additions & 0 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)

def __post_init__(self):
if self.output_dir is not None:
Expand Down Expand Up @@ -535,6 +541,9 @@ def main():
dtype=getattr(jnp, model_args.dtype),
)

if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()

if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

Expand Down
103 changes: 82 additions & 21 deletions src/transformers/models/t5/modeling_flax_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
Expand Down Expand Up @@ -53,6 +54,8 @@
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"

remat = nn_partitioning.remat


# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
Expand Down Expand Up @@ -478,7 +481,9 @@ def __call__(
self,
hidden_states,
attention_mask=None,
key_value_states=None,
KMFODA marked this conversation as resolved.
Show resolved Hide resolved
position_bias=None,
use_cache=False,
output_attentions=False,
deterministic=True,
init_cache=False,
Expand Down Expand Up @@ -607,6 +612,7 @@ class FlaxT5LayerCollection(nn.Module):
config: T5Config
has_relative_attention_bias: bool
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
KMFODA marked this conversation as resolved.
Show resolved Hide resolved

def setup(self):
self.layer = FlaxT5Block(
Expand Down Expand Up @@ -642,13 +648,33 @@ def __call__(
class FlaxT5BlockCollection(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.causal = self.config.causal
self.blocks = [
FlaxT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i))
for i in range(self.config.num_layers)
]
if self.gradient_checkpointing:
FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8, 9))
self.blocks = [
FlaxT5CheckpointLayer(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
gradient_checkpointing=self.gradient_checkpointing,
)
for i in range(self.config.num_layers)
]
else:
self.blocks = [
FlaxT5LayerCollection(
self.config,
has_relative_attention_bias=(i == 0),
dtype=self.dtype,
name=str(i),
gradient_checkpointing=self.gradient_checkpointing,
)
for i in range(self.config.num_layers)
]

def __call__(
self,
Expand All @@ -658,6 +684,7 @@ def __call__(
encoder_attention_mask=None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
KMFODA marked this conversation as resolved.
Show resolved Hide resolved
deterministic: bool = True,
init_cache: bool = False,
):
Expand All @@ -674,14 +701,15 @@ def __call__(

layer_outputs = layer_module(
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
output_attentions=output_attentions,
deterministic=deterministic,
init_cache=init_cache,
attention_mask,
KMFODA marked this conversation as resolved.
Show resolved Hide resolved
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
output_attentions,
return_dict,
KMFODA marked this conversation as resolved.
Show resolved Hide resolved
deterministic,
init_cache,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -711,11 +739,14 @@ class FlaxT5Stack(nn.Module):
config: T5Config
embed_tokens: nn.Embed
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.causal = self.config.causal

self.block = FlaxT5BlockCollection(self.config, dtype=self.dtype)
self.block = FlaxT5BlockCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.final_layer_norm = FlaxT5LayerNorm(
self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
)
Expand Down Expand Up @@ -919,11 +950,19 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
Expand Down Expand Up @@ -1248,6 +1287,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
class FlaxT5Module(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def _get_encoder_module(self):
return self.encoder
Expand All @@ -1264,12 +1304,22 @@ def setup(self):

encoder_config = copy.deepcopy(self.config)
encoder_config.causal = False
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)

decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.decoder = FlaxT5Stack(
decoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)

def __call__(
self,
Expand All @@ -1280,7 +1330,7 @@ def __call__(
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_dict: bool = True,
KMFODA marked this conversation as resolved.
Show resolved Hide resolved
deterministic: bool = True,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down Expand Up @@ -1364,6 +1414,7 @@ class FlaxT5Model(FlaxT5PreTrainedModel):
class FlaxT5EncoderModule(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.shared = nn.Embed(
Expand All @@ -1376,15 +1427,20 @@ def setup(self):
encoder_config.is_decoder = False
encoder_config.is_encoder_decoder = False
encoder_config.causal = False
self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype)
self.encoder = FlaxT5Stack(
encoder_config,
embed_tokens=self.shared,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)

def __call__(
self,
input_ids=None,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
return_dict: bool = True,
deterministic: bool = True,
):

Expand Down Expand Up @@ -1445,6 +1501,7 @@ def __call__(
class FlaxT5ForConditionalGenerationModule(nn.Module):
config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def _get_encoder_module(self):
return self.encoder
Expand All @@ -1465,13 +1522,17 @@ def setup(self):
encoder_config.causal = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = FlaxT5Stack(encoder_config, self.shared, dtype=self.dtype)
self.encoder = FlaxT5Stack(
encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)

decoder_config = copy.deepcopy(self.config)
decoder_config.causal = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.config.num_decoder_layers
self.decoder = FlaxT5Stack(decoder_config, self.shared, dtype=self.dtype)
self.decoder = FlaxT5Stack(
decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)

self.lm_head = nn.Dense(
self.config.vocab_size,
Expand Down
24 changes: 24 additions & 0 deletions tests/models/t5/test_modeling_flax_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,30 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

def test_gradient_checkpointing(self):
KMFODA marked this conversation as resolved.
Show resolved Hide resolved
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# breakpoint()
model = model_class(config)
remat_model = model_class(config)
remat_model.enable_gradient_checkpointing()

outputs = model(**prepared_inputs_dict)
remat_outputs = remat_model(**prepared_inputs_dict)

# ensure that the dicts of outputs contain the same keys
self.assertEqual(outputs.keys(), remat_outputs.keys())

outputs = outputs.to_tuple()
remat_outputs = remat_outputs.to_tuple()

# ensure that the outputs remain precisely equal
for output, remat_output in zip(outputs, remat_outputs):
self.assertTrue((output == remat_output).all())

def test_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# check that gated gelu feed forward and different word embeddings work
Expand Down