Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meet error when using xformers and doing loss backward #535

Closed
eeyrw opened this issue Nov 21, 2022 · 9 comments
Closed

Meet error when using xformers and doing loss backward #535

eeyrw opened this issue Nov 21, 2022 · 9 comments

Comments

@eeyrw
Copy link

eeyrw commented Nov 21, 2022

🐛 Bug

Associated issue: huggingface/diffusers#1314
Get error when I enable xformers of UNet and try to do backward:

Traceback (most recent call last):
  File "f:/diffusers-test/vae_expr.py", line 66, in <module>
    loss.backward()
  File "C:\Users\uuu\.virtualenvs\stable-diffusion\lib\site-packages\torch\_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "C:\Users\uuu\.virtualenvs\stable-diffusion\lib\site-packages\torch\autograd\__init__.py", line 175, in backward
    allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
  File "C:\Users\uuu\.virtualenvs\stable-diffusion\lib\site-packages\torch\autograd\function.py", line 253, in apply   
    return user_fn(self, *args)
  File "f:\xformers\xformers\ops\memory_efficient_attention.py", line 414, in backward
    causal=ctx.causal,
  File "C:\Users\uuu\.virtualenvs\stable-diffusion\lib\site-packages\torch\_ops.py", line 143, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: p.gQ_strideM() == grad_q.stride(1) INTERNAL ASSERT FAILED at "F:\\xformers\\xformers\\components\\attention\\csrc\\cuda\\mem_eff_attention\\attention_backward_generic.cu":181, please report a bug to PyTorch.

Command

To Reproduce

Steps to reproduce the behavior:

import argparse
import logging
import math
import os
import random
from pathlib import Path
from typing import Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

pretrained_model_name_or_path = r'F:\diffusers-weight'
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
noise_scheduler = DDPMScheduler.from_config(pretrained_model_name_or_path, subfolder="scheduler")

# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)


weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
text_encoder.to('cuda', dtype=weight_dtype)
vae.to('cuda', dtype=weight_dtype)
unet.to('cuda', dtype=weight_dtype)
unet.set_use_memory_efficient_attention_xformers(True)

                # Convert images to latent space
images = torch.randn(1,3,512,512).to('cuda', dtype=weight_dtype)
latents = vae.encode(images).latent_dist.sample()
latents = latents * 0.18215
# Convert images to latent space
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
inputs = tokenizer('Terwt dsfs gsdgs sg"', max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
input_ids = [inputs["input_ids"]]
padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
input_ids = padded_tokens.input_ids.to('cuda', dtype=torch.int)
encoder_hidden_states = text_encoder(input_ids)[0]

# Predict the noise residual and compute loss
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
loss.backward()

Expected behavior

Nothing should happen.

Environment

System Info
diffusers version: 0.7.2
Platform: Windows-10-10.0.19041-SP0
Python version: 3.7.7
PyTorch version (GPU?): 1.12.0+cu113 (True)
Huggingface_hub version: 0.10.1
Transformers version: 4.24.0
Using GPU in script?: Yes
Using distributed or parallel set-up in script?: No
xformers version: efdca02
efdca02

@danthe3rd
Copy link
Contributor

Thanks for reporting! Do you mind sharing what is the resolution you are using? Also can you report the output of python -m xformers.info ?

@eeyrw
Copy link
Author

eeyrw commented Nov 22, 2022

@danthe3rd
(stable-diffusion) PS F:\diffusers-test> python -m xformers.info
A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
xFormers 0.0.15.dev+efdca02.d20221116
memory_efficient_attention.flshatt: available - requires GPU with compute capability 7.5+
memory_efficient_attention.cutlass: available
memory_efficient_attention.small_k: available
swiglu.fused.p.cpp: available
is_triton_available: False
is_functorch_available: False
pytorch.version: 1.12.0+cu113
pytorch.cuda: available
gpu.compute_capability: 8.6
gpu.name: NVIDIA GeForce RTX 3060

@eeyrw
Copy link
Author

eeyrw commented Nov 23, 2022

For convenience of reproducing, I make a more concise test case:

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers import UNet2DConditionModel

weight_dtype = torch.float32

cfg = {
    "_class_name": "UNet2DConditionModel",
    "_diffusers_version": "0.8.0.dev0",
    "act_fn": "silu",
    "attention_head_dim": 8,
    "block_out_channels": [
        320,
        640,
        1280,
        1280
    ],
    "center_input_sample": False,
    "cross_attention_dim": 768,
    "down_block_types": [
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D"
    ],
    "downsample_padding": 1,
    "flip_sin_to_cos": True,
    "freq_shift": 0,
    "in_channels": 4,
    "layers_per_block": 2,
    "mid_block_scale_factor": 1,
    "norm_eps": 1e-05,
    "norm_num_groups": 32,
    "out_channels": 4,
    "sample_size": 32,
    "up_block_types": [
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D"
    ]
}
unet = UNet2DConditionModel(**cfg)

unet.to('cuda', dtype=weight_dtype)
unet.set_use_memory_efficient_attention_xformers(True)
noise = torch.randn(1, 4, 64, 64).to('cuda', dtype=weight_dtype)
noisy_latents = torch.randn(1, 4, 64, 64).to('cuda', dtype=weight_dtype)
timesteps = torch.tensor(543, device='cuda', dtype=torch.int64)
encoder_hidden_states = torch.randn(1, 10, 768).to('cuda', dtype=weight_dtype)
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
loss.backward()

@danthe3rd
Copy link
Contributor

Thanks a lot - this is really useful! @artkorenev has been working on this and should have a fix coming soon

@leeruibin
Copy link

Is there any solution now, I also have the same problem when I train stablediffusion model. I can get the result with forward function. However, when I calculate loss can execute backward, it raises

raise NotImplementedError(f"No operator found for this attention: {inp}")
NotImplementedError: No operator found for this attention: Inputs

this Error is raised in file "xformers/ops/fmha/dispatch.py", line 68, in _dispatch_bw", the "inp" has the shape {query:(64,256,1,128),key:(64,77,1,128),value:(64,77,1,128)}, I guess maybe the 1 dim cause this dispatch error?

@danthe3rd
Copy link
Contributor

danthe3rd commented Jan 9, 2023

Closing this as it's resolved now.

@leeruibin this is a different / unrelated issue. Can you open a new one with the entire stacktrace/log of the error? num_heads=1 should be supported without issue. Also include the output of python -m xformers.info

@eeyrw
Copy link
Author

eeyrw commented Jan 9, 2023

Closing this as it's resolved now.

@leeruibin this is a different / unrelated issue. Can you open a new one with the entire stacktrace/log of the error? num_heads=1 should be supported without issue. Also include the output of python -m xformers.info

Just wonder the fix in which release or dev version?

@leeruibin
Copy link

Thank for your reply, I have open a new issue in this link.
#628 (comment)

@danthe3rd
Copy link
Contributor

Just wonder the fix in which release or dev version?

Woops I forgot to circle back here. It has been fixed in 3ea7307

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants