Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DistributedDataParallel error - uninitialized parameters #644

Closed
OhioT opened this issue Aug 5, 2024 · 13 comments
Closed

DistributedDataParallel error - uninitialized parameters #644

OhioT opened this issue Aug 5, 2024 · 13 comments
Labels
upstream-bug We can't do anything but wait. wontfix This will not be worked on

Comments

@OhioT
Copy link

OhioT commented Aug 5, 2024

I'm using Flux quickstart settings with fp8 quantization on 4x3090s. The same settings work on 1x3090.

TRAINING_NUM_PROCESSES=2
export ACCELERATE_EXTRA_ARGS="--multi_gpu"

on line: results = accelerator.prepare(primary_model
RuntimeError: Modules with uninitialized parameters can't be used with DistributedDataParallel. Run a dummy forward pass to correctly initialize the modules

I have tried the DDP argument find_unused_parameters=True and printing modules with requires_grad = True and grad = None, but there aren't any.

@bghira
Copy link
Owner

bghira commented Aug 5, 2024

oh... well.. actually, i haven't tried multigpu quantised training yet. i assumed it would just work, since we're not really messing with a whole lot other than dtypes. @sayakpaul cc

@bghira bghira added bug Something isn't working help wanted Extra attention is needed good first issue Good for newcomers regression This bug has regressed behaviour that previously worked. labels Aug 5, 2024
@bghira
Copy link
Owner

bghira commented Aug 5, 2024

i am guessing you can't test without quantisation to see?

@sayakpaul
Copy link
Contributor

Run a dummy forward pass to correctly initialize the modules

Did this help or isn't it possible at all?

@sayakpaul
Copy link
Contributor

There is a multiGPU training example with FP8 but it uses ao:
https://github.com/pytorch/ao/blob/main/benchmarks/float8/bench_multi_gpu.py

@bghira
Copy link
Owner

bghira commented Aug 5, 2024

everything torch does has such a worse interface than everything hugging face does - ao looks like it will work but jesus lord why is it so ugly lol

@OhioT
Copy link
Author

OhioT commented Aug 5, 2024

Run a dummy forward pass to correctly initialize the modules

Did this help or isn't it possible at all?

I tried the following and the same error happened at prepare()

tpacked_noisy_latents = torch.randn(1, 4320, 64,dtype=weight_dtype, device=accelerator.device)
tpooled_projections = torch.randn(1, 768,dtype=weight_dtype, device=accelerator.device)
ttimesteps = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tguidance = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tencoder_hidden_states = torch.randn(1, 512, 4096,dtype=weight_dtype, device=accelerator.device)
ttxt_ids = torch.randn(1, 512, 3,dtype=weight_dtype, device=accelerator.device)
timg_ids = torch.randn(1, 4320, 3,dtype=weight_dtype, device=accelerator.device)

with torch.no_grad():
    model_pred = transformer(
        hidden_states=tpacked_noisy_latents,
        timestep=ttimesteps,
        guidance=tguidance,
        pooled_projections=tpooled_projections,
        encoder_hidden_states=tencoder_hidden_states,
        txt_ids=ttxt_ids,
        img_ids=timg_ids,
        joint_attention_kwargs=None,
        return_dict=False,
    )
transformer = accelerator.prepare(transformer)

@sayakpaul
Copy link
Contributor

Okay. This is helpful. Would you be able to turn the above into a fuller reproducer and provide your accelerate config and launch command?

Will try to look into it tomorrow.

@bghira
Copy link
Owner

bghira commented Aug 18, 2024

@sayakpaul any luck?

@matabear-wyx
Copy link

Same error here, could you please provide some possible ideas about multi-gpu quantised training? Maybe I can try to work on it.

@bghira
Copy link
Owner

bghira commented Aug 21, 2024

this doesn't happen with LORA_TYPE=lycoris and fp8-quanto on 2x 3090

@bghira bghira added the upstream-bug We can't do anything but wait. label Aug 21, 2024
@bghira
Copy link
Owner

bghira commented Aug 21, 2024

@sayakpaul i got u fam

accelerate launch --multi_gpu test.py

import torch, accelerate
from diffusers import FluxTransformer2DModel
from optimum.quanto import quantize, qint8, freeze
weight_dtype = torch.bfloat16

accelerator = accelerate.Accelerator()

bfl_model = 'black-forest-labs/FLUX.1-dev'
transformer = FluxTransformer2DModel.from_pretrained(bfl_model, torch_dtype=torch.bfloat16, subfolder="transformer")

# you might need 'with accelerator.main_process_first()' if your server lacks system mem
print('quantizing')
quantize(transformer, qint8)
print('freezing')
freeze(transformer)

tpacked_noisy_latents = torch.randn(1, 1024, 64,dtype=weight_dtype, device=accelerator.device)
tpooled_projections = torch.randn(1, 768,dtype=weight_dtype, device=accelerator.device)
ttimesteps = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tguidance = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tencoder_hidden_states = torch.randn(1, 512, 4096,dtype=weight_dtype, device=accelerator.device)
ttxt_ids = torch.randn(1, 512, 3,dtype=weight_dtype, device=accelerator.device)
timg_ids = torch.randn(1, 4320, 3,dtype=weight_dtype, device=accelerator.device)

#with torch.no_grad():
#    model_pred = transformer(
#        hidden_states=tpacked_noisy_latents,
#        timestep=ttimesteps,
#        guidance=tguidance,
#        pooled_projections=tpooled_projections,
#        encoder_hidden_states=tencoder_hidden_states,
#        txt_ids=ttxt_ids,
#        img_ids=timg_ids,
#        joint_attention_kwargs=None,
#        return_dict=False,
#    )
transformer = accelerator.prepare(transformer)

@bghira
Copy link
Owner

bghira commented Aug 21, 2024

same issue here,

transformer = FluxTransformer2DModel.from_pretrained(bfl_model, torch_dtype=torch.bfloat16, subfolder="transformer")
if accelerator.is_main_process:
        print('quantizing')
        quantize(transformer, qint8)
        print('freezing')
        freeze(transformer)
print('waiting..')
accelerator.wait_for_everyone()

bghira pushed a commit that referenced this issue Aug 21, 2024
…FT, inform user to go with lycoris instead
bghira added a commit that referenced this issue Aug 21, 2024
(#644) temporarily block training on multi-gpu setup with quanto + PEFT, inform user to go with lycoris instead
@bghira
Copy link
Owner

bghira commented Aug 27, 2024

for now, DDP works with Lycoris. i will close this and eventually we will receive an upstream fix when there is time for them to focus on it again.

@bghira bghira closed this as not planned Won't fix, can't repro, duplicate, stale Aug 27, 2024
@bghira bghira added wontfix This will not be worked on and removed bug Something isn't working help wanted Extra attention is needed good first issue Good for newcomers regression This bug has regressed behaviour that previously worked. labels Aug 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
upstream-bug We can't do anything but wait. wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

4 participants