Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

setting grad None after training to avoid memory leak #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kaloeffler
Copy link

Hi,

during training I've observed a memory leak, when training several models after another. I've used your example code from the Readme to create an example, see below. I'm using Python 3.7 and PyTorch 1.13.1. Per training run I observe an increase of about 10MiB concerning memory usage. This is especially an issue when training larger models, resulting in an out of memory error.

The memory leak is discussed in pytorch/pytorch#82528. In your code there is already a begin method in the Optimizable class which sets the param.grad to None. However, it seems there is also a need to set param.grad=None at the end of training a model to avoid memory leakage over several training runs. Hence, I suggest to add an end method which is called at the end of each training run.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MNIST_FullyConnected(nn.Module):
    """
    A fully-connected NN for the MNIST task. This is Optimizable but not itself
    an optimizer.
    """
    def __init__(self, num_inp, num_hid, num_out):
        super(MNIST_FullyConnected, self).__init__()
        self.layer1 = nn.Linear(num_inp, num_hid)
        self.layer2 = nn.Linear(num_hid, num_out)

    def initialize(self):
        nn.init.kaiming_uniform_(self.layer1.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.layer2.weight, a=math.sqrt(5))

    def forward(self, x):
        """Compute a prediction."""
        x = self.layer1(x)
        x = torch.tanh(x)
        x = self.layer2(x)
        x = torch.tanh(x)
        x = F.log_softmax(x, dim=1)
        return x

def train_model(model_wrapper, train_data):
    for i in range(1, EPOCHS + 1):
        running_loss = 0.0
        for j, (features_, labels_) in enumerate(train_data):
            model_wrapper.begin()  # call this before each step, enables gradient tracking on desired params
            features, labels = torch.reshape(features_, (-1, 28 * 28)).to(
                DEVICE
            ), labels_.to(DEVICE)
            pred = model_wrapper.forward(features)
            loss = F.nll_loss(pred, labels)
            model_wrapper.zero_grad()
            loss.backward(create_graph=True)  # important! use create_graph=True
            # loss.backward(create_graph=False) # important! use create_graph=True
            model_wrapper.step()
            running_loss += loss.item() * features_.size(0)
        train_loss = running_loss / len(train_data.dataset)
        print("EPOCH: {}, TRAIN LOSS: {}".format(i, train_loss))

if __name__ == "__main__":
    import torchvision
    from gradient_descent_the_ultimate_optimizer import gdtuo

    BATCH_SIZE = 256
    EPOCHS = 2
    N_RUNS = 50
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    for i_run in range(N_RUNS):
        print(f"Run: {i_run}")
        mnist_train = torchvision.datasets.MNIST(
            "./data",
            train=True,
            download=True,
            transform=torchvision.transforms.ToTensor(),
        )
        dl_train = torch.utils.data.DataLoader(
            mnist_train, batch_size=BATCH_SIZE, shuffle=True
        )
        model = MNIST_FullyConnected(28 * 28, 128, 10).to(DEVICE)
        optim = gdtuo.Adam(optimizer=gdtuo.SGD(1e-5))

        mw = gdtuo.ModuleWrapper(model, optimizer=optim)
        mw.initialize()
        train_model(mw, dl_train)

@kach
Copy link
Owner

kach commented Jan 16, 2023

Hi there— thank you for writing! I'm not sure I fully understand where this leak comes from, but it looks from your proposed code change that you could equivalently just call mw.begin() at the end of each training loop. Could you try that and see if it works?

(We would like to keep this repo as a frozen archive of the code from the paper, and will not be updating it unless there are severe issues without plausible workarounds.)

Thanks!

@kaloeffler
Copy link
Author

While calling mw.begin() after a training run also solves the issue since param.grad = None is set, it seems not intuitive to call a method named "begin" after a training run. Moreover, the method contains further functionalities which are not required to resolve the initial problem of resetting the gradients so gc can do its magic.

Nevertheless, I understand that you would like to avoid making changes to the code. However, I would still encourage you to update the README with a hint concerning calling the begin method at the end of a training run, when training several models. This would help future users greatly to successfully and quickly apply your code, without wondering where the out of memory error comes from.

@LeonieTabea
Copy link

Hi everyone,

Thank you Kartik Chandra for creating this great repository.

I have been using this code for a simple FNN and it works great, and there seems to be no memory leak. However, when switching to a simple RNN, memory increases drastically after each training batch, and at some point I am getting an out-of-memory error.

As described above, this issue is caused by setting loss.backward(create_graph=True). I have tried both described solutions for this: 1) set param.grad=None at the end of training by calling the end() method created by kaloeffler. 2) calling mw.begin() at the end of each training loop.

Both solutions did not work in my case. Also, it seems that the high increase in memory is cause by retaining the graph, as retain_graph is automatically set to True when creating the graph.

As I only have this issue when training an RNN model, and since an RNN model has been trained using the ultimate optimizer in your paper, I was wondering if you had a similar issue and another fix for it?

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants