Skip to content

Commit

Permalink
Cleaner code.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 9, 2024
1 parent 037c38e commit 11200de
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .math import attention, rope
import comfy.ops


class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__()
Expand Down Expand Up @@ -174,20 +175,19 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2), pe=pe)

attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img += img_mod1.gate * self.img_attn.proj(img_attn)
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)

# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)

if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)
Expand Down Expand Up @@ -243,7 +243,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x = x + mod.gate * output
x += mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
return x
Expand Down

0 comments on commit 11200de

Please sign in to comment.