From 002e1b6e731007814ce43f95ff592c19d6a0cc9f Mon Sep 17 00:00:00 2001 From: Andi Powers-Holmes Date: Tue, 9 Jul 2024 05:21:20 +0000 Subject: [PATCH] kill off old vanilla CLIP dependency (not used anymore) --- pyproject.toml | 1 - src/neurosis/fsdp/diffusionpolicy.py | 12 +++++------ src/neurosis/modules/attention.py | 31 +++++++++++++++++++++------- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c88992..65abcc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dynamic = ["version"] dependencies = [ "accelerate >= 0.26.1", "chardet == 5.1.0", - "clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33#egg=clip", "colorama >= 0.4.3, < 0.5.0", "colorcet >= 3.0.1, < 4.0.0", "crc32c >= 2.3", diff --git a/src/neurosis/fsdp/diffusionpolicy.py b/src/neurosis/fsdp/diffusionpolicy.py index 5e355d9..c3e2c85 100644 --- a/src/neurosis/fsdp/diffusionpolicy.py +++ b/src/neurosis/fsdp/diffusionpolicy.py @@ -4,7 +4,6 @@ import lightning.pytorch as pl import torch -from clip.model import ResidualAttentionBlock from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout from lightning.pytorch.plugins.precision import Precision @@ -13,9 +12,9 @@ from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm -from transformers.models.clip.modeling_clip import CLIPEncoderLayer +from transformers.models.clip.modeling_clip import CLIPAttention, CLIPEncoderLayer -from neurosis.models.autoencoder import AutoencoderKL +from neurosis.models.autoencoder import AutoencoderKL, FSDPAutoencoderKL from neurosis.models.autoencoder_hf import AutoencoderKL as HFAutoencoderKL from neurosis.models.text_encoder import ( FrozenCLIPEmbedder, @@ -36,13 +35,15 @@ _POLICY = Union[set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] _SHARDING_STRATEGY = Union[ - ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"] + ShardingStrategy, + Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"], ] class DiffusionFsdpPolicy(ModuleWrapPolicy): def __init__(self): module_classes = { + CLIPAttention, CLIPEncoderLayer, Decoder, Encoder, @@ -50,7 +51,6 @@ def __init__(self): FrozenOpenCLIPEmbedder2, FrozenT5Embedder, GeneralConditioner, - ResidualAttentionBlock, SpatialTransformer, TimestepEmbedSequential, UNetModel, @@ -87,7 +87,7 @@ def __post_init__(self): if self.vae_fp32: fp32_classes.extend( - [AutoencoderKL, Decoder, Encoder, HFAutoencoderKL], + [AutoencoderKL, Decoder, Encoder, FSDPAutoencoderKL, HFAutoencoderKL], ) fp32_classes = sorted(set(fp32_classes), key=lambda x: x.__name__) diff --git a/src/neurosis/modules/attention.py b/src/neurosis/modules/attention.py index d064015..7c47e2c 100644 --- a/src/neurosis/modules/attention.py +++ b/src/neurosis/modules/attention.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) try: - from xformers import ( + from xformers import ( # type: ignore __version__ as xformers_version, ops as xops, ) @@ -512,19 +512,34 @@ def _forward( additional_tokens: Optional[Tensor] = None, n_times_crossframe_attn_in_self: int = 0, ) -> Tensor: - x = ( + y = self.norm1(x) + if self.disable_self_attn: + x = x + self.attn1( + y, + context=context, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=0, + ) + else: + x = x + self.attn1( + y, + context=None, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self, + ) + + x = x + ( self.attn1( self.norm1(x), context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, - n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self - if not self.disable_self_attn - else 0, + n_times_crossframe_attn_in_self=0 + if self.disable_self_attn + else n_times_crossframe_attn_in_self, ) - + x ) - x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x - x = self.ff(self.norm3(x)) + x + x = x + self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x = x + self.ff(self.norm3(x)) return x