Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A big improvement for dtype casting system with fp8 storage type and manual cast #14031

Merged
merged 33 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7c128bb
Add fp8 for sd unet
KohakuBlueleaf Oct 19, 2023
5f9ddfa
Add sdxl only arg
KohakuBlueleaf Oct 19, 2023
eaa9f51
Add CPU fp8 support
KohakuBlueleaf Oct 23, 2023
9c1eba2
Fix lint
KohakuBlueleaf Oct 23, 2023
1df6c8b
fp8 for TE
KohakuBlueleaf Oct 25, 2023
4830b25
Fix alphas_cumprod dtype
KohakuBlueleaf Oct 25, 2023
bf5067f
Fix alphas cumprod
KohakuBlueleaf Oct 25, 2023
dda067f
ignore mps for fp8
KohakuBlueleaf Oct 25, 2023
0beb131
change torch version
KohakuBlueleaf Oct 25, 2023
d4d3134
ManualCast for 10/16 series gpu
KohakuBlueleaf Oct 28, 2023
ddc2a34
Add MPS manual cast
KohakuBlueleaf Oct 28, 2023
c3facab
Merge branch 'dev' into test-fp8
KohakuBlueleaf Nov 4, 2023
cd12256
Merge branch 'dev' into test-fp8
KohakuBlueleaf Nov 16, 2023
b60e108
Merge branch 'dev' into test-fp8
KohakuBlueleaf Nov 19, 2023
598da5c
Use options instead of cmd_args
KohakuBlueleaf Nov 19, 2023
890181e
Update the xformers/torch versions
KohakuBlueleaf Nov 19, 2023
f383af2
update xformers/torch versions
KohakuBlueleaf Nov 19, 2023
043d2ed
Better naming
KohakuBlueleaf Nov 19, 2023
b2e039d
Update webui-macos-env.sh
KohakuBlueleaf Nov 20, 2023
370a77f
Option for using fp16 weight when apply lora
KohakuBlueleaf Nov 21, 2023
f5d719d
Add forced reload for fp16 cache
KohakuBlueleaf Nov 21, 2023
40ac134
Fix pre-fp8
KohakuBlueleaf Nov 25, 2023
3d341eb
Merge branch 'dev' into test-fp8
KohakuBlueleaf Nov 26, 2023
110485d
Merge branch 'dev' into test-fp8
KohakuBlueleaf Dec 2, 2023
50a21cb
Ensure the cached weight will not be affected
KohakuBlueleaf Dec 2, 2023
9a15ae2
Merge branch 'dev' into test-fp8
KohakuBlueleaf Dec 3, 2023
f5f8978
Merge branch 'dev' into test-fp8
KohakuBlueleaf Dec 4, 2023
672dc4e
Fix forced reload
KohakuBlueleaf Dec 6, 2023
294ec5a
Merge branch 'dev' into test-fp8
KohakuBlueleaf Dec 6, 2023
39ebd56
Merge remote-tracking branch 'origin/release_candidate' into test-fp8
KohakuBlueleaf Dec 7, 2023
0fb34b5
Merge branch 'dev' into test-fp8
KohakuBlueleaf Dec 14, 2023
ea27215
Add FP8 settings into PNG info
KohakuBlueleaf Dec 16, 2023
8edb914
Merge branch 'dev' into test-fp8
AUTOMATIC1111 Dec 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def calc_scale(self):
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)

if len(output_shape) == 4:
Expand Down
4 changes: 2 additions & 2 deletions extensions-builtin/Lora/network_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):

def calc_updown(self, orig_weight):
output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None:
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
ex_bias = self.ex_bias.to(orig_weight.device)
else:
ex_bias = None

Expand Down
10 changes: 5 additions & 5 deletions extensions-builtin/Lora/network_glora.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
self.w2b = weights.w["b2.weight"]

def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)

output_shape = [w1a.size(0), w1b.size(1)]
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))

return self.finalize_updown(updown, orig_weight, output_shape)
12 changes: 6 additions & 6 deletions extensions-builtin/Lora/network_hada.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
self.t2 = weights.w.get("hada_t2")

def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)

output_shape = [w1a.size(0), w1b.size(1)]

if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
Expand All @@ -45,7 +45,7 @@ def calc_updown(self, orig_weight):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)

if self.t2 is not None:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
Expand Down
2 changes: 1 addition & 1 deletion extensions-builtin/Lora/network_ia3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
self.on_input = weights.w["on_input"].item()

def calc_updown(self, orig_weight):
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
w = self.w.to(orig_weight.device)

output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:
Expand Down
18 changes: 9 additions & 9 deletions extensions-builtin/Lora/network_lokr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,22 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):

def calc_updown(self, orig_weight):
if self.w1 is not None:
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
w1 = self.w1.to(orig_weight.device)
else:
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w1a = self.w1a.to(orig_weight.device)
w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b

if self.w2 is not None:
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
w2 = self.w2.to(orig_weight.device)
elif self.t2 is None:
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b
else:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
t2 = self.t2.to(orig_weight.device)
w2a = self.w2a.to(orig_weight.device)
w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)

output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
Expand Down
6 changes: 3 additions & 3 deletions extensions-builtin/Lora/network_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def create_module(self, weights, key, none_ok=False):
return module

def calc_updown(self, orig_weight):
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
up = self.up_model.weight.to(orig_weight.device)
down = self.down_model.weight.to(orig_weight.device)

output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
Expand Down
4 changes: 2 additions & 2 deletions extensions-builtin/Lora/network_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):

def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
updown = self.w_norm.to(orig_weight.device)

if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
ex_bias = self.b_norm.to(orig_weight.device)
else:
ex_bias = None

Expand Down
18 changes: 13 additions & 5 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,18 +389,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight)
if getattr(self, 'fp16_weight', None) is None:
weight = self.weight
bias = self.bias
else:
weight = self.fp16_weight.clone().to(self.weight.device)
bias = getattr(self, 'fp16_bias', None)
if bias is not None:
bias = bias.clone().to(self.bias.device)
updown, ex_bias = module.calc_updown(weight)

if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))

self.weight += updown
self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias)
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
self.bias += ex_bias
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
Expand Down
60 changes: 58 additions & 2 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ def has_mps() -> bool:
return mac_specific.has_mps


def cuda_no_autocast(device_id=None) -> bool:
if device_id is None:
device_id = get_cuda_device_id()
return (
torch.cuda.get_device_capability(device_id) == (7, 5)
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
)


def get_cuda_device_id():
return (
int(shared.cmd_opts.device_id)
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
else 0
) or torch.cuda.current_device()


def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
Expand Down Expand Up @@ -73,8 +90,7 @@ def enable_tf32():

# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
if cuda_no_autocast():
torch.backends.cudnn.benchmark = True

torch.backends.cuda.matmul.allow_tf32 = True
Expand All @@ -84,6 +100,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")

cpu: torch.device = torch.device("cpu")
fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
Expand All @@ -104,12 +121,51 @@ def cond_cast_float(input):


nv_rng = None
patch_module_list = [
torch.nn.Linear,
torch.nn.Conv2d,
torch.nn.MultiheadAttention,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
]


def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
result = self.org_forward(*args, **kwargs)
self.to(org_dtype)
return result


@contextlib.contextmanager
def manual_cast():
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward
module_type.org_forward = org_forward
try:
yield None
finally:
for module_type in patch_module_list:
module_type.forward = module_type.org_forward


def autocast(disable=False):
if disable:
return contextlib.nullcontext()

if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)

if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
return manual_cast()

if has_mps() and shared.cmd_opts.precision != "full":
return manual_cast()

if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()

Expand Down
6 changes: 6 additions & 0 deletions modules/generation_parameters_copypaste.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ def parse_generation_parameters(x: str):
if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full"

if "FP8 weight" not in res:
res["FP8 weight"] = "Disable"

if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
res["Cache FP16 weight for LoRA"] = False

skip = set(shared.opts.infotext_skip_pasting)
res = {k: v for k, v in res.items() if k not in skip}

Expand Down
2 changes: 2 additions & 0 deletions modules/initialize_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def configure_opts_onchange():
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
startup_timer.record("opts onchange")


Expand Down
2 changes: 2 additions & 0 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
"FP8 weight": opts.fp8_storage if devices.fp8 else None,
"Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
Expand Down
49 changes: 46 additions & 3 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,28 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
SkipWritingToConfig.skip = self.previous


def check_fp8(model):
if model is None:
return None
if devices.get_optimal_device_name() == "mps":
enable_fp8 = False
elif shared.opts.fp8_storage == "Enable":
enable_fp8 = True
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
enable_fp8 = True
else:
enable_fp8 = False
return enable_fp8


def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")

if devices.fp8:
# prevent model to load state dict in fp8
model.half()

if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title

Expand Down Expand Up @@ -404,6 +422,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")

for module in model.modules():
if hasattr(module, 'fp16_weight'):
del module.fp16_weight
if hasattr(module, 'fp16_bias'):
del module.fp16_bias

if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
module.fp16_weight = module.weight.data.clone().cpu().half()
if module.bias is not None:
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
else:
devices.fp8 = False

devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16

model.first_stage_model.to(devices.dtype_vae)
Expand Down Expand Up @@ -746,7 +786,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
return None


def reload_model_weights(sd_model=None, info=None):
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()

timer = Timer()
Expand All @@ -758,11 +798,14 @@ def reload_model_weights(sd_model=None, info=None):
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
if check_fp8(sd_model) != devices.fp8:
# load from state dict again to prevent extra numerical errors
forced_reload = True
elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model

sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model

if sd_model is not None:
Expand Down
2 changes: 1 addition & 1 deletion modules/sd_models_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def extend_sdxl(model):
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"

discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)

model.conditioner.wrapped = torch.nn.Module()

Expand Down
2 changes: 2 additions & 0 deletions modules/shared_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
}))

options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
Expand Down
Loading