Skip to content

Commit

Permalink
add use_running_average_bn arg for jax
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Oct 17, 2024
1 parent b24812f commit f574bf0
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 31 deletions.
9 changes: 7 additions & 2 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ class ResNet(nn.Module):
@nn.compact
def __call__(self,
x: spec.Tensor,
update_batch_norm: bool = True) -> spec.Tensor:
update_batch_norm: bool = True,
use_running_average_bn: bool = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
use_running_average=use_running_average_bn,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
Expand Down
9 changes: 6 additions & 3 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand All @@ -119,14 +120,16 @@ def model_fn(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,16 @@ class ResNet(nn.Module):
@nn.compact
def __call__(self,
x: spec.Tensor,
update_batch_norm: bool = True) -> spec.Tensor:
update_batch_norm: bool = True,
use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)

# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
use_running_average=use_running_average_bn,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
Expand All @@ -157,14 +158,16 @@ def model_fn(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return logits, new_model_state
else:
logits = self._model.apply(
variables,
augmented_and_preprocessed_input_batch['inputs'],
update_batch_norm=update_batch_norm,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return logits, model_state

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,15 +454,20 @@ def setup(self):
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn):
rank = inputs.ndim
reduce_over_dims = list(range(0, rank - 1))

padding = jnp.expand_dims(input_paddings, -1)
momentum = self.config.batch_norm_momentum
epsilon = self.config.batch_norm_epsilon

if train:
if use_running_average_bn:
mean = self.ra_mean.value
var = self.ra_var.value

else:
# compute batch statistics
mask = 1.0 - padding
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
count_v = jnp.sum(
Expand All @@ -477,17 +482,14 @@ def __call__(self, inputs, input_paddings, train):
keepdims=True)

var = sum_vv / count_v

self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var
else:
mean = self.ra_mean.value
var = self.ra_var.value


if update_batch_norm:
self.ra_mean.value = momentum * \
self.ra_mean.value + (1 - momentum) * mean
self.ra_var.value = momentum * \
self.ra_var.value + (1 - momentum) * var

inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)

bn_output = (inputs - mean) * inv + self.beta
bn_output *= 1.0 - padding

Expand Down Expand Up @@ -517,7 +519,7 @@ class ConvolutionBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)

Expand Down Expand Up @@ -546,7 +548,7 @@ def __call__(self, inputs, input_paddings, train):
kernel_init=nn.initializers.xavier_uniform())(
inputs)

inputs = BatchNorm(config)(inputs, input_paddings, train)
inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
Expand Down Expand Up @@ -586,7 +588,7 @@ class ConformerBlock(nn.Module):
config: ConformerConfig

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)

Expand All @@ -597,7 +599,7 @@ def __call__(self, inputs, input_paddings, train):
inputs, input_paddings, train)

inputs = inputs + \
ConvolutionBlock(config)(inputs, input_paddings, train)
ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average)

inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
inputs, padding_mask, train)
Expand Down Expand Up @@ -629,12 +631,23 @@ def setup(self):
.use_dynamic_time_mask_max_frames)

@nn.compact
def __call__(self, inputs, input_paddings, train):
def __call__(self,
inputs,
input_paddings,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None):
config = self.config

outputs = inputs
output_paddings = input_paddings

# Set BN args if not supplied for backwards compatibility
if update_batch_norm is None:
update_batch_norm = train
if use_running_average_bn is None:
use_running_average_bn = not train

# Compute normalized log mel spectrograms from input audio signal.
preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
Expand All @@ -660,7 +673,7 @@ def __call__(self, inputs, input_paddings, train):

# Run the conformer encoder layers.
for _ in range(config.num_encoder_layers):
outputs = ConformerBlock(config)(outputs, output_paddings, train)
outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn)

outputs = LayerNorm(config.encoder_dim)(outputs)
# Run the decoder which in this case is a trivial projection layer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
update_batch_norm: bool,
use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
Expand All @@ -118,15 +119,17 @@ def model_fn(
input_paddings,
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'])
mutable=['batch_stats'],
use_running_average_bn=use_running_average_bn)
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
variables,
inputs,
input_paddings,
train=False,
mutable=False)
mutable=False,
use_running_average_bn=use_running_average_bn)
return (logits, logit_paddings), model_state

def _build_input_queue(
Expand Down

0 comments on commit f574bf0

Please sign in to comment.