Skip to content

Commit

Permalink
Fix loading sampler state dict.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jun 14, 2022
1 parent 53f38c0 commit fd7646d
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 44 deletions.
11 changes: 0 additions & 11 deletions egs/librispeech/ASR/pruned_transducer_stateless/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]

if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]

return saved_params


Expand Down Expand Up @@ -610,13 +607,7 @@ def maybe_log_weights(tag: str):
global_step=params.batch_idx_train,
)

cur_batch_idx = params.get("cur_batch_idx", 0)

for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -664,7 +655,6 @@ def maybe_log_weights(tag: str):
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
Expand All @@ -674,7 +664,6 @@ def maybe_log_weights(tag: str):
sampler=train_dl.sampler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
Expand Down
11 changes: 0 additions & 11 deletions egs/librispeech/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]

if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]

return saved_params


Expand Down Expand Up @@ -661,13 +658,7 @@ def train_one_epoch(

tot_loss = MetricsTracker()

cur_batch_idx = params.get("cur_batch_idx", 0)

for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -702,7 +693,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
Expand All @@ -714,7 +704,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
Expand Down
11 changes: 0 additions & 11 deletions egs/librispeech/ASR/pruned_transducer_stateless4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]

if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]

return saved_params


Expand Down Expand Up @@ -694,13 +691,7 @@ def train_one_epoch(

tot_loss = MetricsTracker()

cur_batch_idx = params.get("cur_batch_idx", 0)

for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -742,7 +733,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
Expand All @@ -755,7 +745,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
Expand Down
11 changes: 0 additions & 11 deletions egs/librispeech/ASR/pruned_transducer_stateless5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]

if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]

return saved_params


Expand Down Expand Up @@ -735,13 +732,7 @@ def train_one_epoch(

tot_loss = MetricsTracker()

cur_batch_idx = params.get("cur_batch_idx", 0)

for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx

params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

Expand Down Expand Up @@ -787,7 +778,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
Expand All @@ -800,7 +790,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
Expand Down

0 comments on commit fd7646d

Please sign in to comment.