Skip to content

Commit

Permalink
[SPMD] input sharding for both train & test data loading in imagenet …
Browse files Browse the repository at this point in the history
…example (#6515)
  • Loading branch information
yeounoh committed Feb 10, 2024
1 parent 0fa24a1 commit a5692c2
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions test/spmd/test_train_spmd_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,13 @@ def train_imagenet():
loader_prefetch_size=FLAGS.loader_prefetch_size,
device_prefetch_size=FLAGS.device_prefetch_size,
host_to_device_transfer_threads=FLAGS.host_to_device_transfer_threads)
test_loader = pl.MpDeviceLoader(
test_loader,
device,
input_sharding=xs.ShardingSpec(input_mesh, (0, 1, 2, 3)),
loader_prefetch_size=FLAGS.loader_prefetch_size,
device_prefetch_size=FLAGS.device_prefetch_size,
host_to_device_transfer_threads=FLAGS.host_to_device_transfer_threads)

writer = None
if xm.is_master_ordinal():
Expand Down

0 comments on commit a5692c2

Please sign in to comment.