-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharded Training not saving memory #6047
Comments
Hey @amogkam Thanks for the issue! The sample also made me realize there was a bug in master that needed to be resolved as well with passing plugins. I've ran your sample and confirmed that I can replicate results. The results are not too surprising, considering the model is really small! The model is around 230k parameters which is quite small (I think the number is a bit fishy however, but its alright for now) and Sharded won't really help here as the states are tiny; partitioning the states would have a negligible improvement. Sharded really takes effect when the model size is large (roughly 100M+ parameters). I ran tests if we bump up the embed dim size to around 1024. This was around 202M parameters which is still smallish. On 4 A100 GPUs (40GB VRAM) here are the numbers I get:
And If I really reduce the batch size (4 which will make the benefits of sharded even more obvious) and increase the model size, upping the embed dim to 2048 (807M parameters):
Sharded can take a larger batch size, but I wanted to compare to DDP. I'd suggest increasing the size of iGPT if possible! In the meantime I'll update the docs to make it clearer that these improvements require the model to be off Million parameter sizes at the least, and the benefits scale as the model size increases (to an extent). Here is my code: import time
import torch
from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.models import ImageGPT
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins import DDPShardedPlugin
class CUDACallback(Callback):
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
torch.cuda.synchronize(trainer.root_gpu)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module, outputs):
torch.cuda.synchronize(trainer.root_gpu)
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
epoch_time = time.time() - self.start_time
max_memory = torch.tensor(max_memory, dtype=torch.int, device=trainer.root_gpu)
epoch_time = torch.tensor(epoch_time, dtype=torch.int, device=trainer.root_gpu)
torch.distributed.all_reduce(max_memory, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(epoch_time, op=torch.distributed.ReduceOp.SUM)
world_size = torch.distributed.get_world_size()
print(f"Average Epoch time: {epoch_time.item() / float(world_size):.2f} seconds")
print(f"Average Peak memory {max_memory.item() / float(world_size):.2f}MiB")
dm = MNISTDataModule('.', batch_size=4)
model = ImageGPT(embed_dim=2048, layers=16, heads=4, vocab_size=32, num_pixels=28)
ddp_plugin = DDPPlugin(find_unused_parameters=True)
sharded_plugin = DDPShardedPlugin()
trainer = pl.Trainer(
max_epochs=1, gpus=4, accelerator='ddp', precision=16,
callbacks=[CUDACallback()], plugins=ddp_plugin # Swap between the plugins
)
trainer.fit(model, dm) |
hey @miraodasilva, I see you're using lightning 1.1, could you try using lightning master?
|
I downgraded to 1.1 since mixed precision is broken on 1.2, as reported here: #6077 I tried lightning master but it also seems to break my code (which works on 1.1): |
@SeanNaren what cloud provider are you using for GPUs? AWS doesn't have any instances with 40 GB GPU memory, does it? |
@SeanNaren any updates on this? I'm getting the same behaviour (no memory saving) on 1.2.4. Thanks in advance. |
I'm using ResNet50 training on ImageNet here. No memory saving compare to regular ddp training. |
Circling back to this issue, if people could share their setup + model size this would help drastically. |
🐛 Bug
I am training an ImageGPT model, but am not seeing any less GPU memory being used when training with ddp sharded vs. without.
Environment info:
PTL: v1.1.8
PTL Bolts: v0.3.0
Pytorch: v1.7.1
Python: v3.7.7
Cuda: 10.2
Single AWS p3.8xlarge instance- it contains 4 Tesla V100 GPUs.
The code is very simple:
When I remove the
DDPShardedPlugin
I am seeing GPU memory usage of ~13MiBWhen I include the plugin, I am still seeing GPU memory usage exactly the same
I would expect the per-device memory usage with the plugin to be less.
@SeanNaren any idea on what's going on here? Thanks a lot for the help.
The text was updated successfully, but these errors were encountered: