-
Notifications
You must be signed in to change notification settings - Fork 914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster training with token downsampling #1151
base: dev
Are you sure you want to change the base?
Conversation
Very impressive improvement in speed as well as maintaining training quality. Additionally, when I use --todo_args "downsample_method=nearest-exact", a parameter parsing error occurs, but it does not affect anything since it is the default value |
@gesen2egee I see the problem and pushed a fix. To be honest,
Actually, I think the arguments need an overhaul. I'm not sure the average user needs this much granularity when setting the downsampling factor. |
Example: |
Hi, Very interesting, could you add that to the flux lora training aswell? or is it straight forward? |
This PR adds support for training with token downsampling and replaces my token merge PR (#1146).
Token downsampling is a lossy optimization that significantly speeds up inference and training. It tries to avoid the quality loss of token merging by only downsampling K and V in the attention operation (Q is preserved) and replaces the expensive token similarity calculation with simple downsampling.
Applying the optimization during training seems to have less quality loss compared to inference, so I was able to increase the amount of downsampling a lot without negative effects.
(Note: I'm using an AMD GPU. Based on some user reports, the speedup is much less dramatic on Nvidia with xformers. ToDo might be simply closing the gap in my case and you might not achieve the same speedup reported below. I'd welcome benchmarks from Nvidia users)
With a downsampling factor of 2 and resolution of 768px I get a 2x speedup for SD1.x LoRA training.
Downsampling factor 4 with max_downsample=2 gave me an even bigger 3.2x speedup with basically no quality loss.
SDXL benefits less because its architecture is already more efficient, but I still saw about 1.3x speedup at 1024px with downsample factor 2 and 1.7x with factor 4.
The potential speedup is larger at higher resolutions.
This PR adds two new flags:
--todo_factor
: (float) token downsampling factor > 1. The inputs of the unet's self-attention layers are scaled down by this factor. Recommend 2-4. Multiple values can be specified to override the factor for different depths.--todo_max_depth
: ([1, 2, 3, 4]) maximum depth to apply ToDo. Max for SDXL is 2. Recommend 1 or 2. Default is autodetected based on the number of values passed totodo_factor
.Sample usage:
--todo_factor 4 --todo_max_depth 2
is the same as
--todo_factor 4 4
The unet is patched when the model is loaded, so the optimization should automatically work with all training scripts, but I only tested train_network.py and sdxl_train_network.py.
The downsampling operation is implemented with pytorch hooks, so model saving should be unaffected.
Example:
Training details:
--mem_eff_attn