Skip to content

Commit

Permalink
make-fast-test-for-StableDiffusionControlNetPipeline-faster (#5292)
Browse files Browse the repository at this point in the history
* decrease UNet2DConditionModel & ControlNetModel blocks

* decrease UNet2DConditionModel & ControlNetModel blocks

* decrease even more blocks & number of norm groups

* decrease vae block out channels and n of norm goups

* fix code style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
m0saan and sayakpaul authored Oct 9, 2023
1 parent 2ed7e05 commit c4d6620
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions tests/pipelines/controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,25 @@ class ControlNetPipelineFastTests(
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
block_out_channels=(4, 8),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=1,
)
torch.manual_seed(0)
controlnet = ControlNetModel(
block_out_channels=(32, 64),
block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
)
torch.manual_seed(0)
scheduler = DDIMScheduler(
Expand All @@ -147,12 +149,13 @@ def get_dummy_components(self):
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
Expand Down Expand Up @@ -230,14 +233,15 @@ class StableDiffusionMultiControlNetPipelineFastTests(
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
block_out_channels=(4, 8),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=1,
)
torch.manual_seed(0)

Expand All @@ -247,23 +251,25 @@ def init_weights(m):
m.bias.data.fill_(1.0)

controlnet1 = ControlNetModel(
block_out_channels=(32, 64),
block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
)
controlnet1.controlnet_down_blocks.apply(init_weights)

torch.manual_seed(0)
controlnet2 = ControlNetModel(
block_out_channels=(32, 64),
block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
)
controlnet2.controlnet_down_blocks.apply(init_weights)

Expand All @@ -277,12 +283,13 @@ def init_weights(m):
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
Expand Down Expand Up @@ -415,14 +422,15 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
block_out_channels=(4, 8),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=1,
)
torch.manual_seed(0)

Expand All @@ -432,12 +440,13 @@ def init_weights(m):
m.bias.data.fill_(1.0)

controlnet = ControlNetModel(
block_out_channels=(32, 64),
block_out_channels=(4, 8),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
norm_num_groups=1,
)
controlnet.controlnet_down_blocks.apply(init_weights)

Expand All @@ -451,12 +460,13 @@ def init_weights(m):
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
Expand Down

0 comments on commit c4d6620

Please sign in to comment.