Skip to content

Commit

Permalink
[SDXL] Allow SDXL LoRA to be run with less than 16GB of VRAM (hugging…
Browse files Browse the repository at this point in the history
…face#4470)

* correct

* correct blocks

* finish

* finish

* finish

* Apply suggestions from code review

* fix

* up

* up

* up

* Update examples/dreambooth/README_sdxl.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Apply suggestions from code review

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
patrickvonplaten and sayakpaul authored Aug 4, 2023
1 parent 72e06b0 commit f173a7e
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 33 deletions.
33 changes: 24 additions & 9 deletions models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def __init__(
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)

self.gradient_checkpointing = False

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -289,15 +291,28 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)

# 3. Output
if self.is_input_continuous:
Expand Down
50 changes: 41 additions & 9 deletions models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,8 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)

self.gradient_checkpointing = False

def forward(
self,
hidden_states: torch.FloatTensor,
Expand All @@ -634,15 +636,45 @@ def forward(
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
if self.training and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)

return hidden_states

Expand Down
6 changes: 1 addition & 5 deletions models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,8 @@
)
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
UpBlock2D,
get_down_block,
get_up_block,
)
Expand Down Expand Up @@ -694,7 +690,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
fn_recursive_set_attention_slice(module, reversed_slice_size)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def forward(
Expand Down
52 changes: 42 additions & 10 deletions pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
fn_recursive_set_attention_slice(module, reversed_slice_size)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value

def forward(
Expand Down Expand Up @@ -1784,6 +1784,8 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)

self.gradient_checkpointing = False

def forward(
self,
hidden_states: torch.FloatTensor,
Expand All @@ -1795,15 +1797,45 @@ def forward(
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
if self.training and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)

return hidden_states

Expand Down

0 comments on commit f173a7e

Please sign in to comment.