Skip to content

Commit

Permalink
Only quantize flux T5 is also quantizing model. Load TE from original…
Browse files Browse the repository at this point in the history
… name and path if fine tuning.
  • Loading branch information
jaretburkett committed Oct 29, 2024
1 parent 4747716 commit 4aa19b5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 34 deletions.
2 changes: 2 additions & 0 deletions toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ def __init__(self, **kwargs):
class ModelConfig:
def __init__(self, **kwargs):
self.name_or_path: str = kwargs.get('name_or_path', None)
# name or path is updated on fine tuning. Keep a copy of the original
self.name_or_path_original: str = self.name_or_path
self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_pixart: bool = kwargs.get('is_pixart', False)
Expand Down
43 changes: 9 additions & 34 deletions toolkit/stable_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ def load_model(self):

elif self.model_config.is_flux:
print("Loading Flux model")
base_model_path = "black-forest-labs/FLUX.1-schnell"
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
print("Loading transformer")
subfolder = 'transformer'
transformer_path = model_path
Expand Down Expand Up @@ -688,11 +689,12 @@ def load_model(self):
text_encoder_2.to(self.device_torch, dtype=dtype)
flush()

print("Quantizing T5")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
flush()

if self.model_config.quantize:
print("Quantizing T5")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
flush()

print("Loading clip")
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
Expand Down Expand Up @@ -2304,34 +2306,7 @@ def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False
named_params[name] = param
if unet:
if self.is_flux:
# Just train the middle 2 blocks of each transformer block
# block_list = []
# num_transformer_blocks = 2
# start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
# for i in range(num_transformer_blocks):
# block_list.append(self.unet.transformer_blocks[start_block + i])
#
# num_single_transformer_blocks = 4
# start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
# for i in range(num_single_transformer_blocks):
# block_list.append(self.unet.single_transformer_blocks[start_block + i])
#
# for block in block_list:
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param

# train the guidance embedding
# if self.unet.config.guidance_embeds:
# transformer: FluxTransformer2DModel = self.unet
# for name, param in transformer.time_text_embed.named_parameters(recurse=True,
# prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param

for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
prefix="transformer.transformer_blocks"):
named_params[name] = param
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
prefix="transformer.single_transformer_blocks"):
for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"):
named_params[name] = param
else:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
Expand Down

0 comments on commit 4aa19b5

Please sign in to comment.