From 73ce178021338e2cea419a00ae61ec0a6630ef19 Mon Sep 17 00:00:00 2001 From: Dango233 Date: Tue, 11 Jun 2024 18:30:25 +0800 Subject: [PATCH] Remove redundancy in mmdit.py (#3685) --- comfy/ldm/modules/diffusionmodules/mmdit.py | 61 --------------------- 1 file changed, 61 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 5e7afc8d31b..be40ab9403d 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -835,72 +835,11 @@ def __init__( ) self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) - # self.initialize_weights() if compile_core: assert False self.forward_core_with_concat = torch.compile(self.forward_core_with_concat) - def initialize_weights(self): - # TODO: Init context_embedder? - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) - - # Initialize (and freeze) pos_embed by sin-cos embedding - if self.pos_embed is not None: - pos_embed_grid_size = ( - int(self.x_embedder.num_patches**0.5) - if self.pos_embed_max_size is None - else self.pos_embed_max_size - ) - pos_embed = get_2d_sincos_pos_embed( - self.pos_embed.shape[-1], - int(self.x_embedder.num_patches**0.5), - pos_embed_grid_size, - scaling_factor=self.pos_embed_scaling_factor, - offset=self.pos_embed_offset, - ) - - - pos_embed = get_2d_sincos_pos_embed( - self.pos_embed.shape[-1], - int(self.pos_embed.shape[-2]**0.5), - scaling_factor=self.pos_embed_scaling_factor, - ) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - if hasattr(self, "y_embedder"): - nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # Zero-out adaLN modulation layers in DiT blocks: - for block in self.joint_blocks: - nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) - nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - def cropped_pos_embed(self, hw, device=None): p = self.x_embedder.patch_size[0] h, w = hw