diff --git a/PixArt/model/blocks.py b/PixArt/model/blocks.py index 19b76d2..d540d96 100644 --- a/PixArt/model/blocks.py +++ b/PixArt/model/blocks.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from einops import rearrange +import comfy.ldm.common_dit from .utils import to_2tuple sdpa_32b = None @@ -25,11 +26,14 @@ from comfy import model_management if model_management.xformers_enabled(): - import xformers import xformers.ops + if int((xformers.__version__).split(".")[2]) >= 28: + block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens + else: + block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens else: if model_management.xpu_available: - import intel_extension_for_pytorch as ipex + import intel_extension_for_pytorch as ipex # type: ignore import os if not torch.xpu.has_fp64_dtype() and not os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None): from ...utils.IPEX.attention import scaled_dot_product_attention_32_bit @@ -70,7 +74,7 @@ def forward(self, x, cond, mask=None): if model_management.xformers_enabled(): attn_bias = None if mask is not None: - attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask) x = xformers.ops.memory_efficient_attention( q, k, v, p=self.attn_drop.p, diff --git a/Sana/models/sana_blocks.py b/Sana/models/sana_blocks.py index 09ca72a..894b23c 100644 --- a/Sana/models/sana_blocks.py +++ b/Sana/models/sana_blocks.py @@ -39,8 +39,11 @@ from comfy import model_management if model_management.xformers_enabled(): - import xformers import xformers.ops + if int((xformers.__version__).split(".")[2]) >= 28: + block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens + else: + block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens else: if model_management.xpu_available: import intel_extension_for_pytorch as ipex # type: ignore @@ -94,7 +97,7 @@ def forward(self, x, cond, mask=None): if model_management.xformers_enabled(): attn_bias = None if mask is not None: - attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask) x = xformers.ops.memory_efficient_attention( q, k, v, p=self.attn_drop.p,