You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When validation/training is used (as default with the boring model) sharded crashes. This is because internally SDP relies on knowing the training state of the model, and when we run the validation sanity check, we do not set the eval mode correctly on the SDP model itself, so it waits for grads to be reduced since the module is in train mode.
importosimporttorchfromtorch.utils.dataimportDatasetfrompytorch_lightningimportLightningModule, TrainerclassRandomDataset(Dataset):
""" >>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS <...bug_report_model.RandomDataset object at ...> """def__init__(self, size, length):
self.len=lengthself.data=torch.randn(length, size)
def__getitem__(self, index):
returnself.data[index]
def__len__(self):
returnself.lenclassBoringModel(LightningModule):
""" >>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE BoringModel( (layer): Linear(...) ) """def__init__(self):
""" Testing PL Module Use as follows: - subclass - modify the behavior for what you want class TestModel(BaseTestModel): def training_step(...): # do your own thing or: model = BaseTestModel() model.training_epoch_end = None """super().__init__()
self.layer=torch.nn.Linear(32, 2)
defforward(self, x):
returnself.layer(x)
defloss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` callsreturntorch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
defstep(self, x):
x=self.layer(x)
out=torch.nn.functional.mse_loss(x, torch.ones_like(x))
returnoutdeftraining_step(self, batch, batch_idx):
output=self.layer(batch)
loss=self.loss(batch, output)
return {"loss": loss}
defvalidation_step(self, batch, batch_idx):
output=self.layer(batch)
loss=self.loss(batch, output)
return {"x": loss}
deftest_step(self, batch, batch_idx):
output=self.layer(batch)
loss=self.loss(batch, output)
return {"y": loss}
defconfigure_optimizers(self):
optimizer=torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
deftest_run():
# fake datatrain_data=torch.utils.data.DataLoader(RandomDataset(32, 64))
val_data=torch.utils.data.DataLoader(RandomDataset(32, 64))
test_data=torch.utils.data.DataLoader(RandomDataset(32, 64))
# modelmodel=BoringModel()
trainer=Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
plugins='ddp_sharded',
gpus=1,
weights_summary=None,
)
trainer.fit(model, train_data, val_data)
trainer.test(test_dataloaders=test_data)
if__name__=='__main__':
test_run()
The text was updated successfully, but these errors were encountered:
🐛 Bug
When validation/training is used (as default with the boring model) sharded crashes. This is because internally SDP relies on knowing the training state of the model, and when we run the validation sanity check, we do not set the eval mode correctly on the SDP model itself, so it waits for grads to be reduced since the module is in
train
mode.The text was updated successfully, but these errors were encountered: