Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] train_text_to_image_lora.py not support Multi-nodes or Multi-gpus training. #4046

Closed
WindVChen opened this issue Jul 11, 2023 · 31 comments
Labels
stale Issues that haven't received updates

Comments

@WindVChen
Copy link

In train_text_to_image_lora.py, I notice that the LORA parameters are extracted into an AttnProcsLayers class:

518    lora_layers = AttnProcsLayers(unet.attn_processors)

And it is only the lora_layers that is wrapped by DistributedDataParallel in the following code:

670    lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
            lora_layers, optimizer, train_dataloader, lr_scheduler
          )

In the training process, it seems that the lora_layers are not explicitly used but only the unet is used:

776    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

My question is that when using Multi-GPUs or Multi-Machines, will the gradients be successfully averaged across all processes in the above way?

It is true that in each process, the gradients will be backward to unet.attn_processors, and these gradients will be shared by lora_layers, so we can use optimizer to update the weights. However, since we actually use unet.attn_processors to do the forward operation, but not the wrapped lora_layers, can the gradients be correctly averaged? From here, it seems that a wrapped module will have a different forward compared to its original forward operation.

I am not quite familiar with torch.nn.parallel.DistributedDataParallel wrapper, and I do worry about whether the current code in train_text_to_image_lora.py will lead to different LORA weights in different processes (if the gradients failed to broadcast among processes).

Hope to find some help here, thank you.

@WindVChen WindVChen changed the title Question about DistributedDataParallel in train_text_to_image_lora.py. [BUG!] train_text_to_image_lora.py not support Multi-nodes or Multi-gpus training. Jul 12, 2023
@WindVChen WindVChen changed the title [BUG!] train_text_to_image_lora.py not support Multi-nodes or Multi-gpus training. [BUG] train_text_to_image_lora.py not support Multi-nodes or Multi-gpus training. Jul 12, 2023
@WindVChen
Copy link
Author

After carefully printing out the gradients and weights in different processes, it seems quite sure that the current LORA training script fails to be applied to Multi-nodes or Multi-GPUs training: The gradients failed to broadcast among the processes, which then lead to different LORA weights in different processes.

In the following, I give a printout when I start a 2-GPU task:
image

The first 3 number denotes process/epoch/step, and the following two tensors are the gradients and weights of lora_layers.layers[31].to_out_lora.up (A linear layer). It is obvious that the gradients are different in different processes. Thus, after three steps, it will turn to:
image
Obviously, the weights of LORA are different in different processes.

I think it is a bug since there is no description of that in the script. People who use the script to train on multi-processes will finally and actually get the result on a single process.

@patrickvonplaten
Copy link
Contributor

cc @williamberman @sayakpaul

@sayakpaul
Copy link
Member

@WindVChen

Thanks for elaborating on this!

From gauging this briefly, it seems like passing unet to the prepare step rather than just the LoRA layers might just fix this.

Could you confirm this once?

@WindVChen
Copy link
Author

Hi @sayakpaul ,

Yes. Passing unet to the prepare step can quickly fix it. However, it also can bring some inconvenience: 1) The batch size has to be reduced by half, or the training will lead to out-of-memory; 2) The checkpoint stored intermediate is much larger (~3G compared to the original ~3M), as it needs to store the whole unet structure.

I wonder if there is a more elegant way to solve it without sacrificing memory and storage. I also tried wrapping each "loraAttnprocessor" manually, but found it required a lot of source code modification, so I gave up 😢.

@sayakpaul
Copy link
Member

Thanks for sharing!

I also tried wrapping each "loraAttnprocessor" manually, but found it required a lot of source code modification, so I gave up

Could you expand a bit more on this?

@WindVChen
Copy link
Author

Yes.

It seems that the problem above happened just because the training script used the unwrapped unet to do the forward, not the wrapped lora_layers, so my previous idea is to wrap every loraAttnprocessor in the unet one by one (directly use accelerate.prepare) and then replace the original unwrapped loraAttnprocessors in unet with these wrapped loraAttnprocessors. In this way, maybe we can expect that in the training phase, the gradients can broadcast among different processes.

However, since I'm not quite familiar with torch.nn.parallel.DistributedDataParallel, I'm not sure whether this solution can work.

@sayakpaul
Copy link
Member

Hmm. I will defer to what @williamberman has to point out here. Also ccing @muellerzr from the accelerate team here.

@muellerzr
Copy link
Contributor

When using Accelerate any models that get gradient/weight updates should be passed to .prepare. Do note that gradient accumulation currently won't work with multiple models, we're adding that very soon, but hope that helps.

@WindVChen
Copy link
Author

Hi @muellerzr ,

Could you expand a bit more about "gradient accumulation currently won't work with multiple models"? Or are there any blogs or issues related to this? Because I am going to use gradient accumulation for multiple models.:sweat_smile:

@muellerzr
Copy link
Contributor

No one has quite pointed out this "issue" yet actually, once the PR is opened (probably today or tomorrow) I'll link to it but basically multiple forward passes followed by multiple backward passes (for each model's loss, for instance) leads to some headaches in torch distributed.

@WindVChen
Copy link
Author

OK, thanks.

Hope to be sure. So does that mean that a GAN training script like the one below will cause some BUGs?

"Suppose gradient_accumulation is set to 2"
optimizer_gen = optim(generator.parameters())
optimizer_disc = optim(discriminator.parameters())
with accelerator.accumulate(generator):
        outputs = optimizer_gen(input)
        loss = loss_func(outputs)
        loss.backward()
        optimizer_gen.step()
        optimizer_gen.zero_grad()

        outputs = optimizer_disc(input)
        loss = loss_func(outputs)
        loss.backward()
        optimizer_disc.step()
        optimizer_disc.zero_grad()

@muellerzr
Copy link
Contributor

Yes because technically the discriminator should also be under accumulate. So the discriminator might still be getting updated each step. (Though again, working on support there)

@WindVChen
Copy link
Author

Ah, I see. Maybe I can double-wrap the pipeline for a temporary fix before the PR? Like this:

with accelerator.accumulate(generator):
    with accelerator.accumulate(discriminator):
        outputs = optimizer_gen(input)
        ...

@muellerzr
Copy link
Contributor

Double wrapping also does not work, hence the need for a more complex solution. (see discussion here: huggingface/accelerate#1708). So just wait a day or two :)

@WindVChen
Copy link
Author

OK, thanks a lot. Look forward to the solution. 😊

@eliphatfs
Copy link
Contributor

I have a solution that is not elegant but works: wrapping around it with another network.

class SuperNet(torch.nn.ModuleDict):
    def forward(self, text_encoder, unet, batch, class_labels, noisy_model_input, timesteps):
        # Get the text embedding for conditioning
        encoder_hidden_states = encode_prompt(
            text_encoder,
            batch["input_ids"],
            None
        )
        # Predict the noise residual
        return unet(
            noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
        ).sample

Only use this module with accelerate and in the training loop instead of the original code call this module.

This originates as a solution to training text encoder and unet simultaneously for dreambooth. You could look here: huggingface/accelerate#668 (comment)

@WindVChen
Copy link
Author

Hi @eliphatfs ,

Thanks for sharing. Just want to confirm that the code given is actually solving the gradient accumulation problem of multiple models when using DDP, not the problem of applying multiple GPUs/nodes to the LORA code, is it? (There are two problems mentioned in this issue 😂 )

@eliphatfs
Copy link
Contributor

eliphatfs commented Jul 13, 2023

Hi @eliphatfs ,

Thanks for sharing. Just want to confirm that the code given is actually solving the gradient accumulation problem of multiple models when using DDP, not the problem of applying multiple GPUs/nodes to the LORA code, is it? (There are two problems mentioned in this issue 😂 )

You can do similar things for LoRA. This is from one of my custom pipelines, but can be quickly changed to any. Note the base class is AttnProcsLayers so you can use it in place of the original AttnProcsLayers and supports inference loaders.

class SuperNet(AttnProcsLayers):
    def forward(self, image_encoder, unet, image, noisy_model_input, timesteps):
        encoded = image_encoder(image)
        return unet(
            noisy_model_input,
            timesteps,
            encoder_hidden_states=encoded.last_hidden_state,
            class_labels=noise_image_embeddings(encoded.image_embeds, 0)
        ).sample

I think the thing is to make sure that

  1. No mutable parameters are accessed outside the DDP wrapper forward and all of them are parameters of that wrapped module.
  2. Only 1 DDP wrapper exists for each process group (or accelerator instance).

@WindVChen
Copy link
Author

Thanks for the further description.

But from my understanding, this may still fail to solve the problem listed here? Have any thoughts on that problem? Or am I missing something?

@eliphatfs
Copy link
Contributor

eliphatfs commented Jul 14, 2023

You are not passing the complete unet into accelerate, it is passed as an argument at forward time, so the parameters will not be stored in the checkpoint. You have to make sure all optimized parameters are registered in the SuperNet of course.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Aug 11, 2023
@eliphatfs
Copy link
Contributor

Anyone can verify if this has been fixed in accelerate?

@thuanz123
Copy link
Contributor

thuanz123 commented Aug 16, 2023

Hi so the quick fix for this is including the unet as well as the lora_layers to the prepare step right ?

Edit: I mean the multi-node/multi-gpus training support for LoRA not the gradient accumulation for multiple models

@sayakpaul
Copy link
Member

Cc: @muellerzr

@hkunzhe
Copy link

hkunzhe commented Aug 21, 2023

Any updates?

@muellerzr
Copy link
Contributor

This should be fixed by passing multiple models with accelerator.accumulate, yes @hkunzhe

@hkunzhe
Copy link

hkunzhe commented Aug 21, 2023

This should be fixed by passing multiple models with accelerator.accumulate, yes @hkunzhe

I got it!

@williamberman
Copy link
Contributor

Hi yes, @WindVChen this is correct. The issue is that accelerate prepare works by wrapping the passed class in ddp and then we're supposed to call the returned ddp class. Similarly accelerate mixed precision works by monkey patching the forward method of the passed in class.

Any script where we use the AttnProcsLayers class will not properly work with accelerate because that class just holds the given parameters but it isn't actually used as a part of the model.

I fixed this for the dreambooth lora script here: #3778

We should really remove the AttnProcsLayers class and always pass the top level model to accelerate.prepare. I'm going to open an issue better documenting this but unfortunately I can't get to it right away as these cross training script refactors are relatively involved

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ai1361720220000
Copy link

This should be fixed by passing multiple models with accelerator.accumulate, yes @hkunzhe

Hi @muellerzr ,
What should be done when it comes to accelerator.prepare()? Does multi models are all put into "prepare" function or only one of them is ok?like accelerator.prepare(model1, model2, optimizer,train_dataloader, lr_scheduler) or accelerator.prepare(model1, optimizer,train_dataloader, lr_scheduler) , accelerator.prepare(model2, optimizer,train_dataloader, lr_scheduler).

@muellerzr
Copy link
Contributor

All of them should be put into prepare, specifically all the ones that expect to have their gradients updated. Those same ones should then also be passed to accumulate.

You can send both into prepare at the same time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

9 participants