Skip to content

Commit

Permalink
use_checkpoint = False (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored May 15, 2024
1 parent 1c0a0c4 commit 5a5ac68
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion configs/alt-diffusion-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
2 changes: 1 addition & 1 deletion configs/alt-diffusion-m18-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ model:
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
2 changes: 1 addition & 1 deletion configs/instruct-pix2pix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
2 changes: 1 addition & 1 deletion configs/sd_xl_inpaint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ model:
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
use_checkpoint: False
in_channels: 9
out_channels: 4
model_channels: 320
Expand Down
2 changes: 1 addition & 1 deletion configs/v1-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
2 changes: 1 addition & 1 deletion configs/v1-inpainting-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False

first_stage_config:
Expand Down
9 changes: 6 additions & 3 deletions modules/sd_hijack_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
import ldm.modules.diffusionmodules.openaimodel


# Setting flag=False so that torch skips checking parameters.
# parameters checking is expensive in frequent operations.

def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
return checkpoint(self._forward, x, context, flag=False)


def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
return checkpoint(self._forward, x, flag=False)


def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)
return checkpoint(self._forward, x, emb, flag=False)


stored = []
Expand Down
2 changes: 1 addition & 1 deletion modules/sd_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def is_using_v_parameterization_for_sd2(state_dict):

with sd_disable_initialization.DisableInitialization():
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
use_checkpoint=True,
use_checkpoint=False,
use_fp16=False,
image_size=32,
in_channels=4,
Expand Down

0 comments on commit 5a5ac68

Please sign in to comment.