-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DistributedDataParallel error - uninitialized parameters #644
Comments
oh... well.. actually, i haven't tried multigpu quantised training yet. i assumed it would just work, since we're not really messing with a whole lot other than dtypes. @sayakpaul cc |
i am guessing you can't test without quantisation to see? |
Did this help or isn't it possible at all? |
There is a multiGPU training example with FP8 but it uses |
everything torch does has such a worse interface than everything hugging face does - ao looks like it will work but jesus lord why is it so ugly lol |
I tried the following and the same error happened at prepare()
|
Okay. This is helpful. Would you be able to turn the above into a fuller reproducer and provide your accelerate config and launch command? Will try to look into it tomorrow. |
@sayakpaul any luck? |
Same error here, could you please provide some possible ideas about multi-gpu quantised training? Maybe I can try to work on it. |
this doesn't happen with LORA_TYPE=lycoris and fp8-quanto on 2x 3090 |
@sayakpaul i got u fam
import torch, accelerate
from diffusers import FluxTransformer2DModel
from optimum.quanto import quantize, qint8, freeze
weight_dtype = torch.bfloat16
accelerator = accelerate.Accelerator()
bfl_model = 'black-forest-labs/FLUX.1-dev'
transformer = FluxTransformer2DModel.from_pretrained(bfl_model, torch_dtype=torch.bfloat16, subfolder="transformer")
# you might need 'with accelerator.main_process_first()' if your server lacks system mem
print('quantizing')
quantize(transformer, qint8)
print('freezing')
freeze(transformer)
tpacked_noisy_latents = torch.randn(1, 1024, 64,dtype=weight_dtype, device=accelerator.device)
tpooled_projections = torch.randn(1, 768,dtype=weight_dtype, device=accelerator.device)
ttimesteps = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tguidance = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tencoder_hidden_states = torch.randn(1, 512, 4096,dtype=weight_dtype, device=accelerator.device)
ttxt_ids = torch.randn(1, 512, 3,dtype=weight_dtype, device=accelerator.device)
timg_ids = torch.randn(1, 4320, 3,dtype=weight_dtype, device=accelerator.device)
#with torch.no_grad():
# model_pred = transformer(
# hidden_states=tpacked_noisy_latents,
# timestep=ttimesteps,
# guidance=tguidance,
# pooled_projections=tpooled_projections,
# encoder_hidden_states=tencoder_hidden_states,
# txt_ids=ttxt_ids,
# img_ids=timg_ids,
# joint_attention_kwargs=None,
# return_dict=False,
# )
transformer = accelerator.prepare(transformer) |
same issue here, transformer = FluxTransformer2DModel.from_pretrained(bfl_model, torch_dtype=torch.bfloat16, subfolder="transformer")
if accelerator.is_main_process:
print('quantizing')
quantize(transformer, qint8)
print('freezing')
freeze(transformer)
print('waiting..')
accelerator.wait_for_everyone() |
…FT, inform user to go with lycoris instead
(#644) temporarily block training on multi-gpu setup with quanto + PEFT, inform user to go with lycoris instead
for now, DDP works with Lycoris. i will close this and eventually we will receive an upstream fix when there is time for them to focus on it again. |
I'm using Flux quickstart settings with fp8 quantization on 4x3090s. The same settings work on 1x3090.
TRAINING_NUM_PROCESSES=2
export ACCELERATE_EXTRA_ARGS="--multi_gpu"
on line: results = accelerator.prepare(primary_model
RuntimeError: Modules with uninitialized parameters can't be used with
DistributedDataParallel
. Run a dummy forward pass to correctly initialize the modulesI have tried the DDP argument find_unused_parameters=True and printing modules with requires_grad = True and grad = None, but there aren't any.
The text was updated successfully, but these errors were encountered: