Skip to content

Commit

Permalink
propaga changes from k2-fsa#525 to other librispeech recipes
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Aug 17, 2022
1 parent 6694018 commit 4c6a2b7
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 6 deletions.
6 changes: 4 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def forward(
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
reduction: str = "sum",
) -> torch.Tensor:
"""
Args:
Expand Down Expand Up @@ -95,6 +96,7 @@ def forward(
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
Expand Down Expand Up @@ -136,7 +138,7 @@ def forward(
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)

Expand All @@ -163,7 +165,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)

return (simple_loss, pruned_loss)
33 changes: 33 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,14 +544,47 @@ def compute_loss(
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]

# If the batch contains more than 10 utterance AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)

simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()

loss = params.simple_loss_scale * simple_loss + pruned_loss

assert loss.requires_grad == is_training

info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
Expand Down
32 changes: 32 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,35 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]

# If the batch contains more than 10 utterance AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)

simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
Expand All @@ -620,6 +648,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
Expand Down
10 changes: 8 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def forward(
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
) -> torch.Tensor:
"""
Args:
Expand All @@ -131,6 +132,10 @@ def forward(
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
Returns:
Return the transducer loss.
Expand All @@ -140,6 +145,7 @@ def forward(
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
Expand Down Expand Up @@ -196,7 +202,7 @@ def forward(
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)

Expand Down Expand Up @@ -229,7 +235,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)

return (simple_loss, pruned_loss)
32 changes: 32 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,35 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]

# If the batch contains more than 10 utterance AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)

simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
Expand All @@ -657,6 +685,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
Expand Down
32 changes: 32 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,35 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]

# If the batch contains more than 10 utterance AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)

simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
Expand All @@ -650,6 +678,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
Expand Down
10 changes: 8 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless6/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def forward(
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
reudction: str = "sum",
codebook_indexes: torch.Tensor = None,
) -> torch.Tensor:
"""
Expand All @@ -113,6 +114,10 @@ def forward(
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
codebook_indexes:
codebook_indexes extracted from a teacher model.
Returns:
Expand All @@ -124,6 +129,7 @@ def forward(
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
Expand Down Expand Up @@ -184,7 +190,7 @@ def forward(
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)

Expand Down Expand Up @@ -217,7 +223,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)

return (simple_loss, pruned_loss, codebook_loss)
Expand Down
31 changes: 31 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless6/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,35 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
codebook_indexes=codebook_indexes,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If the batch contains more than 10 utterance AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if feature.size(0) >= 10:
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)

simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
Expand All @@ -654,6 +681,10 @@ def compute_loss(

with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
Expand Down

0 comments on commit 4c6a2b7

Please sign in to comment.