Skip to content

Commit

Permalink
Update base for Update on "[BE][6/n] replace large c4_mini datasets b…
Browse files Browse the repository at this point in the history
…y c4_test with the first 2K entries"


`c4_mini` is 100MB large and makes repo clone slow. Since we already have the original dataset `c4`, let's remove redundancy and only keey a minimal dataset for testing (even offline). For loss convergence testing, we can use the full `c4`.

`c4_test` (2K entries, <5MB size) is now put under `test/assets`, together with the test tokenizer. It can cover the first 10 iterations of debug model without repetition.

After this PR lands, we should do a history rewriting to remove `c4_mini` entirely from history, to avoid repo clone overhead.

[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Aug 8, 2024
1 parent f58ca70 commit 4926495
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def main(job_config: JobConfig):
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)

# loss function to be shared by Pipeline Parallel and spmd training
# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
Expand All @@ -150,7 +150,7 @@ def loss_fn(pred, labels):
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply spmd-style PT-D techniques
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)

# In PP, we cannot call init_weights directly because some layers are missing.
Expand Down Expand Up @@ -269,7 +269,7 @@ def loss_fn(pred, labels):
optimizers.zero_grad()

if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
Expand Down

0 comments on commit 4926495

Please sign in to comment.