Skip to content

Commit

Permalink
initial support of lycoris/lora + hunyuan dit
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Jun 22, 2024
1 parent 982cf79 commit fb3e8a7
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 38 deletions.
26 changes: 17 additions & 9 deletions hunyuan_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
Very beautiful Steampunk lady, long silver hair, steampunk outfit and weapon, hyperrealism, photorealistic, 8k, unreal engine
"""
NEG_PROMPT = "错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺"
CLIP_TOKENS = 75*1 + 2
CLIP_TOKENS = 75 * 1 + 2
ATTN_MODE = "xformers"
STEPS = 16
STEPS = 50
CFG_SCALE = 7
DEVICE = "cuda"
DTYPE = torch.float16
Expand All @@ -27,10 +27,18 @@
seed_everything(0)
with torch.inference_mode(True), torch.no_grad():
alphas, sigmas = load_scheduler_sigmas()
denoiser, patch_size, head_dim, clip_tokenizer, clip_encoder, mt5_embedder, vae = (
load_model("./model", dtype=DTYPE, device=DEVICE)
)
(
denoiser,
patch_size,
head_dim,
clip_tokenizer,
clip_encoder,
mt5_embedder,
vae,
) = load_model("./model", dtype=DTYPE, device=DEVICE)
denoiser.eval()
denoiser.disable_fp32_silu()
denoiser.disable_fp32_layer_norm()
denoiser.set_attn_mode(ATTN_MODE)
vae.requires_grad_(False)

Expand All @@ -42,25 +50,25 @@
clip_encoder,
# Should be same as original implementation with max_length_clip=77
# Support 75*n + 2
max_length_clip=CLIP_TOKENS
max_length_clip=CLIP_TOKENS,
)
neg_clip_h, neg_clip_m, neg_mt5_h, neg_mt5_m = get_cond(
NEG_PROMPT,
mt5_embedder,
clip_tokenizer,
clip_encoder,
max_length_clip=CLIP_TOKENS
max_length_clip=CLIP_TOKENS,
)
clip_h = torch.concat([clip_h, neg_clip_h], dim=0)
clip_m = torch.concat([clip_m, neg_clip_m], dim=0)
mt5_h = torch.concat([mt5_h, neg_mt5_h], dim=0)
mt5_m = torch.concat([mt5_m, neg_mt5_m], dim=0)
torch.cuda.empty_cache()

style = torch.as_tensor([0]*2, device=DEVICE)
style = torch.as_tensor([0] * 2, device=DEVICE)
# src hw, dst hw, 0, 0
size_cond = [1024, 1024, 1024, 1024, 0, 0]
image_meta_size = torch.as_tensor([size_cond]*2, device=DEVICE)
image_meta_size = torch.as_tensor([size_cond] * 2, device=DEVICE)
freqs_cis_img = calc_rope(1024, 1024, patch_size, head_dim)

denoiser_wrapper = DiscreteVDDPMDenoiser(
Expand Down
54 changes: 40 additions & 14 deletions hunyuan_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
# sdxl_train_util.verify_sdxl_training_args(args)

if args.cache_text_encoder_outputs:
assert (
Expand Down Expand Up @@ -70,7 +70,7 @@ def load_target_model(self, args, weight_dtype, accelerator):
)

def load_tokenizer(self, args):
tokenizer = hunyuan_utils.load_tokenizers(args)
tokenizer = hunyuan_utils.load_tokenizers()
return tokenizer

def is_text_encoder_outputs_cached(self, args):
Expand Down Expand Up @@ -103,6 +103,8 @@ def get_text_cond(
):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
print("input_ids1", input_ids1.shape)
print("input_ids2", input_ids2.shape)
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
Expand All @@ -119,6 +121,8 @@ def get_text_cond(
accelerator=accelerator,
)
)
print("encoder_hidden_states1", encoder_hidden_states1.shape)
print("encoder_hidden_states2", encoder_hidden_states2.shape)
else:
raise NotImplementedError
return encoder_hidden_states1, mask1, encoder_hidden_states2, mask2
Expand All @@ -139,17 +143,22 @@ def call_unet(
) # TODO check why noisy_latents is not weight_dtype

# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
orig_size = batch["original_sizes_hw"] # B, 2
crop_size = batch["crop_top_lefts"] # B, 2
target_size = batch["target_sizes_hw"] # B, 2
B, C, H, W = noisy_latents.shape

# TODO implement correct meta_size info

style = torch.as_tensor([0] * B, device=accelerator.device)
# src hw, dst hw, 0, 0
size_cond = [1024, 1024, 1024, 1024, 0, 0]
image_meta_size = torch.as_tensor([size_cond], device=accelerator.device)
freqs_cis_img = hunyuan_utils.calc_rope(H*8, W*8, 2, 88)
image_meta_size = torch.concat(
[
orig_size,
target_size,
# Not following SDXL but following HunYuan's Implementation
# TODO examine if this is correct
torch.zeros_like(target_size),
]
)
freqs_cis_img = hunyuan_utils.calc_rope(H * 8, W * 8, 2, 88)

# concat embeddings
encoder_hidden_states1, mask1, encoder_hidden_states2, mask2 = text_conds
Expand All @@ -160,12 +169,13 @@ def call_unet(
text_embedding_mask=mask1,
encoder_hidden_states_t5=encoder_hidden_states2,
text_embedding_mask_t5=mask2,
image_meta_size=None,
image_meta_size=image_meta_size,
style=style,
cos_cis_img=freqs_cis_img[0],
sin_cis_img=freqs_cis_img[1],
)
return noise_pred
# TODO Handle learned sigma correctly
return noise_pred.chunk(2, dim=1)[0]

def sample_images(
self,
Expand All @@ -179,7 +189,23 @@ def sample_images(
text_encoder,
unet,
):
raise NotImplementedError
steps = global_step
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if (
steps % args.sample_every_n_steps != 0 or epoch is not None
): # steps is not divisible or end of epoch
return
logger.warning("Sampling images not supported yet.")


def setup_parser() -> argparse.ArgumentParser:
Expand Down
79 changes: 73 additions & 6 deletions library/hunyuan_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,13 @@ def __init__(
max_length=128,
):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch_dtype or torch.bfloat16
self.max_length = max_length
if model_kwargs is None:
model_kwargs = {
# "low_cpu_mem_usage": True,
"torch_dtype": self.torch_dtype,
}
model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device}
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
if use_tokenizer_only:
return
Expand All @@ -61,6 +59,53 @@ def __init__(
.eval()
.to(self.torch_dtype)
)
self.register_buffer("device", torch.tensor(0.0), persistent=False)

def get_token_embedding(self):
return self.model.shared

def gradient_checkpointing_enable(self):
for block in self.model.encoder.block:
block.org_forward = block.forward

def mt5_block_forward(
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
return_dict=True,
):
return checkpoint.checkpoint(
block.org_forward,
hidden_states,
attention_mask,
position_bias,
encoder_hidden_states,
encoder_attention_mask,
encoder_decoder_position_bias,
layer_head_mask,
cross_attn_layer_head_mask,
past_key_value,
use_cache,
output_attentions,
return_dict,
use_reentrant=False,
)

block.forward = mt5_block_forward

def gradient_checkpointing_disable(self):
for block in self.model.encoder.block:
if hasattr(block, "org_forward"):
block.forward = block.org_forward
delattr(block, "org_forward")

def get_tokens_and_mask(self, texts):
text_tokens_and_mask = self.tokenizer(
Expand Down Expand Up @@ -110,6 +155,8 @@ def get_input_ids(self, caption):
).input_ids

def get_hidden_states(self, input_ids, layer_index=-1):
if input_ids.dim() == 3:
input_ids = input_ids.view(input_ids.size(0), -1)
mask = (input_ids != 0).long()
outputs = self.model(
input_ids=input_ids, attention_mask=mask, output_hidden_states=True
Expand Down Expand Up @@ -660,6 +707,7 @@ def modulate(x, shift, scale):

class FP32_Layernorm(nn.LayerNorm):
enable_fp32 = True

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if self.enable_fp32:
return F.layer_norm(
Expand All @@ -677,9 +725,12 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:

class FP32_SiLU(nn.SiLU):
enable_fp32 = True

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if self.enable_fp32:
return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)
return torch.nn.functional.silu(inputs.float(), inplace=False).to(
inputs.dtype
)
return torch.nn.functional.silu(inputs, inplace=False).to(inputs.dtype)


Expand Down Expand Up @@ -793,11 +844,19 @@ def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):

return x

def forward(self, *args, **kwargs):
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
if self.gradient_checkpointing and self.training:
return checkpoint.checkpoint(self._forward, *args, **kwargs)
return checkpoint.checkpoint(
self._forward,
x,
c,
text_states,
freq_cis_img,
skip,
use_reentrant=False,
)
else:
return self._forward(*args, **kwargs)
return self._forward(x, c, text_states, freq_cis_img, skip)


class FinalLayer(nn.Module):
Expand Down Expand Up @@ -967,6 +1026,14 @@ def set_attn_mode(self, attn_mode):
for block in self.blocks:
block.set_attn_mode(attn_mode)

def set_use_memory_efficient_attention(self, xformers, mem_eff):
if xformers:
self.set_attn_mode("xformers")
elif mem_eff:
self.set_attn_mode("torch")
else:
self.set_attn_mode("vanilla")

def enable_fp32_layer_norm(self):
FP32_Layernorm.enable_fp32 = True

Expand Down
11 changes: 6 additions & 5 deletions library/hunyuan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,12 @@ def hunyuan_get_hidden_states(
)
input_ids1 = input_ids1.to(device)
input_ids2 = input_ids2.to(device)
if input_ids1.dim() == 2:
input_ids1 = input_ids1.unsqueeze(0)
clip_hidden_states, clip_mask = clip_get_hidden_states(
input_ids1.unsqueeze(0).to(device),
input_ids1.to(device),
tokenizer1,
clip_encoder,
text_encoder1,
max_token_length=max_token_length + 2,
)
mt5_hidden_states, mt5_mask = text_encoder2.get_hidden_states(input_ids2)
Expand Down Expand Up @@ -211,11 +213,11 @@ def load_tokenizers():
subfolder=TOKENIZER1_PATH,
)
tokenizer.eos_token_id = tokenizer.sep_token_id
tokenizer2 = T5Tokenizer(
tokenizer2 = T5Tokenizer.from_pretrained(
BASE_PATH,
subfolder=TOKENIZER2_PATH,
)
return tokenizer, tokenizer2
return [tokenizer, tokenizer2]


def load_scheduler_sigmas():
Expand Down Expand Up @@ -244,7 +246,6 @@ def load_model(model_path: str, dtype=torch.float16, device="cuda"):
.to(device)
.to(dtype)
)
mt5_embedder.device = device

vae = (
AutoencoderKL.from_pretrained(os.path.join(model_path, "vae"))
Expand Down
17 changes: 16 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection, T5Tokenizer
import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
Expand Down Expand Up @@ -68,6 +68,7 @@
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.hunyuan_utils as hunyuan_utils
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
Expand Down Expand Up @@ -812,6 +813,20 @@ def get_input_ids(self, caption, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizers[0]

# HunYuan DiT
if not isinstance(tokenizer, CLIPTokenizer):
if isinstance(tokenizer, T5Tokenizer):
result = tokenizer(
caption,
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt",
).input_ids
else:
result = hunyuan_utils.clip_get_input_ids(caption, tokenizer, self.tokenizer_max_length)
return result

input_ids = tokenizer(
caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt"
).input_ids
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ imagesize==1.4.1
# for cuda 12.1(default 11.8)
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/

# for HunYuanDiT
sentencepiece==0.2.0
timm==1.0.7

# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
Expand Down
Loading

0 comments on commit fb3e8a7

Please sign in to comment.