Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support wrapping multiple models in Accelerator.accumulate() #1708

Merged
merged 5 commits into from
Jul 25, 2023

Conversation

yuxinyuan
Copy link
Contributor

This PR updates Accelerator.accumulate() so that users can pass in multiple models into the context manager for skipping gradient sync (by calling no_sync())

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 12, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR assumes all models have the same gradient accumulation steps which might not always be the case.

@yuxinyuan
Copy link
Contributor Author

@pacman100 But can accelerate handle different gradient accumulation steps for different models without this PR?

My hope is that this PR will not break any old behavior. And when I train a GAN, I can do something like with accelerator.accumulate(gen, disc):. Because, right now, doing with accelerator.accumulate(gen), accelerator.accumulate(disc): does not give the correct result

@muellerzr
Copy link
Collaborator

While this pr is good, I recall more is needed for this to work. Let me dig through my stuff

@muellerzr
Copy link
Collaborator

The issue here is the gradient synchronization will be off. We were waiting on feedback from the torch team for if our solution was a "bug" or a feature, and since they haven't responded we'll go with feature and I'll open a PR today. As sadly no, it's not as simple as you would hope

@yuxinyuan
Copy link
Contributor Author

@muellerzr I don't quite get the thing about "bug or feature" (maybe I need more background), but isn't disabling gradient synchronization exactly what one needs for gradient accumulation? At least that is for my use case

@muellerzr
Copy link
Collaborator

muellerzr commented Jul 13, 2023

@yuxinyuan please see #1726

@yuxinyuan
Copy link
Contributor Author

@muellerzr I go through #1726 and I'm a little bit confused with the example use case.

For now, the example use case is like:

with accelerator.no_sync(ddp_model):
    loss_a = ddp_model(input_a)
    loss_b = ddp_model(input_b)
accelerator.backward(loss_a)  # No synchronization across processes, only accumulate gradients
with accelerator.trigger_sync_in_backward(ddp_model):
    accelerator.backward(loss_b)  # Synchronization across all processes

What is the difference between using trigger_sync_in_backward and the below:

loss_a = ddp_model(input_a)
loss_b = ddp_model(input_b)
accelerator.backward(loss_a + loss_b)

I guess, in general, no_sync and trigger_sync_in_backward are in analogy to torch.no_grad() and torch.enable_grad(). But can you maybe give a more practical use case of trigger_sync_in_backward?

And regarding this PR, I don't think it conflicts with existing features. It just gives more flexibility to accumulate()

@yuxinyuan
Copy link
Contributor Author

I also tried the script below. It seems to me trigger_sync_in_backward can be used within no_sync. In a sense, they are indeed the gradient-synchronized versions of torch.no_grad() and torch.enable_grad().

import accelerate
import torch


def print_rank_by_rank(*args, **kwargs):
    for rank in range(accelerator.num_processes):
        if rank == accelerator.process_index:
            print(*args, **kwargs)
        accelerator.wait_for_everyone()


if __name__ == "__main__":
    accelerator = accelerate.Accelerator(gradient_accumulation_steps=2)
    accelerator.print(accelerate.__version__)
    accelerator.print(torch.__version__)

accelerate.utils.set_seed(1234)
    model = torch.nn.Linear(3, 3, bias=False)
    model = accelerator.prepare(model)

    generator = torch.Generator().manual_seed(1234 + accelerator.process_index)
    input_a = torch.randn(2, 3, generator=generator).to(accelerator.device)
    input_b = torch.randn(2, 3, generator=generator).to(accelerator.device)
    target_a = torch.randn(2, 3, generator=generator).to(accelerator.device)
    target_b = torch.randn(2, 3, generator=generator).to(accelerator.device)

    loss_a = torch.nn.functional.mse_loss(model(input_a), target_a)
    loss_b = torch.nn.functional.mse_loss(model(input_b), target_b)
    accelerator.backward(loss_a + loss_b)

    for param in model.parameters():
        print_rank_by_rank(param.grad)

    model.zero_grad()

    with accelerator.no_sync(model):
        loss_a = torch.nn.functional.mse_loss(model(input_a), target_a)
        loss_b = torch.nn.functional.mse_loss(model(input_b), target_b)

        accelerator.backward(loss_a)
        for param in model.parameters():
            print_rank_by_rank(param.grad)

        with accelerator.trigger_sync_in_backward(model):
            accelerator.backward(loss_b)
        for param in model.parameters():
            print_rank_by_rank(param.grad)

@muellerzr
Copy link
Collaborator

muellerzr commented Jul 14, 2023

It can be, but shouldn't be really I don't think. The difference is no_sync ensures that the GPUs never synchronize in a multi-GPU setup, and save time. This is especially needed in cases such as gradient accumulation, when it wastes time to do so, especially on multi-node. See my article here discussing it: https://muellerzr.github.io/blog/gradient_accumulation.html

So it's more than no_grad, because gradients are still calculated.

@yuxinyuan
Copy link
Contributor Author

@muellerzr I go through #1726 and I'm a little bit confused with the example use case.

For now, the example use case is like:

with accelerator.no_sync(ddp_model):
    loss_a = ddp_model(input_a)
    loss_b = ddp_model(input_b)
accelerator.backward(loss_a)  # No synchronization across processes, only accumulate gradients
with accelerator.trigger_sync_in_backward(ddp_model):
    accelerator.backward(loss_b)  # Synchronization across all processes

What is the difference between using trigger_sync_in_backward and the below:

loss_a = ddp_model(input_a)
loss_b = ddp_model(input_b)
accelerator.backward(loss_a + loss_b)

I guess, in general, no_sync and trigger_sync_in_backward are in analogy to torch.no_grad() and torch.enable_grad(). But can you maybe give a more practical use case of trigger_sync_in_backward?

And regarding this PR, I don't think it conflicts with existing features. It just gives more flexibility to accumulate()

@muellerzr I might not be clear enough here. What I mean is that accelerate.no_sync and accelerate.trigger_sync_in_backward are in analogy to torch.no_grad() and torch.enable_grad() in the sense that accelerate.no_sync prevents DDP from syncing grads while accelerate.trigger_sync_in_backward forces DDP to sync grads.

I understand the importance of disabling synchronization in gradient accumulation. I just cannot picture a senario where we actually need accelerate.trigger_sync_in_backward.

Anyway, back to this PR. What are the current shortcomings or limitations in your opinion?

@TalhaUusuf
Copy link

I am also facing this issue, while training multiple controlnets. Acclerate is not able to accumulate multiple models, Is there any workaround until this problem is solved?

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yuxinyuan, thinking on it more let's go with this as well. A few nits in terms of variable names, otherwise looks great to me 🤗

src/accelerate/accelerator.py Outdated Show resolved Hide resolved
src/accelerate/accelerator.py Outdated Show resolved Hide resolved
yuxinyuan and others added 3 commits July 25, 2023 10:14
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
@muellerzr muellerzr requested a review from sgugger July 25, 2023 16:15
@muellerzr
Copy link
Collaborator

Awesome! Sending to @sgugger for a final review :)

@sgugger sgugger merged commit 6e70e79 into huggingface:main Jul 25, 2023
@sgugger
Copy link
Collaborator

sgugger commented Jul 25, 2023

Thanks for your contribution!

@eliphatfs
Copy link

The issue here is the gradient synchronization will be off. We were waiting on feedback from the torch team for if our solution was a "bug" or a feature, and since they haven't responded we'll go with feature and I'll open a PR today. As sadly no, it's not as simple as you would hope

Sorry but I don't quite get the result -- does accelerate support accumulation on multiple models on multiple GPUs with correct gradient synchronization now?

@amanikiruga
Copy link

@muellerzr @yuxinyuan so as asked "does accelerate support accumulation on multiple models on multiple GPUs with correct gradient synchronization now?" and more specifically, if i am training a diffusion model and text encoder can i just do accelerator.accumulate(unet, text_encoder)?

@bghira
Copy link

bghira commented Oct 8, 2023

as i'm preparing to fix this issue in the diffusers example training scripts, i would also like to know whether the results are considered "correct"

@muellerzr
Copy link
Collaborator

Yes, that is all you need

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants