Skip to content

Commit

Permalink
Merge pull request #14046 from hidenorly/AddFP32FallbackSupportOnSdVa…
Browse files Browse the repository at this point in the history
…eApprox

Add FP32 fallback support on sd_vae_approx
  • Loading branch information
AUTOMATIC1111 committed Dec 2, 2023
2 parents 600036d + 81c0072 commit e12a26c
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions modules/mac_specific.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

import torch
from torch import Tensor
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
Expand Down Expand Up @@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
return cumsum_func(input, *args, **kwargs)


# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
try:
return orig_func(*args, **kwargs)
except RuntimeError as e:
if "not implemented for" in str(e) and "Half" in str(e):
input_tensor = args[0]
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
else:
print(f"An unexpected RuntimeError occurred: {str(e)}")

if has_mps:
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
Expand All @@ -77,6 +89,9 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')

# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)

# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
Expand Down

0 comments on commit e12a26c

Please sign in to comment.