Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding SD3 timestep-dependent shift for Flux training #894

Merged
merged 5 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion documentation/DATALOADER.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ Images are not resized before cropping **unless** `maximum_image_size` and `targ
- This is equivalent to the commandline option `--skip_file_discovery`
- This is helpful if you have datasets you don't need the trainer to scan on every startup, eg. their latents/embeds are already cached fully. This allows quicker startup and resumption of training.

### `preserve_data_cache_backend`
### `preserve_data_backend_cache`

- You probably don't want to ever set this - it is useful only for very large AWS datasets.
- Like `skip_file_discovery`, this option can be set to prevent unnecessary, lengthy and costly filesystem scans at startup.
Expand Down
13 changes: 13 additions & 0 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,19 @@ def parse_args(input_args=None):
" which has improved results in short experiments. Thanks to @mhirki for the contribution."
),
)
parser.add_argument(
"--flux_schedule_shift",
type=float,
default=None,
help=(
"Shift the noise schedule. This is a value between 0 and ~4.0, where 0 disables the timestep-dependent shift,"
" and anything greater than 0 will shift the timestep sampling accordingly. The SD3 model was trained with"
" a shift value of 3. The value for Flux is unknown. Higher values result in less noisy timesteps sampled,"
" which results in a lower mean loss value, but not necessarily better results. Early reports indicate"
" that modification of this value can change how the contrast is learnt by the model, and whether fine"
" details are ignored or accentuated."
),
)
parser.add_argument(
"--flux_guidance_mode",
type=str,
Expand Down
1 change: 1 addition & 0 deletions helpers/training/custom_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def __init__(
self.T_mult = T_mult
self.eta_min = eta_min
self.T_cur = last_step
self.last_step = last_step
super().__init__(optimizer, last_step, verbose)

def get_lr(self):
Expand Down
7 changes: 7 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,13 @@ def main():
device=accelerator.device,
)
timesteps = sigmas * 1000.0
if (
args.flux_schedule_shift is not None
and args.flux_schedule_shift > 0
):
timesteps = (timesteps * args.flux_schedule_shift) / (
1 + (args.flux_schedule_shift - 1) * timesteps
)
sigmas = sigmas.view(-1, 1, 1, 1)
else:
# Sample a random timestep for each image, potentially biased by the timestep weights.
Expand Down
Loading