Skip to content

Commit

Permalink
BN Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adefazio committed Sep 5, 2024
1 parent bdece3b commit b24812f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
14 changes: 8 additions & 6 deletions algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def sync_ddp_time(time: float, device: torch.device) -> float:
dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX)
return time_tensor.item()


def update_batch_norm_fn(module: spec.ParameterContainer,
update_batch_norm: bool) -> None:
bn_layers = (
Expand All @@ -67,10 +66,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
)
if isinstance(module, bn_layers):
if not update_batch_norm:
module.eval()
module.momentum_backup = module.momentum
if not hasattr(module, 'momentum_backup'):
module.momentum_backup = module.momentum

# module.momentum can be float or torch.Tensor.
module.momentum = 0. * module.momentum_backup
if torch.is_tensor(module.momentum_backup):
module.momentum = torch.zeros_like(module.momentum_backup)
else:
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup
module.track_running_stats = update_batch_norm
module.momentum = module.momentum_backup
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ConformerConfig:
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
input_dropout_rate: float = 0.1
batch_norm_momentum: float = 0.999
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
attention_temperature: float = 1.0
Expand Down Expand Up @@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings):
mean = (masked_inp).sum(dim=(0, 1)) / count
var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count

self.running_mean = self.momentum * self.running_mean + (
1 - self.momentum) * mean.detach()
self.running_var = self.momentum * self.running_var + (
1 - self.momentum) * var.detach()
self.running_mean = (1 - self.momentum) * self.running_mean + (
self.momentum) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
self.momentum) * var.detach()

else:
mean = self.running_mean
var = self.running_var
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DeepspeechConfig:
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 0.999
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
# If None, defaults to 0.1.
input_dropout_rate: Optional[float] = 0.1
Expand Down Expand Up @@ -264,10 +264,10 @@ def forward(self, inputs, input_paddings):
sum_ = dist_nn.all_reduce(sum_)
var = sum_ / count

self.running_mean = self.momentum * self.running_mean + (
1 - self.momentum) * mean.detach()
self.running_var = self.momentum * self.running_var + (
1 - self.momentum) * var.detach()
self.running_mean = (1 - self.momentum) * self.running_mean + (
self.momentum) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
self.momentum) * var.detach()
else:
mean = self.running_mean
var = self.running_var
Expand Down

0 comments on commit b24812f

Please sign in to comment.