Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reintroduce timestep dependent shift as an option during flux training for dev and schnell, disabled by default #892

Merged
merged 1 commit into from
Aug 27, 2024

Conversation

bghira
Copy link
Owner

@bghira bghira commented Aug 27, 2024

No description provided.

…g for dev and schnell, disabled by default
@bghira bghira merged commit de4cbf5 into main Aug 27, 2024
1 check passed
@mhirki
Copy link
Contributor

mhirki commented Aug 27, 2024

I think the shift should be applied to the sigmas and not the timesteps derived from the sigmas. See sd3-ref implementation here:
https://github.com/Stability-AI/sd3-ref/blob/883b836841679d8791a5e346c861dd914fbb618d/sd3_impls.py#L33

@bghira
Copy link
Owner Author

bghira commented Aug 27, 2024

i actually agree with you, and that's how i initially messed with it here. but i went with kohya's idea so that people get more consistent experiences. if you'd like, i suggest open a PR with a modification to allow either way of doing it (defaulting to sigmas)

@mhirki
Copy link
Contributor

mhirki commented Aug 27, 2024

In official Flux inference implementation, the shift depends on image resolution:
https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/sampling.py#L66

This is yet another alternative implementation but of course we don't know whether they used this for training. Kohya's suggested value of e^1.15 matches with a resolution of 4096px.

I'll do some experiments when I have time but that may have to wait until the weekend.

@bghira
Copy link
Owner Author

bghira commented Aug 27, 2024

ha! a month ago i noted that in the SD3 paper... timestep AND resolution-dependent shift. but we just do timestep. so that is the missing piece of the puzzle.. ok i'm not supposed to be working outside of a specific set of tasks this week... but.. this could arguably be viewed as a correctness issue. i will sigh deeply and take a look here in a moment

@mhirki
Copy link
Contributor

mhirki commented Aug 28, 2024

I worked out how the shift is calculated from resolution. demo_st.py shows quite clearly how image_seq_len is calculated:

        # divide pixel space by 16**2 to account for latent space conversion
        timesteps = get_schedule(
            opts.num_steps,
            (x.shape[-1] * x.shape[-2]) // 4,
            shift=(not is_schnell),
        )

https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/demo_st.py#L196

If pixel space is 256 * 256, then we have:
image_seq_len = 256 * 256 / 16**2 = 256
This gets mapped to a shift value of math.exp(0.5) = 1.6487.

So it's pixel space divided by 16**2 which means if pixel space is 1024 * 1024, then we have:
image_seq_len = 1024 * 1024 / 16**2 = 4096
And this gets mapped to a shift value of math.exp(1.15) = 3.1581.

The shift value does seem to go a bit wild for bigger resolutions:

>>> math.exp(get_lin_function()(1536 * 1536 / (16 ** 2)))
7.513239016971895
>>> math.exp(get_lin_function()(2048 * 2048 / (16 ** 2)))
25.279656970962883

I'm not sure if these values actually make sense. So maybe some experimentation is needed.

The shift calculation is also implemented in Diffusers as calculate_shift:
https://github.com/huggingface/diffusers/blob/b5f591fea843cb4bf1932bd94d1db5d5eebe3298/src/diffusers/pipelines/flux/pipeline_flux.py#L67
Default value of max_shift is 1.16 due to a possible typo but it should be using the correct value of 1.15 from scheduler config.

@bghira bghira deleted the feature/flux-timestep-shift-reintroduction branch September 1, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants