From f173a7efac0a246898c039883ba2572e5ccdc6cc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 4 Aug 2023 20:06:38 +0200 Subject: [PATCH] [SDXL] Allow SDXL LoRA to be run with less than 16GB of VRAM (#4470) * correct * correct blocks * finish * finish * finish * Apply suggestions from code review * fix * up * up * up * Update examples/dreambooth/README_sdxl.md Co-authored-by: Sayak Paul * Apply suggestions from code review --------- Co-authored-by: Sayak Paul --- models/transformer_2d.py | 33 ++++++++---- models/unet_2d_blocks.py | 50 ++++++++++++++---- models/unet_2d_condition.py | 6 +-- .../versatile_diffusion/modeling_text_unet.py | 52 +++++++++++++++---- 4 files changed, 108 insertions(+), 33 deletions(-) diff --git a/models/transformer_2d.py b/models/transformer_2d.py index 998535c58a73..344a9441ced1 100644 --- a/models/transformer_2d.py +++ b/models/transformer_2d.py @@ -204,6 +204,8 @@ def __init__( self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.Tensor, @@ -289,15 +291,28 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) # 3. Output if self.is_input_continuous: diff --git a/models/unet_2d_blocks.py b/models/unet_2d_blocks.py index 8d7e864dfcab..6f3037d624f9 100644 --- a/models/unet_2d_blocks.py +++ b/models/unet_2d_blocks.py @@ -623,6 +623,8 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.FloatTensor, @@ -634,15 +636,45 @@ def forward( ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) return hidden_states diff --git a/models/unet_2d_condition.py b/models/unet_2d_condition.py index cede2ed9d36a..fea1b4cd7823 100644 --- a/models/unet_2d_condition.py +++ b/models/unet_2d_condition.py @@ -36,12 +36,8 @@ ) from .modeling_utils import ModelMixin from .unet_2d_blocks import ( - CrossAttnDownBlock2D, - CrossAttnUpBlock2D, - DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, - UpBlock2D, get_down_block, get_up_block, ) @@ -694,7 +690,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( diff --git a/pipelines/versatile_diffusion/modeling_text_unet.py b/pipelines/versatile_diffusion/modeling_text_unet.py index 7a69a7908efa..adb41a8dfd07 100644 --- a/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/pipelines/versatile_diffusion/modeling_text_unet.py @@ -800,7 +800,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): + if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( @@ -1784,6 +1784,8 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + self.gradient_checkpointing = False + def forward( self, hidden_states: torch.FloatTensor, @@ -1795,15 +1797,45 @@ def forward( ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) return hidden_states