From 7c128bbdac0da1767c239174e91af6f327845372 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:56:17 +0800 Subject: [PATCH 01/22] Add fp8 for sd unet --- extensions-builtin/Lora/network.py | 2 +- extensions-builtin/Lora/network_full.py | 4 ++-- extensions-builtin/Lora/network_glora.py | 10 +++++----- extensions-builtin/Lora/network_hada.py | 12 ++++++------ extensions-builtin/Lora/network_ia3.py | 2 +- extensions-builtin/Lora/network_lokr.py | 18 +++++++++--------- extensions-builtin/Lora/network_lora.py | 6 +++--- extensions-builtin/Lora/network_norm.py | 4 ++-- extensions-builtin/Lora/networks.py | 6 +++--- modules/cmd_args.py | 1 + modules/sd_models.py | 3 +++ 11 files changed, 36 insertions(+), 32 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 6021fd8de0f..a62e5eff9e1 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -137,7 +137,7 @@ def calc_scale(self): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown += self.bias.to(orig_weight.device, dtype=updown.dtype) updown = updown.reshape(output_shape) if len(output_shape) == 4: diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e96c0..f221c95f3b5 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -18,9 +18,9 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): output_shape = self.weight.shape - updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.weight.to(orig_weight.device) if self.ex_bias is not None: - ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.ex_bias.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py index 492d487078d..efe5c6814fa 100644 --- a/extensions-builtin/Lora/network_glora.py +++ b/extensions-builtin/Lora/network_glora.py @@ -22,12 +22,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.w2b = weights.w["b2.weight"] def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 5fcb0695fbb..d95a0fd18e3 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,16 +27,16 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.t2 = weights.w.get("hada_t2") def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] if self.t1 is not None: output_shape = [w1a.size(1), w1b.size(1)] - t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + t1 = self.t1.to(orig_weight.device) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) output_shape += t1.shape[2:] else: @@ -45,7 +45,7 @@ def calc_updown(self, orig_weight): updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) if self.t2 is not None: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) else: updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 7edc4249791..96faeaf3ede 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -17,7 +17,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): self.on_input = weights.w["on_input"].item() def calc_updown(self, orig_weight): - w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + w = self.w.to(orig_weight.device) output_shape = [w.size(0), orig_weight.size(1)] if self.on_input: diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 340acdab3d0..fcdaeafd896 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -37,22 +37,22 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): if self.w1 is not None: - w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = self.w1.to(orig_weight.device) else: - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) w1 = w1a @ w1b if self.w2 is not None: - w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = self.w2.to(orig_weight.device) elif self.t2 is None: - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = w2a @ w2b else: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 26c0a72c237..4cc4029510f 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -61,13 +61,13 @@ def create_module(self, weights, key, none_ok=False): return module def calc_updown(self, orig_weight): - up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + up = self.up_model.weight.to(orig_weight.device) + down = self.down_model.weight.to(orig_weight.device) output_shape = [up.size(0), down.size(1)] if self.mid_model is not None: # cp-decomposition - mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + mid = self.mid_model.weight.to(orig_weight.device) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) output_shape += mid.shape[2:] else: diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index ce450158068..d25afcbb928 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -18,10 +18,10 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): def calc_updown(self, orig_weight): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(orig_weight.device) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4cf5..8ea4ea6092d 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -381,12 +381,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight += updown + self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: - self.bias = torch.nn.Parameter(ex_bias) + self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias += ex_bias + self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a84270..0f14c71e4a8 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,3 +118,4 @@ parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) +parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b6cdea18d4..3b8ff8209b3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,6 +391,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + if shared.cmd_opts.opt_unet_fp8_storage: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From 5f9ddfa46f28ca2aa9e0bd832f6bbd67069be63e Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:57:22 +0800 Subject: [PATCH 02/22] Add sdxl only arg --- modules/cmd_args.py | 1 + modules/sd_models.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 0f14c71e4a8..20bfb2c442b 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -119,3 +119,4 @@ parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) +parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b8ff8209b3..08af128fc4e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -394,6 +394,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.opt_unet_fp8_storage: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 01:49:05 +0800 Subject: [PATCH 03/22] Add CPU fp8 support Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear) And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet. --- modules/devices.py | 6 +++++- modules/processing.py | 2 +- modules/sd_models.py | 20 ++++++++++++++++---- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563517..0cd2b55d279 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -71,6 +71,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -93,10 +94,13 @@ def cond_cast_float(input): nv_rng = None -def autocast(disable=False): +def autocast(disable=False, unet=False): if disable: return contextlib.nullcontext() + if unet and fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 40598f5cff4..2df8a7ea7a8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index 08af128fc4e..c5fe57bfec0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,12 +391,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + + if shared.cmd_opts.opt_unet_fp8_storage: + enable_fp8 = True + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + enable_fp8 = True + + if enable_fp8: + devices.fp8 = True + if devices.device == devices.cpu: + for module in model.model.diffusion_model.modules(): + if isinstance(module, torch.nn.Conv2d): + module.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for cpu") + else: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From 9c1eba2af3a6f9cd6282b3a367656793cbe70c01 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 02:11:27 +0800 Subject: [PATCH 04/22] Fix lint --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index c5fe57bfec0..44d4038b25b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,7 +396,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True - + if enable_fp8: devices.fp8 = True if devices.device == devices.cpu: From 1df6c8bfec4715610d64684b6ad2fa38c76c1df6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:36:43 +0800 Subject: [PATCH 05/22] fp8 for TE --- modules/sd_models.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 44d4038b25b..69395294300 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) timer.record("apply fp8 unet for cpu") else: + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") From 4830b251366436ee8499c003fe87e46ddb4a4581 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:53:37 +0800 Subject: [PATCH 06/22] Fix alphas_cumprod dtype --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 69395294300..23660454505 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -416,6 +416,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From bf5067f50ca32cd4764638702e3cc38bca8bfd8b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 12:54:28 +0800 Subject: [PATCH 07/22] Fix alphas cumprod --- modules/sd_models.py | 3 ++- modules/sd_models_xl.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 23660454505..7ed89a9c93f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,6 +396,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True + else: + enable_fp8 = False if enable_fp8: devices.fp8 = True @@ -416,7 +418,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 0112332161f..11259a369de 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -93,7 +93,7 @@ def extend_sdxl(model): model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() - model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) model.conditioner.wrapped = torch.nn.Module() From dda067f64d3289cee3ffd65767126cb30ae73b13 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 19:53:22 +0800 Subject: [PATCH 08/22] ignore mps for fp8 --- modules/sd_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 7ed89a9c93f..ccb6afd2ade 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -392,7 +392,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.cmd_opts.opt_unet_fp8_storage: enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True From 0beb131c7ffae6f756a6339206da311232a36970 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 20:07:37 +0800 Subject: [PATCH 09/22] change torch version --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 8cdbafa50c8..636da679ecc 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -308,8 +308,8 @@ def requirements_met(requirements_file): def prepare_environment(): - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') From d4d3134f6d2d232c7bcfa80900a362921e644976 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Oct 2023 15:24:26 +0800 Subject: [PATCH 10/22] ManualCast for 10/16 series gpu --- modules/devices.py | 57 ++++++++++++++++++++++++++++++++++++++----- modules/processing.py | 2 +- modules/sd_models.py | 21 +++++++++------- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0cd2b55d279..c05f2b35e5e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -92,15 +108,44 @@ def cond_cast_float(input): nv_rng = None - - -def autocast(disable=False, unet=False): +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + +def autocast(disable=False): + print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() - if unet and fp8 and device==cpu: + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 2df8a7ea7a8..40598f5cff4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index ccb6afd2ade..31bcb913a51 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if enable_fp8: devices.fp8 = True + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + if devices.device == devices.cpu: for module in model.model.diffusion_model.modules(): if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) elif isinstance(module, torch.nn.Linear): module.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for cpu") else: - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet") + timer.record("apply fp8") + else: + devices.fp8 = False devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From ddc2a3499b8cd120b4a42358bcd33137ce1d1e75 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Sat, 28 Oct 2023 16:52:35 +0800 Subject: [PATCH 11/22] Add MPS manual cast --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index c05f2b35e5e..d7c905c283a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -121,6 +121,8 @@ def manual_autocast(): def manual_cast_forward(self, *args, **kwargs): org_dtype = next(self.parameters()).dtype self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} result = self.org_forward(*args, **kwargs) self.to(org_dtype) return result @@ -136,7 +138,6 @@ def manual_cast_forward(self, *args, **kwargs): def autocast(disable=False): - print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() @@ -146,6 +147,9 @@ def autocast(disable=False): if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): return manual_autocast() + if has_mps() and shared.cmd_opts.precision != "full": + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() From 598da5cd4928618b166886d3485ce30ce3a43490 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:50:06 +0800 Subject: [PATCH 12/22] Use options instead of cmd_args --- modules/cmd_args.py | 2 -- modules/devices.py | 25 +++++++++------- modules/initialize_util.py | 1 + modules/sd_models.py | 61 ++++++++++++++++++++------------------ modules/shared_options.py | 1 + scripts/xyz_grid.py | 1 + 6 files changed, 49 insertions(+), 42 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 088d5deaac1..a9fb9bfa3ef 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,5 +118,3 @@ parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) -parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) -parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/devices.py b/modules/devices.py index d7c905c283a..03e7bdb7c2e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ def cond_cast_float(input): torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2e9b6d895f4..1b11ead61fd 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,6 +177,7 @@ def configure_opts_onchange(): shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index a6c8b2fa7db..eb4914347b9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -339,10 +339,28 @@ def __exit__(self, exc_type, exc_value, exc_traceback): SkipWritingToConfig.skip = self.previous +def check_fp8(model): + if model is None: + return None + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.opts.fp8_storage == "Enable": + enable_fp8 = True + elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + enable_fp8 = True + else: + enable_fp8 = False + return enable_fp8 + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if not check_fp8(model) and devices.fp8: + # prevent model to load state dict in fp8 + model.half() + if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title @@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if devices.get_optimal_device_name() == "mps": - enable_fp8 = False - elif shared.cmd_opts.opt_unet_fp8_storage: - enable_fp8 = True - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - enable_fp8 = True - else: - enable_fp8 = False - - if enable_fp8: + if check_fp8(model): devices.fp8 = True - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): + first_stage = model.first_stage_model + model.first_stage_model = None + for module in model.modules(): + if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) - - if devices.device == devices.cpu: - for module in model.model.diffusion_model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) - else: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + model.first_stage_model = first_stage timer.record("apply fp8") else: devices.fp8 = False @@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None -def reload_model_weights(sd_model=None, info=None): +def reload_model_weights(sd_model=None, info=None, forced_reload=False): checkpoint_info = info or select_checkpoint() timer = Timer() @@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if check_fp8(sd_model) != devices.fp8: + # load from state dict again to prevent extra numerical errors + forced_reload = True + elif sd_model.sd_model_checkpoint == checkpoint_info.filename: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) - if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: return sd_model if sd_model is not None: diff --git a/modules/shared_options.py b/modules/shared_options.py index f1003f2187b..d27f35e96fc 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,6 +200,7 @@ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc43d..b2250c04dc0 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -270,6 +270,7 @@ def __init__(self, *args, **kwargs): AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), + AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]), ] From 890181e1d456b613bf60f6e8378dc68b39011af9 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:54:39 +0800 Subject: [PATCH 13/22] Update the xformers/torch versions --- modules/errors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index 8c339464d46..a3498c11f4b 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -93,8 +93,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.0.0" - expected_xformers_version = "0.0.20" + expected_torch_version = "2.1.0" + expected_xformers_version = "0.0.22.post7" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): From f383af2729ec2d1969200218577ab19dd78f7d48 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:23 +0800 Subject: [PATCH 14/22] update xformers/torch versions --- modules/launch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 636da679ecc..c225bbc180c 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -312,7 +312,7 @@ def prepare_environment(): torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") From 043d2edcf6a543f236f1f3cb70ac72e7b3b357b6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:31 +0800 Subject: [PATCH 15/22] Better naming --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 03e7bdb7c2e..c19a7f402f1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs): @contextlib.contextmanager -def manual_autocast(): +def manual_cast(): for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward @@ -148,10 +148,10 @@ def autocast(disable=False): return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_autocast() + return manual_cast() if has_mps() and shared.cmd_opts.precision != "full": - return manual_autocast() + return manual_cast() if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() From b2e039d07bed76350120ff448964c907a3b5e4a3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 20 Nov 2023 14:05:32 +0800 Subject: [PATCH 16/22] Update webui-macos-env.sh --- webui-macos-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 24bc5c42615..db7e8b1a05b 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" +export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0" export PYTORCH_ENABLE_MPS_FALLBACK=1 #################################################################### From 370a77f8e78e65a8a1339289d684cb43df142f70 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 21 Nov 2023 19:59:34 +0800 Subject: [PATCH 17/22] Option for using fp16 weight when apply lora --- extensions-builtin/Lora/networks.py | 16 ++++++++++++---- modules/initialize_util.py | 1 + modules/sd_models.py | 14 +++++++++++--- modules/shared_options.py | 1 + 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 0170dbfbd23..d22ed843868 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -388,18 +388,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if module is not None and hasattr(self, 'weight'): try: with torch.no_grad(): - updown, ex_bias = module.calc_updown(self.weight) + if getattr(self, 'fp16_weight', None) is None: + weight = self.weight + bias = self.bias + else: + weight = self.fp16_weight.clone().to(self.weight.device) + bias = getattr(self, 'fp16_bias', None) + if bias is not None: + bias = bias.clone().to(self.bias.device) + updown, ex_bias = module.calc_updown(weight) - if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + if len(weight.shape) == 4 and weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) + self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) + self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 1b11ead61fd..7fb1d8d5093 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -178,6 +178,7 @@ def configure_opts_onchange(): shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index eb4914347b9..0a7777f1c82 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + for module in model.modules(): + if hasattr(module, 'fp16_weight'): + del module.fp16_weight + if hasattr(module, 'fp16_bias'): + del module.fp16_bias + if check_fp8(model): devices.fp8 = True first_stage = model.first_stage_model model.first_stage_model = None for module in model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + if shared.opts.cache_fp16_weight: + module.fp16_weight = module.weight.clone().half() + if module.bias is not None: + module.fp16_bias = module.bias.clone().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") diff --git a/modules/shared_options.py b/modules/shared_options.py index d27f35e96fc..eaa9f1352c7 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -201,6 +201,7 @@ "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), + "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { From f5d719d1f1baa775d838aa75d9af1971bcc78e8f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 22 Nov 2023 01:45:56 +0800 Subject: [PATCH 18/22] Add forced reload for fp16 cache --- modules/initialize_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 7fb1d8d5093..b6767138dd5 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -178,7 +178,7 @@ def configure_opts_onchange(): shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) - shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False) startup_timer.record("opts onchange") From 40ac134c553ac824d4a96666bba14d550300daa5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 25 Nov 2023 12:35:09 +0800 Subject: [PATCH 19/22] Fix pre-fp8 --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a7777f1c82..90437c87ac1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -357,7 +357,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if not check_fp8(model) and devices.fp8: + if devices.fp8: # prevent model to load state dict in fp8 model.half() From 50a21cb09fe3e9ea2d4fe058e0484e192c8a86e3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 2 Dec 2023 22:06:47 +0800 Subject: [PATCH 20/22] Ensure the cached weight will not be affected --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4b8a9ae653a..dcf816b3a85 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -435,9 +435,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): if shared.opts.cache_fp16_weight: - module.fp16_weight = module.weight.clone().half() + module.fp16_weight = module.weight.data.clone().cpu().half() if module.bias is not None: - module.fp16_bias = module.bias.clone().half() + module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") From 672dc4efa8e0da38426b121e7c7216d0a8e465fd Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:16:10 +0800 Subject: [PATCH 21/22] Fix forced reload --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index dcf816b3a85..d0046f88c6d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -801,7 +801,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if check_fp8(sd_model) != devices.fp8: # load from state dict again to prevent extra numerical errors forced_reload = True - elif sd_model.sd_model_checkpoint == checkpoint_info.filename: + elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) From ea272152e0b50dbb2bd675ec020607f3d50c37d0 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Dec 2023 15:08:08 +0800 Subject: [PATCH 22/22] Add FP8 settings into PNG info --- modules/generation_parameters_copypaste.py | 6 ++++++ modules/processing.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4efe53e0c48..dbffe494998 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -314,6 +314,12 @@ def parse_generation_parameters(x: str): if "VAE Decoder" not in res: res["VAE Decoder"] = "Full" + if "FP8 weight" not in res: + res["FP8 weight"] = "Disable" + + if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": + res["Cache FP16 weight for LoRA"] = False + skip = set(shared.opts.infotext_skip_pasting) res = {k: v for k, v in res.items() if k not in skip} diff --git a/modules/processing.py b/modules/processing.py index bea01ec6884..179f2c0fccf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -688,6 +688,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None, + "FP8 weight": opts.fp8_storage if devices.fp8 else None, + "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None, "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None, "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),