Skip to content

Commit

Permalink
Support train TE
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Jun 22, 2024
1 parent 32f15c9 commit 6155efc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
8 changes: 8 additions & 0 deletions hunyuan_train_network.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import warnings

warnings.filterwarnings("ignore")

import argparse

import torch
Expand Down Expand Up @@ -103,6 +107,8 @@ def get_text_cond(
):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
logger.debug("input_ids1", input_ids1.shape)
logger.debug("input_ids2", input_ids2.shape)
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
Expand All @@ -119,6 +125,8 @@ def get_text_cond(
accelerator=accelerator,
)
)
logger.debug("encoder_hidden_states1", encoder_hidden_states1.shape)
logger.debug("encoder_hidden_states2", encoder_hidden_states2.shape)
else:
raise NotImplementedError
return encoder_hidden_states1, mask1, encoder_hidden_states2, mask2
Expand Down
10 changes: 9 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,15 @@ def train(self, args):

# set top parameter requires_grad = True for gradient checkpointing works
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
if hasattr(t_enc, "text_model"):
t_enc.text_model.embeddings.requires_grad_(True)
elif hasattr(t_enc, "embeddings"):
# HunYuan Bert(CLIP)
t_enc.embeddings.requires_grad_(True)
elif hasattr(t_enc, "get_token_embedding"):
# Others (mT5 or other encoder, will have custom method to get the correct embedding)
t_enc.get_token_embedding().requires_grad_(True)


else:
unet.eval()
Expand Down

0 comments on commit 6155efc

Please sign in to comment.