Skip to content

Commit

Permalink
Adding gradient_checkpointing to Flax Whisper
Browse files Browse the repository at this point in the history
It uses `flax.linen.remat` and follows on PRs huggingface#13657 and huggingface#17994
  • Loading branch information
versae authored Apr 20, 2023
1 parent 6dc0a84 commit 1eefc67
Showing 1 changed file with 45 additions and 13 deletions.
58 changes: 45 additions & 13 deletions src/transformers/models/whisper/modeling_flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
Expand Down Expand Up @@ -53,6 +54,8 @@
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_CONFIG_FOR_DOC = "WhisperConfig"

remat = nn_partitioning.remat


WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down Expand Up @@ -391,12 +394,20 @@ def __call__(
class FlaxWhisperEncoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
if self.gradient_checkpointing:
FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3))
self.layers = [
FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
else:
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop

def __call__(
Expand Down Expand Up @@ -535,12 +546,20 @@ def __call__(
class FlaxWhisperDecoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
if self.gradient_checkpointing:
FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6))
self.layers = [
FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
else:
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.decoder_layerdrop

def __call__(
Expand Down Expand Up @@ -605,6 +624,7 @@ def __call__(
class FlaxWhisperEncoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.conv1 = nn.Conv(
Expand All @@ -628,6 +648,7 @@ def setup(self) -> None:
self.layers = FlaxWhisperEncoderLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)

Expand Down Expand Up @@ -689,12 +710,13 @@ def __call__(
class FlaxWhisperDecoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)

self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype)
self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)

self.dropout_layer = nn.Dropout(rate=self.config.dropout)

Expand Down Expand Up @@ -753,10 +775,11 @@ def __call__(
class FlaxWhisperModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype)
self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype)
self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)

def __call__(
self,
Expand Down Expand Up @@ -821,11 +844,19 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
Expand Down Expand Up @@ -1137,9 +1168,10 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
class FlaxWhisperForConditionalGenerationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype)
self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
Expand Down

0 comments on commit 1eefc67

Please sign in to comment.