Skip to content

Commit

Permalink
kill off old vanilla CLIP dependency (not used anymore)
Browse files Browse the repository at this point in the history
  • Loading branch information
neggles committed Jul 9, 2024
1 parent d447edc commit 002e1b6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dynamic = ["version"]
dependencies = [
"accelerate >= 0.26.1",
"chardet == 5.1.0",
"clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33#egg=clip",
"colorama >= 0.4.3, < 0.5.0",
"colorcet >= 3.0.1, < 4.0.0",
"crc32c >= 2.3",
Expand Down
12 changes: 6 additions & 6 deletions src/neurosis/fsdp/diffusionpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import lightning.pytorch as pl
import torch
from clip.model import ResidualAttentionBlock
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.pytorch.plugins.precision import Precision
Expand All @@ -13,9 +12,9 @@
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPEncoderLayer

from neurosis.models.autoencoder import AutoencoderKL
from neurosis.models.autoencoder import AutoencoderKL, FSDPAutoencoderKL
from neurosis.models.autoencoder_hf import AutoencoderKL as HFAutoencoderKL
from neurosis.models.text_encoder import (
FrozenCLIPEmbedder,
Expand All @@ -36,21 +35,22 @@

_POLICY = Union[set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
_SHARDING_STRATEGY = Union[
ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]
ShardingStrategy,
Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"],
]


class DiffusionFsdpPolicy(ModuleWrapPolicy):
def __init__(self):
module_classes = {
CLIPAttention,
CLIPEncoderLayer,
Decoder,
Encoder,
FrozenCLIPEmbedder,
FrozenOpenCLIPEmbedder2,
FrozenT5Embedder,
GeneralConditioner,
ResidualAttentionBlock,
SpatialTransformer,
TimestepEmbedSequential,
UNetModel,
Expand Down Expand Up @@ -87,7 +87,7 @@ def __post_init__(self):

if self.vae_fp32:
fp32_classes.extend(
[AutoencoderKL, Decoder, Encoder, HFAutoencoderKL],
[AutoencoderKL, Decoder, Encoder, FSDPAutoencoderKL, HFAutoencoderKL],
)

fp32_classes = sorted(set(fp32_classes), key=lambda x: x.__name__)
Expand Down
31 changes: 23 additions & 8 deletions src/neurosis/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = logging.getLogger(__name__)

try:
from xformers import (
from xformers import ( # type: ignore
__version__ as xformers_version,
ops as xops,
)
Expand Down Expand Up @@ -512,19 +512,34 @@ def _forward(
additional_tokens: Optional[Tensor] = None,
n_times_crossframe_attn_in_self: int = 0,
) -> Tensor:
x = (
y = self.norm1(x)
if self.disable_self_attn:
x = x + self.attn1(
y,
context=context,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=0,
)
else:
x = x + self.attn1(
y,
context=None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self,
)

x = x + (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
n_times_crossframe_attn_in_self=0
if self.disable_self_attn
else n_times_crossframe_attn_in_self,
)
+ x
)
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
x = self.ff(self.norm3(x)) + x
x = x + self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens)
x = x + self.ff(self.norm3(x))
return x


Expand Down

0 comments on commit 002e1b6

Please sign in to comment.