Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter non-finite losses #525

Merged
merged 2 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def forward(
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
) -> torch.Tensor:
"""
Args:
Expand All @@ -101,6 +102,10 @@ def forward(
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
Returns:
Return the transducer loss.

Expand All @@ -110,6 +115,7 @@ def forward(
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
Expand Down Expand Up @@ -155,7 +161,7 @@ def forward(
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)

Expand Down Expand Up @@ -188,7 +194,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)

return (simple_loss, pruned_loss)
32 changes: 32 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,35 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads up: I was facing the same logical indexing issue here that we had seen earlier, so I had to replace this with the torch.where(). I suppose this might just be a CUDA 11.1 issue after all.

pruned_loss = pruned_loss[pruned_loss_is_finite]

# If the batch contains more than 10 utterance AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)

simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
Expand All @@ -675,6 +703,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
Expand Down