From 3400882a8099645ce4c797f57ac258f1e1424ffd Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 22 Oct 2024 12:21:36 -0600 Subject: [PATCH] Added preliminary support for SD3.5-large lora training --- .../examples/train_lora_sd35_large_24gb.yaml | 97 +++++++++++++++++++ jobs/process/BaseSDTrainProcess.py | 2 + toolkit/lora_special.py | 5 +- toolkit/stable_diffusion_model.py | 88 ++++++++++++++--- 4 files changed, 176 insertions(+), 16 deletions(-) create mode 100644 config/examples/train_lora_sd35_large_24gb.yaml diff --git a/config/examples/train_lora_sd35_large_24gb.yaml b/config/examples/train_lora_sd35_large_24gb.yaml new file mode 100644 index 00000000..e1766c39 --- /dev/null +++ b/config/examples/train_lora_sd35_large_24gb.yaml @@ -0,0 +1,97 @@ +--- +# NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE +job: extension +config: + # this name will be the folder and filename name + name: "my_first_sd3l_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps +# performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word +# trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: float16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 4 # how many intermittent saves to keep + push_to_hub: false #change this to True to push your trained model to Hugging Face. + # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in +# hf_repo_id: your-username/your-model-slug +# hf_private: true #whether the repo is private or public + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 1024 ] + train: + batch_size: 1 + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false # May not fully work with SD3 yet + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" + timestep_type: "linear" # linear or sigmoid + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample +# skip_first_sample: true + # uncomment to completely disable sampling +# disable_sampling: true + # uncomment to use new vell curved weighting. Experimental but may produce better results +# linear_timesteps: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on. + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for sd3, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "stabilityai/stable-diffusion-3.5-large" + is_v3: true + quantize: true # run 8bit mixed precision + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word +# - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop" + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 43d45f73..e146cc2d 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1907,6 +1907,8 @@ def _generate_readme(self, repo_id: str) -> str: tags.append("stable-diffusion-xl") if self.model_config.is_flux: tags.append("flux") + if self.model_config.is_v3: + tags.append("sd3") if self.network_config: tags.extend( [ diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 9981ec8a..6c53439a 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -232,7 +232,7 @@ def __init__( self.peft_format = peft_format # always do peft for flux only for now - if self.is_flux: + if self.is_flux or self.is_v3: self.peft_format = True if self.peft_format: @@ -326,6 +326,9 @@ def create_modules( if self.transformer_only and self.is_flux and is_unet: if "transformer_blocks" not in lora_name: skip = True + if self.transformer_only and self.is_v3 and is_unet: + if "transformer_blocks" not in lora_name: + skip = True if (is_linear or is_conv2d) and not skip: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 8fb9eefb..e85dc67e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -47,7 +47,7 @@ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ - FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler + FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel import diffusers from diffusers import \ AutoencoderKL, \ @@ -267,30 +267,84 @@ def load_model(self): pipln = self.custom_pipeline else: pipln = StableDiffusion3Pipeline - - quantization_config = BitsAndBytesConfig(load_in_8bit=True) - - model_id = "stabilityai/stable-diffusion-3-medium" - text_encoder3 = T5EncoderModel.from_pretrained( - model_id, - subfolder="text_encoder_3", - # quantization_config=quantization_config, - revision="refs/pr/26", - device_map="cuda" + + print("Loading SD3 model") + # assume it is the large model + base_model_path = "stabilityai/stable-diffusion-3.5-large" + print("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + else: + # is remote use whatever path we were given + base_model_path = model_path + + transformer = SD3Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + if not self.low_vram: + # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu + transformer.to(torch.device(self.quantize_device), dtype=dtype) + flush() + + if self.model_config.lora_path is not None: + raise ValueError("LoRA is not supported for SD3 models currently") + + if self.model_config.quantize: + quantization_type = qfloat8 + print("Quantizing transformer") + quantize(transformer, weights=quantization_type) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + print("Loading vae") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + print("Loading t5") + tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype) + text_encoder_3 = T5EncoderModel.from_pretrained( + base_model_path, + subfolder="text_encoder_3", + torch_dtype=dtype ) + + text_encoder_3.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize: + print("Quantizing T5") + quantize(text_encoder_3, weights=qfloat8) + freeze(text_encoder_3) + flush() + # see if path exists if not os.path.exists(model_path) or os.path.isdir(model_path): try: # try to load with default diffusers pipe = pipln.from_pretrained( - model_path, + base_model_path, dtype=dtype, device=self.device_torch, - text_encoder_3=text_encoder3, + tokenizer_3=tokenizer_3, + text_encoder_3=text_encoder_3, + transformer=transformer, # variant="fp16", use_safetensors=True, - revision="refs/pr/26", repo_type="model", ignore_patterns=["*.md", "*..gitattributes"], **load_args @@ -302,9 +356,11 @@ def load_model(self): else: pipe = pipln.from_single_file( model_path, + transformer=transformer, device=self.device_torch, torch_dtype=self.torch_dtype, - text_encoder_3=text_encoder3, + tokenizer_3=tokenizer_3, + text_encoder_3=text_encoder_3, **load_args ) @@ -1815,6 +1871,8 @@ def scale_model_input(model_input, timestep_tensor): pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), **kwargs, ).sample + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() elif self.is_auraflow: # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image # broadcast to batch dimension in a way that's compatible with ONNX/Core ML