You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
Deterministic mode is not set on all workers when Trainer is set to deterministic=True.
To Reproduce
The script is divided in two parts. test.py and model.py to show that torch.backends.cudnn.deterministic is set on every worker and that the initial value is False
#############test.py#############importtorchimportpytorch_lightningasplfrompytorch_lightning.strategiesimportStrategyfrompytorch_lightningimportLightningModule, Trainerfromray_lightningimportRayStrategyfrommodelimportBoringModeldefget_trainer(dir,
strategy: Strategy,
gpus=None,
max_epochs: int=1,
limit_train_batches: int=10,
limit_val_batches: int=10,
**trainer_kwargs) ->Trainer:
"""Returns a Pytorch Lightning Trainer with the provided arguments."""trainer=pl.Trainer(
default_root_dir=dir,
gpus=gpus,
strategy=strategy,
max_epochs=max_epochs,
deterministic=True,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
enable_progress_bar=False,
**trainer_kwargs)
returntrainerdeftrain_test(trainer: Trainer, model: LightningModule):
"""Checks if training the provided model updates its weights."""initial_values=torch.tensor(
[torch.sum(torch.abs(x)) forxinmodel.parameters()])
trainer.fit(model)
post_train_values=torch.tensor(
[torch.sum(torch.abs(x)) forxinmodel.parameters()])
asserttrainer.state.finished, f"Trainer failed with {trainer.state}"# Check that the model is actually changed post-training.asserttorch.norm(initial_values-post_train_values) >0.1deftest_ray_train(tmpdir, num_workers):
"""Tests if training modifies model weights."""model=BoringModel()
strategy=RayStrategy(num_workers=num_workers, use_gpu=True)
trainer=get_trainer(tmpdir, strategy=strategy)
train_test(trainer, model)
if__name__=='__main__':
test_ray_train("test", 1)
#############model.py#############importtorchfromtorch.utils.dataimportDatasetfrompytorch_lightningimportLightningModuleprint("Deterministic:", torch.backends.cudnn.deterministic)
torch.backends.cudnn.deterministic=Trueprint("Deterministic:", torch.backends.cudnn.deterministic)
classRandomDataset(Dataset):
def__init__(self, size: int, length: int):
self.len=lengthself.data=torch.randn(length, size)
def__getitem__(self, index: int):
returnself.data[index]
def__len__(self):
returnself.lenclassBoringModel(LightningModule):
def__init__(self):
super().__init__()
self.layer=torch.nn.Linear(32, 2)
self.val_epoch=0defforward(self, x):
returnself.layer(x)
defloss(self, batch, prediction):
# Arbitrary loss to have a loss that updates the model weights# during `Trainer.fit` callsreturntorch.nn.functional.mse_loss(prediction,
torch.ones_like(prediction))
defstep(self, x):
x=self(x)
out=torch.nn.functional.mse_loss(x, torch.ones_like(x))
returnoutdeftraining_step(self, batch, batch_idx):
output=self.layer(batch)
loss=self.loss(batch, output)
return {"loss": loss}
defconfigure_optimizers(self):
optimizer=torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler=torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
deftrain_dataloader(self):
returntorch.utils.data.DataLoader(RandomDataset(32, 64))
This leads to the following output
MODEL Deterministic: False
MODEL Deterministic: True
2022-09-15 09:39:20,204 INFO worker.py:1518 -- Started a local Ray instance.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
(RayExecutor pid=1819295) MODEL Deterministic: False
(RayExecutor pid=1819295) MODEL Deterministic: True
(RayExecutor pid=1819295) /scratch/markus.spanring/conda/envs/storch/lib/python3.9/site-packages/pytorch_lightning/utilities/warnings.py:53: LightningDeprecationWarning: pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6 and will be removed in v1.8. Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead.
(RayExecutor pid=1819295) new_rank_zero_deprecation(
(RayExecutor pid=1819295) /scratch/markus.spanring/conda/envs/storch/lib/python3.9/site-packages/pytorch_lightning/utilities/warnings.py:58: LightningDeprecationWarning: ParallelStrategy.torch_distributed_backend was deprecated in v1.6 and will be removed in v1.8.
(RayExecutor pid=1819295) return new_rank_zero_deprecation(*args, **kwargs)
(RayExecutor pid=1819295) Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
(RayExecutor pid=1819295) ----------------------------------------------------------------------------------------------------
(RayExecutor pid=1819295) distributed_backend=nccl
(RayExecutor pid=1819295) All distributed processes registered. Starting with 1 processes
(RayExecutor pid=1819295) ----------------------------------------------------------------------------------------------------
(RayExecutor pid=1819295)
(RayExecutor pid=1819295) GPU available: True (cuda), used: True (Please ignore the previous info [GPU used: False]).
(RayExecutor pid=1819295) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
(RayExecutor pid=1819295)
(RayExecutor pid=1819295) | Name | Type | Params
(RayExecutor pid=1819295) ---------------------------------
(RayExecutor pid=1819295) 0 | layer | Linear | 66
(RayExecutor pid=1819295) ---------------------------------
(RayExecutor pid=1819295) 66 Trainable params
(RayExecutor pid=1819295) 0 Non-trainable params
(RayExecutor pid=1819295) 66 Total params
(RayExecutor pid=1819295) 0.000 Total estimated model params size (MB)
From this one can see that model.py is loaded twice and that torch.backends.cudnn.deterministic is always false at the beginning.
Expected behavior
Deterministic mode is set on all workers when running in distributed mode.
I know that PTL uses torch.use_deterministic_algorithms to set the deterministic mode and that this does not set torch.backends.cudnn.deterministic. I stumbled over this behavior when I tried to compare the checkpoints of two DCGANs (lot of non deterministic layers) that should have been equal. So I am positive that deterministic mode is not set on the remote worker.
🐛 Bug
Deterministic mode is not set on all workers when
Trainer
is set todeterministic=True
.To Reproduce
The script is divided in two parts.
test.py
andmodel.py
to show thattorch.backends.cudnn.deterministic
is set on every worker and that the initial value isFalse
This leads to the following output
From this one can see that
model.py
is loaded twice and thattorch.backends.cudnn.deterministic
is always false at the beginning.Expected behavior
Deterministic mode is set on all workers when running in distributed mode.
I know that PTL uses torch.use_deterministic_algorithms to set the deterministic mode and that this does not set
torch.backends.cudnn.deterministic
. I stumbled over this behavior when I tried to compare the checkpoints of two DCGANs (lot of non deterministic layers) that should have been equal. So I am positive that deterministic mode is not set on the remote worker.As a first workaround I have added
here
I am not sure if this is enough or if the accelerator needs to be initialized properly on each worker.
Environment
The text was updated successfully, but these errors were encountered: