Skip to content

Commit

Permalink
remove Unet#attn_processors_state_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman committed Jun 30, 2023
1 parent 03f28e8 commit 651cc93
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 50 deletions.
21 changes: 19 additions & 2 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import shutil
import warnings
from pathlib import Path
from typing import Dict

import numpy as np
import torch
Expand Down Expand Up @@ -650,6 +651,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
return prompt_embeds


def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
r"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors

attn_processors_state_dict = {}

for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter

return attn_processors_state_dict


def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)

Expand Down Expand Up @@ -869,7 +886,7 @@ def save_model_hook(models, weights, output_dir):

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = model.attn_processors_state_dict
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
Expand Down Expand Up @@ -1303,7 +1320,7 @@ def compute_text_embeddings(prompt):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = unet.attn_processors_state_dict
unet_lora_layers = unet_attn_processors_state_dict(unet)

if text_encoder is not None and args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
Expand Down
24 changes: 0 additions & 24 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,30 +528,6 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:

return processors

@property
def attn_processors_state_dict(self) -> Dict[str, torch.tensor]:
r"""
Returns:
a state dict containing just the attention processor parameters.
"""
# set recursively
processors = {}

def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
for processor_key, processor_parameter in module.processor.state_dict().items():
processors[f"{name}.processor.{processor_key}"] = processor_parameter

for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

return processors

for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)

return processors

def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Parameters:
Expand Down
24 changes: 0 additions & 24 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,30 +632,6 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:

return processors

@property
def attn_processors_state_dict(self) -> Dict[str, torch.tensor]:
r"""
Returns:
a state dict containing just the attention processor parameters.
"""
# set recursively
processors = {}

def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
for processor_key, processor_parameter in module.processor.state_dict().items():
processors[f"{name}.processor.{processor_key}"] = processor_parameter

for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

return processors

for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)

return processors

def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Parameters:
Expand Down

0 comments on commit 651cc93

Please sign in to comment.