Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep speed #1139

Merged
merged 22 commits into from
Mar 26, 2024
Merged

Deep speed #1139

merged 22 commits into from
Mar 26, 2024

Conversation

kohya-ss
Copy link
Owner

Original PR #1101

I think it is not necessary to set back unet or text_encoder with the result of prepare_deepspeed_model. Because the model is not list, so they are not changed in the function.

@kohya-ss kohya-ss changed the base branch from main to dev February 27, 2024 12:34
@kohya-ss kohya-ss mentioned this pull request Feb 27, 2024
@BootsofLagrangian
Copy link
Contributor

I tested new branch with some of settings. It seems like even if SD-variants(cascade, SD-3, etc.) come out later, they will work well with wrapping.

@storuky
Copy link

storuky commented Mar 6, 2024

Hey @BootsofLagrangian
Have you encountered the "zero stage 2 requires an optimizer" error? If so, how did you fix that?

@BootsofLagrangian
Copy link
Contributor

Hey @BootsofLagrangian Have you encountered the "zero stage 2 requires an optimizer" error? If so, how did you fix that?

Can you attach your bash script or toml config file?

@storuky
Copy link

storuky commented Mar 8, 2024

@BootsofLagrangian
Here is accelerate config:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 1
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Here is how I run finetuning:

accelerate launch --gpu_ids="0,1"  --multi_gpu --num_processes=2  --num_cpu_threads_per_process=2 "./sdxl_train.py" \
  --ddp_timeout='1000' \
  --bucket_no_upscale \
  --bucket_reso_steps=64 \
  --cache_latents \
  --cache_latents_to_disk \
  --caption_extension=".txt" \
  --dataset_repeats="20" \
  --enable_bucket \
  --min_bucket_reso=64   \
  --max_bucket_reso=1024 \
  --in_json="/home/storuky/ml/train/meta_cap.json" \
  --gradient_checkpointing \
  --learning_rate="1.2e-06" \
  --learning_rate_te1="5e-07" \
  --learning_rate_te2="5e-07" \
  --logging_dir="/home/storuky/ml/train/log" \
  --lr_scheduler="constant" \
  --lr_scheduler_args \
  --lr_scheduler_type "CosineAnnealingLR" \
  --lr_scheduler_args "T_max=10" \
  --max_data_loader_n_workers="0" \
  --resolution="1024,1024" \
  --max_timestep=900 \
  --max_token_length=225 \
  --max_train_epochs=10 \
  --max_train_steps="979575" \
  --min_snr_gamma=5  \
  --min_timestep=100 \
  --mixed_precision="bf16" \
  --no_half_vae \
  --noise_offset=0.0375 \
  --adaptive_noise_scale=0.00375 \
  --optimizer_args  scale_parameter=False relative_step=False warmup_init=False weight_decay=0.01 \
  --optimizer_type="Adafactor" \
  --output_dir="/home/storuky/ml/out" \
  --output_name="TrainingModel"    \
  --pretrained_model_name_or_path="/home/storuky/ml/sd_xl_base_1.0.safetensors" \
  --save_every_n_epochs="1"  \
  --save_model_as=safetensors \
  --save_precision="bf16" \
  --save_state \
  --seed="1234" \
  --train_batch_size="1" \
  --train_data_dir="/home/storuky/ml/train/dataset" \
  --train_text_encoder \
  --v_pred_like_loss="0.5" \
  --xformers  \
  --deepspeed \
  --zero_stage 2 \
  --offload_optimizer_device cpu

@BootsofLagrangian
Copy link
Contributor

@storuky

When you want to use cpu offloading with offload_optimizer_device=cpu, DeepSpeed will build and use CPUAdam. It is also kind of Adam.

Can you change optimizer_type="Adafactor" this to optimizer_type="AdamW" or another adamw such as adamw8bit?

When I use adafactor, I got another error. No error with adamw.

@storuky
Copy link

storuky commented Mar 9, 2024

@BootsofLagrangian Yeah, I tried AdamW as well but no luck so far...

Here is a full trace of issue with AdamW as optimizer (spoiler: it's happening with any kind of offload_optimizer_device... none, nvme, cpu – doesn't matter):

[2024-03-09 11:03:49,081] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.13.5, git-hash=unknown, git-branch=unknown
[2024-03-09 11:03:52,910] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
Traceback (most recent call last):
  File "/home/storuky/ml/train/sd-scripts/./sdxl_train.py", line 810, in <module>
Traceback (most recent call last):
  File "/home/storuky/ml/train/sd-scripts/./sdxl_train.py", line 810, in <module>
    train(args)
      File "/home/storuky/ml/train/sd-scripts/./sdxl_train.py", line 415, in train
train(args)
  File "/home/storuky/ml/train/sd-scripts/./sdxl_train.py", line 415, in train
        ds_model = accelerator.prepare(ds_model)ds_model = accelerator.prepare(ds_model)

  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1209, in prepare
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1209, in prepare
    result = self._prepare_deepspeed(*args)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1582, in _prepare_deepspeed
    result = self._prepare_deepspeed(*args)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1582, in _prepare_deepspeed
    engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/deepspeed/__init__.py", line 176, in initialize
Traceback (most recent call last):
  File "/home/storuky/ml/train/sd-scripts/./sdxl_train.py", line 810, in <module>
    engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/deepspeed/__init__.py", line 176, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 312, in __init__
    train(args)
  File "/home/storuky/ml/train/sd-scripts/./sdxl_train.py", line 415, in train
    ds_model = accelerator.prepare(ds_model)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1209, in prepare
    engine = DeepSpeedEngine(args=args,
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 312, in __init__
    result = self._prepare_deepspeed(*args)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1582, in _prepare_deepspeed
    self.optimizer = self._configure_zero_optimizer(optimizer=None)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1505, in _configure_zero_optimizer
    engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/deepspeed/__init__.py", line 176, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 312, in __init__
    self.optimizer = self._configure_zero_optimizer(optimizer=None)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1505, in _configure_zero_optimizer
    self.optimizer = self._configure_zero_optimizer(optimizer=None)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1505, in _configure_zero_optimizer
    assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage)
    assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage)    AssertionError
: assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage)zero stage 2 requires an optimizer

AssertionErrorAssertionError: zero stage 2 requires an optimizer

[2024-03-09 11:04:00,019] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 407381) of binary: /home/storuky/ml/train/sd-scripts/venv/bin/python3
Traceback (most recent call last):
  File "/home/storuky/ml/train/sd-scripts/venv/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1008, in launch_command
    multi_gpu_launcher(args)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 666, in multi_gpu_launcher
    distrib_run.run(args)
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/storuky/ml/train/sd-scripts/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
./sdxl_train.py FAILED

@storuky
Copy link

storuky commented Mar 9, 2024

@BootsofLagrangian even if I copy your toml conf from here , change only paths and run as you described I still get this error. Tried to reconfigure accelerate and reinstall/install another versions on Deepspeed – no affect.

@storuky
Copy link

storuky commented Mar 9, 2024

@BootsofLagrangian Ah, I just switched to your version and it's working! The issue just with this branch.

@BootsofLagrangian
Copy link
Contributor

@BootsofLagrangian Ah, I just switched to your version and it's working! The issue just with this branch.

Thank for your report!

kohya-ss and others added 4 commits March 12, 2024 20:41
 - we have to prepare optimizer and ds_model at the same time.
 - pull/1139#issuecomment-1986790007

Signed-off-by: BootsofLagrangian <hard2251@yonsei.ac.kr>
@kohya-ss kohya-ss marked this pull request as ready for review March 24, 2024 09:46
@Trojaner
Copy link

Trojaner commented Mar 24, 2024

Edit: see comment below for reason (missing dtype=weight_dtype)
I cannot get sd1.5 lora to work with bf16:

accelerate launch \
  --mixed_precision=bf16 \
  --num_processes=2 \
  --num_machines=1 \
  --multi_gpu \
  --main_process_ip=localhost \
  --main_process_port=29555 \
  --num_cpu_threads_per_process=2 \
  ./train_network.py \
    --config_file=/home/ml/checkpoints/sd15/config.toml

(...)
Traceback (most recent call last):
  File "/home/ml/sd-scripts/./train_network.py", line 1087, in <module>
    trainer.train(args)
  File "/home/ml/sd-scripts/./train_network.py", line 839, in train
    noise_pred = self.call_unet(
  File "/home/ml/sd-scripts/./train_network.py", line 130, in call_unet
    noise_pred = unet(noisy_latents, timesteps, text_conds).sample
  File "/home/ml/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ml/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ml/sd-scripts/library/original_unet.py", line 1582, in forward
    sample = self.conv_in(sample)
  File "/home/ml/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ml/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ml/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/ml/sd-scripts/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

config.toml

pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
dataset_config = "/home/ml/checkpoints/sd15/dataset.toml"
xformers = true
deepspeed = true
zero_stage = 2
mixed_precision = "bf16"
save_precision = "bf16"
full_bf16 = true
no_half_vae = true
train_batch_size = 24
max_data_loader_n_workers = 4
persistent_data_loader_workers = true
optimizer_type = "AdamW8bit"
optimizer_args = [ "weight_decay=1e-1", ]
lr_scheduler = "constant"
max_train_steps = 78452
gradient_checkpointing = true
gradient_accumulation_steps = 16
learning_rate = 4e-5
unet_lr = 4e-5
text_encoder_lr = 2e-5
max_grad_norm = 1.0
max_token_length = 225
network_alpha = 64
network_dim = 128
network_module = "networks.lora"
cache_latents = true
cache_latents_to_disk = true

fine_tune.py Outdated Show resolved Hide resolved
@kohya-ss kohya-ss merged commit ea05e3f into dev Mar 26, 2024
2 checks passed
@kohya-ss kohya-ss deleted the deep-speed branch March 26, 2024 10:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants