Skip to content

Commit

Permalink
Merge pull request apple#37 from huggingface/merg_unet_attn_into_glide
Browse files Browse the repository at this point in the history
merge unet attention into glide attention
  • Loading branch information
patrickvonplaten authored Jun 28, 2022
2 parents 9dccc7d + c45fd74 commit e372767
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 107 deletions.
56 changes: 0 additions & 56 deletions src/diffusers/models/attention2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,62 +32,6 @@ def forward(self, x):
return self.to_out(out)


# unet.py
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels

self.norm = normalization(in_channels, swish=None, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)

def forward(self, x):
print("x", x.abs().sum())
h_ = x
h_ = self.norm(h_)

print("hid_states shape", h_.shape)
print("hid_states", h_.abs().sum())
print("hid_states - 3 - 3", h_.view(h_.shape[0], h_.shape[1], -1)[:, :3, -3:])

q = self.q(h_)
k = self.k(h_)
v = self.v(h_)

print(self.q)
print("q_shape", q.shape)
print("q", q.abs().sum())
# print("k_shape", k.shape)
# print("k", k.abs().sum())
# print("v_shape", v.shape)
# print("v", v.abs().sum())

# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw

w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)

print("weight", w_.abs().sum())

# attend to values
v = v.reshape(b, c, h * w)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)

h_ = self.proj_out(h_)

return x + h_


# unet_glide.py & unet_ldm.py
class AttentionBlock(nn.Module):
"""
Expand Down
43 changes: 1 addition & 42 deletions src/diffusers/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
from .attention2d import AttnBlock, AttentionBlock
from .attention2d import AttentionBlock


def nonlinearity(x):
Expand Down Expand Up @@ -86,44 +86,6 @@ def forward(self, x, temb):
return x + h


#class AttnBlock(nn.Module):
# def __init__(self, in_channels):
# super().__init__()
# self.in_channels = in_channels
#
# self.norm = Normalize(in_channels)
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
# def forward(self, x):
# h_ = x
# h_ = self.norm(h_)
# q = self.q(h_)
# k = self.k(h_)
# v = self.v(h_)
#
# compute attention
# b, c, h, w = q.shape
# q = q.reshape(b, c, h * w)
# q = q.permute(0, 2, 1) # b,hw,c
# k = k.reshape(b, c, h * w) # b,c,hw
# w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
# w_ = w_ * (int(c) ** (-0.5))
# w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values
# v = v.reshape(b, c, h * w)
# w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
# h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
# h_ = h_.reshape(b, c, h, w)
#
# h_ = self.proj_out(h_)
#
# return x + h_


class UNetModel(ModelMixin, ConfigMixin):
def __init__(
self,
Expand Down Expand Up @@ -186,7 +148,6 @@ def __init__(
)
block_in = block_out
if curr_res in attn_resolutions:
# attn.append(AttnBlock(block_in))
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
down = nn.Module()
down.block = block
Expand All @@ -202,7 +163,6 @@ def __init__(
self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
# self.mid.attn_1 = AttnBlock(block_in)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
Expand All @@ -228,7 +188,6 @@ def __init__(
)
block_in = block_out
if curr_res in attn_resolutions:
# attn.append(AttnBlock(block_in))
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
up = nn.Module()
up.block = block
Expand Down
17 changes: 8 additions & 9 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,25 +858,26 @@ def test_ddpm_cifar10(self):
image_slice = image[0, -1, -3:, -3:].cpu()

assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761])
expected_slice = torch.tensor([0.2249, 0.3375, 0.2359, 0.0929, 0.3439, 0.3156, 0.1937, 0.3585, 0.1761])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

@slow
def test_ddim_cifar10(self):
generator = torch.manual_seed(0)
model_id = "fusing/ddpm-cifar10"

unet = UNetModel.from_pretrained(model_id)
noise_scheduler = DDIMScheduler(tensor_format="pt")

ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)

generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0)

image_slice = image[0, -1, -3:, -3:].cpu()

assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]
[-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

Expand All @@ -895,7 +896,7 @@ def test_pndm_cifar10(self):

assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.7888, -0.7870, -0.7759, -0.7823, -0.8014, -0.7608, -0.6818, -0.7130, -0.7471]
[-0.7925, -0.7902, -0.7789, -0.7796, -0.8000, -0.7596, -0.6852, -0.7125, -0.7494]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

Expand Down Expand Up @@ -966,24 +967,22 @@ def test_grad_tts(self):

@slow
def test_score_sde_ve_pipeline(self):
torch.manual_seed(0)

model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")

sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)

torch.manual_seed(0)
image = sde_ve(num_inference_steps=2)

expected_image_sum = 3382810112.0
expected_image_mean = 1075.366455078125
expected_image_sum = 3382849024.0
expected_image_mean = 1075.3788

assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4

@slow
def test_score_sde_vp_pipeline(self):

model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp")

Expand Down

0 comments on commit e372767

Please sign in to comment.