Skip to content

Commit

Permalink
Add FP32 fallback support on torch.nn.functional.interpolate
Browse files Browse the repository at this point in the history
This tries to execute interpolate with FP32 if it failed.

Background is that
on some environment such as Mx chip MacOS devices, we get error as follows:

```
"torch/nn/functional.py", line 3931, in interpolate
        return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half'
```

In this case, ```--no-half``` doesn't help to solve. Therefore this commits add the FP32 fallback execution to solve it.

Note that the ```upsample_nearest2d``` is called from ```torch.nn.functional.interpolate```.
And the fallback for torch.nn.functional.interpolate is necessary at
```modules/sd_vae_approx.py``` 's ```VAEApprox.forward```
```repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py``` 's ```Upsample.forward```
  • Loading branch information
hidenorly committed Nov 28, 2023
1 parent 39eae9f commit a0096c5
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions modules/mac_specific.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging

import torch
from typing import Optional, List
from torch import Tensor
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
Expand Down Expand Up @@ -51,6 +53,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
return cumsum_func(input, *args, **kwargs)


# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
try:
return orig_func(*args, **kwargs)
except RuntimeError as e:
if "not implemented for" in str(e) and "Half" in str(e):
input_tensor = args[0]
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
else:
print(f"An unexpected RuntimeError occurred: {str(e)}")

if has_mps:
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
Expand All @@ -77,6 +90,9 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')

# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)

# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
Expand Down

0 comments on commit a0096c5

Please sign in to comment.