Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

documentation updates, apple pytorch 2.4 #595

Merged
merged 5 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 38 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,20 @@

**SimpleTuner** is a repository dedicated to a set of experimental scripts designed for training optimization. The project is geared towards simplicity, with a focus on making the code easy to read and understand. This codebase serves as a shared academic exercise, and contributions are welcome.

- Multi-GPU training
- Aspect bucketing "just works"; fill a folder of images and let it rip
- Multiple datasets can be used in a single training session, each with a different base resolution.
- VRAM-saving techniques, such as pre-computing VAE and text encoder outputs
- Full featured fine-tuning support
- Bias training (BitFit)
- LoRA training support

## Table of Contents

- [Design Philosophy](#design-philosophy)
- [Tutorial](#tutorial)
- [Features](#features)
- [Hardware Requirements](#hardware-requirements)
- [SDXL](#sdxl)
- [Stable Diffusion 2.0/2.1](#stable-diffusion-2x)
- [PixArt Sigma](#pixart-sigma)
- [Stable Diffusion 2.0/2.1](#stable-diffusion-20--21)
- [Stable Diffusion 3.0](#stable-diffusion-3)
- [AuraFlow](#auraflow)
- [Kwai Kolors](#kwai-kolors)
- [Hardware Requirements](#hardware-requirements)
- [SDXL](#sdxl-1024px)
- [Stable Diffusion (Legacy)](#stable-diffusion-2x-768px)
- [AuraFlow v0.1](#auraflow-v01)
- [Scripts](#scripts)
- [Toolkit](#toolkit)
- [Setup](#setup)
Expand All @@ -44,22 +41,33 @@ For memory-constrained systems, see the [DeepSpeed document](/documentation/DEEP

## Features

- Precomputed VAE (latents) outputs saved to storage, eliminating the need to invoke the VAE during training.
- Precomputed captions are run through the text encoder(s) and saved to storage to save on VRAM.
- Trainable on a 24G GPU, or even down to 16G at lower base resolutions.
- LoRA training for SDXL, SD3, and SD 2.x that uses less than 16G VRAM.
- Multi-GPU training
- Image and caption features (embeds) are cached to the hard drive in advance, so that training runs faster and with less memory consumption
- Aspect bucketing: support for a variety of image sizes and aspect ratios, enabling widescreen and portrait training.
- Refiner LoRA or full u-net training for SDXL
- Most models are trainable on a 24G GPU, or even down to 16G at lower base resolutions.
- LoRA training for PixArt, SDXL, SD3, and SD 2.x that uses less than 16G VRAM; AuraFlow uses less than 24G VRAM
- DeepSpeed integration allowing for [training SDXL's full u-net on 12G of VRAM](/documentation/DEEPSPEED.md), albeit very slowly.
- Optional EMA (Exponential moving average) weight network to counteract model overfitting and improve training stability. **Note:** This does not apply to LoRA.
- Support for a variety of image sizes and aspect ratios, enabling widescreen and portrait training.
- Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3)
- [DeepFloyd stage I and II full u-net or parameter-efficient fine-tuning](/documentation/DEEPFLOYD.md) via LoRA using 22G VRAM
- SDXL Refiner LoRA or full u-net training, incl validation using img2img
- Full [ControlNet model training](/documentation/CONTROLNET.md) (not ControlLoRA or ControlLite)
- For only SDXL and SD 1.x/2.x, full [ControlNet model training](/documentation/CONTROLNET.md) (not ControlLoRA or ControlLite)
- Training [Mixture of Experts](/documentation/MIXTURE_OF_EXPERTS.md) for lightweight, high-quality diffusion models
- Webhook support for updating eg. Discord channels with your training progress, validations, and errors
- Integration with the [Hugging Face Hub](https://huggingface.co) for seamless model upload and nice automatically-generated model cards.

### Stable Diffusion 2.0/2.1

### PixArt Sigma

SimpleTuner has extensive training integration with PixArt Sigma - both the 600M & 900M models load without any fuss.

- Text encoder training is not supported, as T5 is enormous.
- LoRA and full tuning both work as expected
- ControlNet training is not yet supported
- [Two-stage PixArt](https://huggingface.co/ptx0/pixart-900m-1024-ft-v0.7-stage1) training support (see: [MIXTURE_OF_EXPERTS](/documentation/MIXTURE_OF_EXPERTS.md))

See the [PixArt Quickstart](/documentation/quickstart/SIGMA.md) guide to start training.

### Stable Diffusion 2.0 & 2.1

Stable Diffusion 2.1 is known for difficulty during fine-tuning, but this doesn't have to be the case. Related features in SimpleTuner include:

Expand All @@ -70,17 +78,12 @@ Stable Diffusion 2.1 is known for difficulty during fine-tuning, but this doesn'

### Stable Diffusion 3

This model is very new and the current level of support for it in SimpleTuner is preliminary:

- LoRA and full finetuning are supported as usual.
- ControlNet is not yet implemented.
- Certain features such as segmented timestep selection and Compel long prompt weighting are not yet supported.
- Parameters have been optimised to get the best results, validated through from-scratch training of SD3 models

A few sharp edges could catch you off-guard, but for the most part, this initial pass at SD3 support is considered to be robust enough not to let you screw up too many parameters - it will oftentimes simply override bad values and set them for more sensible ones.

Simply point your base model to a Stable Diffusion 3 checkpoint and set `STABLE_DIFFUSION_3=true` in your environment file.

> ⚠️ In the current source release of Diffusers, gradient checkpointing is broken for Stable Diffusion 3 models. This will result in much, much higher memory use.
See the [Stable Diffusion 3 Quickstart](/documentation/quickstart/SD3.md) to get going.

### AuraFlow

Expand All @@ -90,18 +93,6 @@ Currently, AuraFlow v0.1 has limited support for SimpleTuner:
- All limitations that apply to Stable Diffusion 3 also apply to AuraFlow
- LoRA is currently the only viable method of AuraFlow training

This model is very large, and will require more resources to train than PixArt or SDXL.

AuraFlow has some distinct advantages that make it worth investigating over Stable Diffusion 3:

- It is the largest open text-to-image model with a truly open license
- It uses the SDXL 4ch VAE which arguably provides an easier learning objective over the 16ch VAE from Stable Diffusion 3
- Though small newspaper or book print text suffers at 4ch compression levels, the overall fine details makes this approach viable.
- It uses just a single text encoder versus Stable Diffusion's three text encoders
- AuraFlow leverages EleutherAI's **Pile-T5** which was trained on **twice as much data** with **fewer parameters** than Stable Diffusion 3, DeepFloyd, and PixArt's **T5-XXL v1.1**
- Pile-T5 has gone through less content prefiltering than OpenCLIP or T5 v1.1, and has "consumed more of the Internet" than T5 v1.1
- With a large data corpus, it has potential for subtle semantic understanding of linguistic oddities, and understanding of more modern concepts without finetuning the text encoder

### Kwai Kolors

An SDXL-based model with ChatGLM (General Language Model) 6B as its text encoder, **doubling** the hidden dimension size and substantially increasing the level of local detail included in the prompt embeds.
Expand Down Expand Up @@ -137,6 +128,14 @@ Without EMA, more care must be taken not to drastically change the model leading
- NVIDIA RTX 4090 or better (24G, no EMA)
- NVIDIA RTX 4080 or better (LoRA only)

### AuraFlow v0.1

This model is very large; it will require more resources to train than any other, incurring a substantial hardware cost.

- Full tuning will OOM at a batch size of 1 on a single 80G GPU. A system with 8x A100-80G (SXM4) is a recommended minimum for FSDP (DeepSpeed ZeRO Stage 2) training.
- LoRA training will OOM at a batch size of 1 on a single 16G GPU. A system with 1x 24G is required, with a 48G GPU being an ideal size.


## Scripts

- `ubuntu.sh` - This is a basic "installer" that makes it quick to deploy on a Vast.ai instance. It might not work for every single container image.
Expand Down
24 changes: 24 additions & 0 deletions helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,17 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--gradient_precision",
type=str,
choices=["unmodified", "fp32"],
default=None,
help=(
"One of the hallmark discoveries of the Llama 3.1 paper is numeric instability when calculating"
" gradients in bf16 precision. The default behaviour when gradient accumulation steps are enabled"
" is now to use fp32 gradients, which is slower, but provides more accurate updates."
),
)
parser.add_argument(
"--local_rank",
type=int,
Expand Down Expand Up @@ -1859,4 +1870,17 @@ def parse_args(input_args=None):
f"{'PixArt Sigma' if args.pixart_sigma else 'Stable Diffusion 3'} requires --max_grad_norm=0.01 to prevent model collapse. Overriding value. Set this value manually to disable this warning."
)
args.max_grad_norm = 0.01

if args.gradient_accumulation_steps > 1:
if args.gradient_precision == "unmodified" or args.gradient_precision is None:
warning_log(
"Gradient accumulation steps are enabled, but gradient precision is set to 'unmodified'."
" This may lead to numeric instability. Consider setting --gradient_precision=fp32."
)
elif args.gradient_precision == "fp32":
info_log(
"Gradient accumulation steps are enabled, and gradient precision is set to 'fp32'."
)
args.gradient_precision = "fp32"

return args
Loading
Loading