Skip to content

Commit

Permalink
Fix DPP + SyncBN (Lightning-AI#6838)
Browse files Browse the repository at this point in the history
* Fix DPP + SyncBN

Ensure that model is already on correct GPU before applying SyncBN conversion

* Fix order of SyncBN for ddp_spawn
  • Loading branch information
BloodAxe authored and kaushikb11 committed Apr 6, 2021
1 parent 31b2d2b commit 215a9c9
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ def pre_dispatch(self):
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

# move the model to the correct device
self.model_to_device()

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

self.configure_ddp()

self.barrier()
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ def new_process(self, process_idx, trainer, mp_queue):
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

# move the model to the correct device
self.model_to_device()

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

self.configure_ddp()

self.barrier()
Expand Down
3 changes: 0 additions & 3 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,8 +721,6 @@ def __len__(self):
assert has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.validate(model, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
Expand All @@ -735,7 +733,6 @@ def __len__(self):
assert not has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.validate(model, val_dataloaders=dataloader)
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
trainer.test(model, test_dataloaders=dataloader)
trainer.predict(model, dataloaders=dataloader)
Expand Down

0 comments on commit 215a9c9

Please sign in to comment.