Skip to content

Commit

Permalink
Add switch for new xformers
Browse files Browse the repository at this point in the history
#86 - patch provided by @SaiZyca
  • Loading branch information
city96 committed Dec 11, 2024
1 parent ee982c5 commit 4a11543
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
10 changes: 7 additions & 3 deletions PixArt/model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch.nn.functional as F
from einops import rearrange

import comfy.ldm.common_dit
from .utils import to_2tuple

sdpa_32b = None
Expand All @@ -25,11 +26,14 @@

from comfy import model_management
if model_management.xformers_enabled():
import xformers
import xformers.ops
if int((xformers.__version__).split(".")[2]) >= 28:
block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
else:
block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
else:
if model_management.xpu_available:
import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch as ipex # type: ignore
import os
if not torch.xpu.has_fp64_dtype() and not os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None):
from ...utils.IPEX.attention import scaled_dot_product_attention_32_bit
Expand Down Expand Up @@ -70,7 +74,7 @@ def forward(self, x, cond, mask=None):
if model_management.xformers_enabled():
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(
q, k, v,
p=self.attn_drop.p,
Expand Down
7 changes: 5 additions & 2 deletions Sana/models/sana_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@

from comfy import model_management
if model_management.xformers_enabled():
import xformers
import xformers.ops
if int((xformers.__version__).split(".")[2]) >= 28:
block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
else:
block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
else:
if model_management.xpu_available:
import intel_extension_for_pytorch as ipex # type: ignore
Expand Down Expand Up @@ -94,7 +97,7 @@ def forward(self, x, cond, mask=None):
if model_management.xformers_enabled():
attn_bias = None
if mask is not None:
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
x = xformers.ops.memory_efficient_attention(
q, k, v,
p=self.attn_drop.p,
Expand Down

0 comments on commit 4a11543

Please sign in to comment.