Skip to content

Commit

Permalink
Update README and clean-up the code for SD3 timesteps
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 7, 2024
1 parent 588ea9e commit 5e86323
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 19 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ The command to install PyTorch is as follows:

### Recent Updates

Nov 7, 2024:

- The distribution of timesteps during SD3/3.5 training has been adjusted. This applies to both fine-tuning and LoRA training. PR [#1768](https://github.com/kohya-ss/sd-scripts/pull/1768) Thanks to Dango233!
- Previously, the side closer to noise was more sampled, but now it is uniform by default. This may improve the problem of difficulty in learning details.
- Specifically, the problem of double shifting has been fixed. The default for `--weighting_scheme` has been changed to `uniform` (the previous default was `logit_normal`).
- A new option `--training_shift` has been added. The default is 1.0, and all timesteps are sampled uniformly. If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled.

Oct 31, 2024:

- Added support for SD3.5L/M training. See [SD3 training](#sd3-training) for details.
Expand Down Expand Up @@ -641,6 +648,7 @@ Here are the arguments. The arguments and sample settings are still experimental
- `--clip_l_dropout_rate`, `--clip_g_dropout_rate` and `--t5_dropout_rate` are the dropout rates for the embeddings of CLIP-L, CLIP-G, and T5XXL, described in [SAI research papre](http://arxiv.org/pdf/2403.03206). The default is 0.0. For LoRA training, it is seems to be better to set 0.0.
- `--pos_emb_random_crop_rate` is the rate of random cropping of positional embeddings, described in [SD3.5M model card](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). The default is 0. It is seems to be better to set 0.0 for LoRA training.
- `--enable_scaled_pos_embed` is to enable the scaled positional embeddings. The default is False. This option is an experimental feature for SD3.5M. Details are described below.
- `--training_shift` is the shift value for the training distribution of timesteps. The default is 1.0 (uniform distribution, no shift). If less than 1.0, the side closer to the image is more sampled, and if more than 1.0, the side closer to noise is more sampled.
Other options are described below.
Expand Down Expand Up @@ -681,7 +689,10 @@ Other options are described below.
- Same as FLUX.1 for data preparation.
- If you train with multiple resolutions, you can enable the scaled positional embeddings with `--enable_scaled_pos_embed`. The default is False. __This option is an experimental feature.__
6. Weighting scheme and training shift:
- The weighting scheme is described in the section 3.1 of the [SD3 paper](https://arxiv.org/abs/2403.03206v1).
- The uniform distribution is the default. If you want to change the distribution, see `--help` for options.
- `--training_shift` is the shift value for the training distribution of timesteps.
Technical details of multi-resolution training for SD3.5M:
Expand Down
2 changes: 1 addition & 1 deletion library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
Expand Down
2 changes: 1 addition & 1 deletion library/sd3_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None

# remove duplcates and sort latent sizes in ascending order
# remove duplicates and sort latent sizes in ascending order
latent_sizes = list(set(latent_sizes))
latent_sizes = sorted(latent_sizes)

Expand Down
17 changes: 9 additions & 8 deletions library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
)

# Dependencies of Diffusers noise sampler has been removed for clearity.
# Dependencies of Diffusers noise sampler has been removed for clarity.
parser.add_argument(
"--weighting_scheme",
type=str,
Expand Down Expand Up @@ -285,7 +285,8 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
default=1.0,
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
)



def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
Expand Down Expand Up @@ -956,9 +957,10 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting


def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# endregion


def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]

# Sample a random timestep for each image
Expand All @@ -977,13 +979,12 @@ def get_noisy_model_input_and_timesteps(
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
u = (u * shift) / (1 + (shift - 1) * u)

indices = (u * (t_max-t_min) + t_min).long()
indices = (u * (t_max - t_min) + t_min).long()
timesteps = indices.to(device=device, dtype=dtype)

# sigmas according to flowmatching
sigmas = timesteps / 1000
sigmas = sigmas.view(-1,1,1,1)
sigmas = sigmas.view(-1, 1, 1, 1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents

return noisy_model_input, timesteps, sigmas

8 changes: 4 additions & 4 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ def optimizer_hook(parameter: torch.Tensor):
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0

noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
# noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
# noise_scheduler_copy = copy.deepcopy(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
Expand Down Expand Up @@ -940,11 +940,11 @@ def optimizer_hook(parameter: torch.Tensor):

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# bsz = latents.shape[0]

# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
args, latents, noise, accelerator.device, weight_dtype
)

# debug: NaN check for all inputs
Expand Down
7 changes: 3 additions & 4 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,8 @@ def sample_images(self, accelerator, args, epoch, global_step, device, vae, toke
)

def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
# shift 3.0 is the default value
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
# this scheduler is not used in training, but used to get num_train_timesteps etc.
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
return noise_scheduler

def encode_images_to_latents(self, args, accelerator, vae, images):
Expand All @@ -304,7 +303,7 @@ def get_noise_pred_and_target(

# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
args, latents, noise, accelerator.device, weight_dtype
)

# ensure the hidden state will require grad
Expand Down

0 comments on commit 5e86323

Please sign in to comment.