Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster training with token downsampling #1151

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from

Conversation

feffy380
Copy link
Contributor

@feffy380 feffy380 commented Mar 3, 2024

This PR adds support for training with token downsampling and replaces my token merge PR (#1146).

Token downsampling is a lossy optimization that significantly speeds up inference and training. It tries to avoid the quality loss of token merging by only downsampling K and V in the attention operation (Q is preserved) and replaces the expensive token similarity calculation with simple downsampling.
Applying the optimization during training seems to have less quality loss compared to inference, so I was able to increase the amount of downsampling a lot without negative effects.

(Note: I'm using an AMD GPU. Based on some user reports, the speedup is much less dramatic on Nvidia with xformers. ToDo might be simply closing the gap in my case and you might not achieve the same speedup reported below. I'd welcome benchmarks from Nvidia users)
With a downsampling factor of 2 and resolution of 768px I get a 2x speedup for SD1.x LoRA training.
Downsampling factor 4 with max_downsample=2 gave me an even bigger 3.2x speedup with basically no quality loss.
SDXL benefits less because its architecture is already more efficient, but I still saw about 1.3x speedup at 1024px with downsample factor 2 and 1.7x with factor 4.
The potential speedup is larger at higher resolutions.

This PR adds two new flags:

  • --todo_factor: (float) token downsampling factor > 1. The inputs of the unet's self-attention layers are scaled down by this factor. Recommend 2-4. Multiple values can be specified to override the factor for different depths.
  • --todo_max_depth: ([1, 2, 3, 4]) maximum depth to apply ToDo. Max for SDXL is 2. Recommend 1 or 2. Default is autodetected based on the number of values passed to todo_factor.

Sample usage:
--todo_factor 4 --todo_max_depth 2
is the same as
--todo_factor 4 4

The unet is patched when the model is loaded, so the optimization should automatically work with all training scripts, but I only tested train_network.py and sdxl_train_network.py.
The downsampling operation is implemented with pytorch hooks, so model saving should be unaffected.

Example:

Name Downsample factor s/it Speedup
feffy-v3.50 None 2.0 1x
feffy-todo2 2 1.0 2x
feffy-todo4_2 4 0.63 3.2x

image

Training details:

  • 7900 XTX with --mem_eff_attn
  • 768px resolution with bucketing
  • Batch size 4

@gesen2egee
Copy link
Contributor

Very impressive improvement in speed as well as maintaining training quality.

Additionally, when I use --todo_args "downsample_method=nearest-exact", a parameter parsing error occurs, but it does not affect anything since it is the default value

@feffy380
Copy link
Contributor Author

feffy380 commented Mar 4, 2024

@gesen2egee I see the problem and pushed a fix.

To be honest, --todo_args doesn't seem necessary. It's a leftover from the research code which had options for both ToDo and ToMe. Here it's only needed for downsample_method, and the default method already works so well that there isn't much reason to use a more expensive interpolation mode.

If nobody objects, I'm going to remove --todo_args

Actually, I think the arguments need an overhaul. I'm not sure the average user needs this much granularity when setting the downsampling factor.
The way ToMe does it is they apply the same merge ratio to everything and use max_downsample to control which layers are affected. I'll train another lora with this approach and if it works, we can simplify the arguments to a single --todo_factor and --todo_max_downsample=[1,2,4,8]

@feffy380 feffy380 marked this pull request as draft March 4, 2024 16:13
@feffy380
Copy link
Contributor Author

feffy380 commented Mar 4, 2024

Ok I think that settles it. I'll go for simplicity

feffy-todo4_4-comparison

@feffy380 feffy380 marked this pull request as ready for review March 4, 2024 18:21
@feffy380
Copy link
Contributor Author

--todo_factor now accepts multiple values if you want to override the downsampling factor for individual depths.

Example: --todo_factor 4 2 will use factor=4 for depth 1 and factor=2 for depth 2.

@kunibald413
Copy link

Hi, Very interesting, could you add that to the flux lora training aswell? or is it straight forward?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants