Skip to content

Commit

Permalink
better impl to fit Kohya style
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Jun 21, 2024
1 parent 48f7739 commit dc6c97e
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 27 deletions.
2 changes: 1 addition & 1 deletion library/hunyuan_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_hidden_states(self, input_ids, layer_index=-1):
outputs = self.model(
input_ids=input_ids, attention_mask=mask, output_hidden_states=True
)
return outputs["hidden_states"][layer_index]
return outputs["hidden_states"][layer_index], mask


def reshape_for_broadcast(
Expand Down
117 changes: 91 additions & 26 deletions library/hunyuan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .hunyuan_models import MT5Embedder, HunYuanDiT, BertModel, DiT_g_2


def get_input_ids(caption, tokenizer, tokenizer_max_length=225):
def clip_get_input_ids(caption, tokenizer, tokenizer_max_length=225):
tokens = tokenizer(
caption,
padding="max_length",
Expand All @@ -24,13 +24,10 @@ def get_input_ids(caption, tokenizer, tokenizer_max_length=225):
return_tensors="pt",
)
input_ids = tokens["input_ids"]
masks = tokens["attention_mask"]

if tokenizer_max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
masks = masks.squeeze(0)
iids_list = []
mask_list = []
for i in range(
1,
tokenizer_max_length - tokenizer.model_max_length + 2,
Expand All @@ -42,12 +39,6 @@ def get_input_ids(caption, tokenizer, tokenizer_max_length=225):
input_ids[-1].unsqueeze(0),
) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
mask_chunk = (
masks[0].unsqueeze(0),
masks[i : i + tokenizer.model_max_length - 2],
masks[-1].unsqueeze(0),
)
mask_chunk = torch.cat(mask_chunk)

# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
Expand All @@ -61,16 +52,13 @@ def get_input_ids(caption, tokenizer, tokenizer_max_length=225):
ids_chunk[1] = tokenizer.eos_token_id

iids_list.append(ids_chunk)
mask_list.append(mask_chunk)

input_ids = torch.stack(iids_list) # 3,77
masks = torch.stack(mask_list) # 3,77
return input_ids, masks
return input_ids


def get_hidden_states(
def clip_get_hidden_states(
input_ids,
masks,
tokenizer,
text_encoder: BertModel,
max_token_length=225,
Expand All @@ -83,7 +71,7 @@ def get_hidden_states(
# input_ids: b,n,77
b_size = input_ids.size(0)
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
masks = masks.reshape((-1, tokenizer.model_max_length))
masks = (input_ids != tokenizer.pad_token_id).long()

encoder_hidden_states = text_encoder(input_ids, attention_mask=masks)[0]

Expand All @@ -92,6 +80,7 @@ def get_hidden_states(
(b_size, -1, encoder_hidden_states.shape[-1])
)
masks = masks.reshape((b_size, -1))
input_ids = input_ids.reshape((b_size, -1))

if max_token_length is not None:
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
Expand All @@ -100,7 +89,11 @@ def get_hidden_states(
states_list.append(
encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]
) # <BOS> の後から <EOS> の前まで
mask_list.append(masks[:, i : i + tokenizer.model_max_length - 2])
ids = input_ids[:, i : i + tokenizer.model_max_length - 2]
mask_list.append(
masks[:, i : i + tokenizer.model_max_length - 2]
* (ids[:, :1] != tokenizer.eos_token_id)
)
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
mask_list.append(masks[:, -1].unsqueeze(1))

Expand All @@ -111,7 +104,60 @@ def get_hidden_states(
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)

return encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])), masks
return (
encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])),
masks,
)


def hunyuan_get_input_ids(
caption: str,
max_token_length: int,
tokenizer1: BertTokenizer,
tokenizer2,
):
input_ids1 = clip_get_input_ids(caption, tokenizer1, max_token_length)
input_ids2 = tokenizer2(
caption,
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt",
).input_ids
return input_ids1, input_ids2


def hunyuan_get_hidden_states(
max_token_length: int,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: BertTokenizer,
tokenizer2,
text_encoder1: BertModel,
text_encoder2: MT5Embedder,
weight_dtype: Optional[str] = None,
accelerator=None,
):
device = (
accelerator.device
if accelerator is not None
else next(text_encoder1.parameters()).device
)
input_ids1 = input_ids1.to(device)
input_ids2 = input_ids2.to(device)
clip_hidden_states, clip_mask = clip_get_hidden_states(
input_ids1.unsqueeze(0).to(device),
tokenizer1,
clip_encoder,
max_token_length=max_token_length + 2,
)
mt5_hidden_states, mt5_mask = text_encoder2.get_hidden_states(input_ids2)
return (
clip_hidden_states.to(weight_dtype),
clip_mask.long().to(device),
mt5_hidden_states.to(weight_dtype),
mt5_mask.long().to(device),
)


def get_cond(
Expand All @@ -123,16 +169,15 @@ def get_cond(
dtype=None,
device="cuda",
):
'''
"""
Get CLIP and mT5 embeddings for HunYuan DiT
Note that this function support "CLIP Concat" trick.
with max_length_clip = 152/227 or higher.
'''
"""
prompt = prompt.strip()
clip_input_ids, mask = get_input_ids(prompt, clip_tokenizer, max_length_clip)
clip_hidden_states, clip_mask = get_hidden_states(
clip_input_ids = clip_get_input_ids(prompt, clip_tokenizer, max_length_clip)
clip_hidden_states, clip_mask = clip_get_hidden_states(
clip_input_ids.unsqueeze(0).to(device),
mask.to(device),
clip_tokenizer,
clip_encoder,
max_token_length=max_length_clip,
Expand Down Expand Up @@ -451,14 +496,17 @@ def calc_rope(height, width, patch_size=2, head_size=64):
clip_tokenizer = AutoTokenizer.from_pretrained("./model/clip")
clip_tokenizer.eos_token_id = 2
clip_encoder = BertModel.from_pretrained("./model/clip").half().cuda()
print(clip_tokenizer.eos_token_id, clip_tokenizer.eos_token)

mt5_embedder = MT5Embedder(
"./model/mt5", torch_dtype=torch.float16, max_length=256
).cuda()
mt5_embedder.device = "cuda"

clip_h, clip_m, mt5_h, mt5_m = get_cond("""anime style, illustration, masterpiece,
print(clip_tokenizer.pad_token_id, mt5_embedder.tokenizer.pad_token_id)
print(clip_tokenizer.eos_token_id, mt5_embedder.tokenizer.eos_token_id)

clip_h, clip_m, mt5_h, mt5_m = get_cond(
"""anime style, illustration, masterpiece,
1girl,
ciloranko, maccha (mochancc), lobelia (saclia), welchino, yanyo (ogino atsuki),
Expand All @@ -473,7 +521,24 @@ def calc_rope(height, width, patch_size=2, head_size=64):
mt5_embedder,
clip_tokenizer,
clip_encoder,
75*3+2
75 * 3 + 2,
)
print(clip_h.dtype, clip_m.dtype, mt5_h.dtype, mt5_m.dtype)
print(clip_h.shape, clip_m.shape, mt5_h.shape, mt5_m.shape)
print(mt5_m)
print(clip_m)

input_ids1, input_ids2 = hunyuan_get_input_ids(
"Hello HunYuan DiT", 77, clip_tokenizer, mt5_embedder.tokenizer
)
clip_h, clip_m, mt5_h, mt5_m = hunyuan_get_hidden_states(
77,
input_ids1,
input_ids2,
clip_tokenizer,
mt5_embedder.tokenizer,
clip_encoder,
mt5_embedder,
)
print(clip_h.dtype, clip_m.dtype, mt5_h.dtype, mt5_m.dtype)
print(clip_h.shape, clip_m.shape, mt5_h.shape, mt5_m.shape)
Expand Down

0 comments on commit dc6c97e

Please sign in to comment.