diff --git a/hunyuan_test_ft.py b/hunyuan_test_ft.py new file mode 100644 index 000000000..d79a725e8 --- /dev/null +++ b/hunyuan_test_ft.py @@ -0,0 +1,143 @@ +import numpy as np +import torch +from pathlib import Path + +from k_diffusion.external import DiscreteVDDPMDenoiser +from k_diffusion.sampling import sample_euler_ancestral, get_sigmas_exponential + +from PIL import Image +from pytorch_lightning import seed_everything + +from library.hunyuan_models import * +from library.hunyuan_utils import * + + +PROMPT = """ +qinglongshengzhe, 1girl, solo, breasts, looking at viewer, smile, open mouth, bangs, hair between eyes, bare shoulders, collarbone, upper body, detached sleeves, midriff, crop top, black background +""" +NEG_PROMPT = "错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺" +CLIP_TOKENS = 75 * 3 + 2 +ATTN_MODE = "xformers" +H = 1024 +W = 1024 +STEPS = 30 +CFG_SCALE = 5 +DEVICE = "cuda" +DTYPE = torch.float16 +USE_EXTRA_COND = False +BETA_END = 0.02 + + +def load_scheduler_sigmas(beta_start=0.00085, beta_end=0.018, num_train_timesteps=1000): + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + sigmas = np.array(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas) + return alphas_cumprod, sigmas + + +if __name__ == "__main__": + seed_everything(0) + with torch.inference_mode(True), torch.no_grad(): + alphas, sigmas = load_scheduler_sigmas(beta_end=BETA_END) + ( + denoiser, + patch_size, + head_dim, + clip_tokenizer, + clip_encoder, + mt5_embedder, + vae, + ) = load_model("/root/albertxyu/HunYuanDiT-V1.2-fp16-pruned", dtype=DTYPE, device=DEVICE, + # dit_path="./output_pro/debug-000006.ckpt", + ) + + denoiser.eval() + denoiser.disable_fp32_silu() + denoiser.disable_fp32_layer_norm() + denoiser.set_attn_mode(ATTN_MODE) + vae.requires_grad_(False) + mt5_embedder.to(torch.float16) + + + with torch.autocast("cuda"): + clip_h, clip_m, mt5_h, mt5_m = get_cond( + PROMPT, + mt5_embedder, + clip_tokenizer, + clip_encoder, + # Should be same as original implementation with max_length_clip=77 + # Support 75*n + 2 + max_length_clip=CLIP_TOKENS, + ) + neg_clip_h, neg_clip_m, neg_mt5_h, neg_mt5_m = get_cond( + NEG_PROMPT, + mt5_embedder, + clip_tokenizer, + clip_encoder, + max_length_clip=CLIP_TOKENS, + ) + clip_h = torch.concat([clip_h, neg_clip_h], dim=0) + clip_m = torch.concat([clip_m, neg_clip_m], dim=0) + mt5_h = torch.concat([mt5_h, neg_mt5_h], dim=0) + mt5_m = torch.concat([mt5_m, neg_mt5_m], dim=0) + torch.cuda.empty_cache() + + if USE_EXTRA_COND: + style = torch.as_tensor([0] * 2, device=DEVICE) + # src hw, dst hw, 0, 0 + size_cond = [H, W, H, W, 0, 0] + image_meta_size = torch.as_tensor([size_cond] * 2, device=DEVICE) + else: + style = None + image_meta_size = None + freqs_cis_img = calc_rope(H, W, patch_size, head_dim) + + denoiser_wrapper = DiscreteVDDPMDenoiser( + # A quick patch for learn_sigma + lambda *args, **kwargs: denoiser(*args, **kwargs).chunk(2, dim=1)[0], + alphas, + False, + ).to(DEVICE) + + def cfg_denoise_func(x, sigma): + cond, uncond = denoiser_wrapper( + x.repeat(2, 1, 1, 1), + sigma.repeat(2), + encoder_hidden_states=clip_h, + text_embedding_mask=clip_m, + encoder_hidden_states_t5=mt5_h, + text_embedding_mask_t5=mt5_m, + image_meta_size=image_meta_size, + style=style, + cos_cis_img=freqs_cis_img[0], + sin_cis_img=freqs_cis_img[1], + ).chunk(2, dim=0) + return uncond + (cond - uncond) * CFG_SCALE + + sigmas = denoiser_wrapper.get_sigmas(STEPS).to(DEVICE) + sigmas = get_sigmas_exponential( + STEPS, denoiser_wrapper.sigma_min, denoiser_wrapper.sigma_max, DEVICE + ) + x1 = torch.randn(1, 4, H//8, W//8, dtype=torch.float16, device=DEVICE) + + Path('imgs').mkdir(exist_ok=True, parents=True) + with torch.autocast("cuda"): + sample = sample_euler_ancestral( + cfg_denoise_func, + x1 * sigmas[0], + sigmas, + ) + torch.cuda.empty_cache() + with torch.no_grad(): + latent = sample / 0.13025 + image = vae.decode(latent).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = image.permute(0, 2, 3, 1).cpu().numpy() + image = (image * 255).round().astype(np.uint8) + image = [Image.fromarray(im) for im in image] + for im in image: + im.save("imgs/test_opro.png") diff --git a/hunyuan_train.py b/hunyuan_train.py new file mode 100644 index 000000000..2afaf0058 --- /dev/null +++ b/hunyuan_train.py @@ -0,0 +1,889 @@ +# ================================================ +# HunyuanDiT training scripts (with captions) + +import argparse +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from library import deepspeed_utils, sdxl_model_util + +import library.train_util as train_util + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util +import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, + apply_debiased_estimation, + apply_masked_loss, +) +import library.hunyuan_utils as hunyuan_utils + +UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23 + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + assert ( + not args.weighted_captions + ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.train_text_encoder or not args.cache_text_encoder_outputs + ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + if args.block_lr: + block_lrs = [float(lr) for lr in args.block_lr.split(",")] + assert ( + len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR + ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください" + else: + block_lrs = None + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenizer1, tokenizer2 = hunyuan_utils.load_tokenizers() + + # Prepare datasets + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # Prepare accelerator + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # Prepare types that supports mixed precision and casts as needed. + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # Load models + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + hydit, + logit_scale, + ckpt_info, + ) = hunyuan_utils.load_target_model(args, accelerator, "hydit", weight_dtype, args.use_extra_cond) + if args.use_extra_cond: + hydit_version = 'v1.1' + else: + hydit_version = 'v1.2' + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + # assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります" + + # Setting the flag for using Diffusers version of xformers function + def set_diffusers_xformers_flag(model, valid): + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + fn_recursive_set_mem_eff(model) + + # Integrate xformers and memory efficient attention into the model + if args.diffusers_xformers: + # もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず + accelerator.print("Use xformers by Diffusers") + # set_diffusers_xformers_flag(hydit, True) + set_diffusers_xformers_flag(vae, True) + else: + # The Windows version of xformers may not be able to train with float, so there is a need to enable settings that don't use xformers. + accelerator.print("Disable Diffusers' xformers") + train_util.replace_unet_modules(hydit, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # The following can be used with xformers compatible with PyTorch 2.0.0 and above. + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + # Prepare vae latents + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # Prepare for learning: Get the model into a proper state + if args.gradient_checkpointing: + hydit.enable_gradient_checkpointing() + train_hydit = args.learning_rate != 0 + train_text_encoder1 = False + train_text_encoder2 = False + + if args.train_text_encoder: + raise NotImplementedError("Training text encoder is not supported yet for HunyuanDiT") + else: + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) + text_encoder1.requires_grad_(False) + text_encoder2.requires_grad_(False) + text_encoder1.eval() + text_encoder2.eval() + + # Cache the output of Textencoder + if args.cache_text_encoder_outputs: + raise NotImplementedError("Caching text encoder outputs in HunyuanDiT is not supported yet") + # TODO: We just copy the code from sdxl_train.py, need to rewrite `cache_text_encoder_outputs` + # for supporting SDXL and HunyuanDiT at the same time. + # Text Encodes are eval and no grad + with torch.no_grad(), accelerator.autocast(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + hydit.requires_grad_(train_hydit) + if not train_hydit: + hydit.to(accelerator.device, dtype=weight_dtype) # because of hydit is not prepared + + training_models = [] + params_to_optimize = [] + if train_hydit: + training_models.append(hydit) + if block_lrs is None: + params_to_optimize.append({"params": list(hydit.parameters()), "lr": args.learning_rate}) + else: + raise NotImplementedError("block_lr is not supported yet for HunyuanDiT") + + if train_text_encoder1: + training_models.append(text_encoder1) + params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + if train_text_encoder2: + training_models.append(text_encoder2) + params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"train hydit: {train_hydit}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # Prepare the tools necessary for training + accelerator.print("prepare optimizer, data loader etc.") + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # Prepare the DataLoader + # Note that the number of DataLoader processes: 0 cannot use persistent_workers + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # Calculate the number of training steps + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # Send the training steps to the dataset side + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # Prepare a learning rate scheduler + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # Experimental Feature: Conducting fp16/bf16 learning, including gradients, converting the entire model to fp16/bf16. + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + hydit.to(weight_dtype) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + hydit.to(weight_dtype) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + hydit=hydit if train_hydit else None, + text_encoder1=text_encoder1 if train_text_encoder1 else None, + text_encoder2=text_encoder2 if train_text_encoder2 else None, + ) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # acceleratorがなんかよろしくやってくれるらしい + if train_hydit: + hydit = accelerator.prepare(hydit) + if train_text_encoder1: + text_encoder1 = accelerator.prepare(text_encoder1) + if train_text_encoder2: + text_encoder2 = accelerator.prepare(text_encoder2) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + train_noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=args.beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False, + steps_offset=1, + ) + prepare_scheduler_for_custom_training(train_noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(train_noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], hydit + ) + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.set_grad_enabled(args.train_text_encoder): + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, mask1, encoder_hidden_states2, mask2 = ( + hunyuan_utils.hunyuan_get_hidden_states( + args.max_token_length, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + accelerator=accelerator, + ) + ) + logger.debug("encoder_hidden_states1", encoder_hidden_states1.shape) + logger.debug("encoder_hidden_states2", encoder_hidden_states2.shape) + else: + raise NotImplementedError + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, train_noise_scheduler, latents + ) + + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + B, C, H, W = noisy_latents.shape + + if args.use_extra_cond: + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + style = torch.as_tensor([0] * B, device=accelerator.device) + image_meta_size = torch.concat([orig_size, target_size, crop_size]) + else: + style = None + image_meta_size = None + + # RoPE embeddings + freqs_cis_img = hunyuan_utils.calc_rope(H * 8, W * 8, 2, 88) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = hydit( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states1, + text_embedding_mask=mask1, + encoder_hidden_states_t5=encoder_hidden_states2, + text_embedding_mask_t5=mask2, + image_meta_size=image_meta_size, + style=style, + cos_cis_img=freqs_cis_img[0], + sin_cis_img=freqs_cis_img[1], + ) + # `noise_pred` has 8 channels. The first four channels are used for the noise prediction, and the + # last four channels are used for the variance prediction. During inference, we found that the + # predicted variance has imperceptible affect on the quality of the generated images. Therefore, we + # only use the first four channels for the noise prediction. See the following link for details. + # https://github.com/Tencent/HunyuanDiT/blob/5657364143e44ac90f72aeb47b81bd505a95665d/hydit/diffusion/gaussian_diffusion.py#L562 + noise_pred, _ = noise_pred.chunk(2, dim=1) + + if args.v_parameterization: + # v-parameterization training + target = train_noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + if ( + args.min_snr_gamma + or args.scale_v_pred_loss_like_noise_pred + or args.v_pred_like_loss + or args.debiased_estimation_loss + or args.masked_loss + ): + # do not mean over batch dimension for snr weight or scale v-pred loss + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, train_noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, train_noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, train_noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, train_noise_scheduler) + + loss = loss.mean() # mean over batch dimension + else: + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) + + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + sdxl_train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + hydit, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + hunyuan_utils.save_hydit_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder1), + accelerator.unwrap_model(text_encoder2), + accelerator.unwrap_model(hydit), + vae, + logit_scale, + ckpt_info, + hydit_version, + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + if block_lrs is None: + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_hydit) + else: + append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + hunyuan_utils.save_hydit_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder1), + accelerator.unwrap_model(text_encoder2), + accelerator.unwrap_model(hydit), + vae, + logit_scale, + ckpt_info, + hydit_version, + ) + + sdxl_train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + hydit, + ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + hydit = accelerator.unwrap_model(hydit) + text_encoder1 = accelerator.unwrap_model(text_encoder1) + text_encoder2 = accelerator.unwrap_model(text_encoder2) + + accelerator.end_training() + + if args.save_state or args.save_state_on_train_end: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + hunyuan_utils.save_hydit_model_on_train_end( + args, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + global_step, + text_encoder1, + text_encoder2, + hydit, + vae, + logit_scale, + ckpt_info, + hydit_version, + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + hunyuan_utils.add_hydit_arguments(parser) + + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder 1 (CLIP) / text encoder 1 (ViT-L)の学習率", + ) + parser.add_argument( + "--learning_rate_te2", + type=float, + default=None, + help="learning rate for text encoder 2 (mT5) / text encoder 2 (BiG-G)の学習率", + ) + + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + parser.add_argument( + "--block_lr", + type=str, + default=None, + help=f"learning rates for each block of HunyuanDiT, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) + return parser + + +if __name__ == '__main__': + parser = setup_parser() + + args = parser.parse_args() + print(args) + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/library/hunyuan_models.py b/library/hunyuan_models.py index db5aadaaf..e65a7bd88 100644 --- a/library/hunyuan_models.py +++ b/library/hunyuan_models.py @@ -947,6 +947,7 @@ def __init__( mlp_ratio=4.0, log_fn=print, attn_mode="xformers", + use_extra_cond=False, ): super().__init__() self.log_fn = log_fn @@ -962,6 +963,7 @@ def __init__( self.text_len = text_len self.text_len_t5 = text_len_t5 self.norm = norm + self.use_extra_cond = use_extra_cond log_fn(f" Use {attn_mode} attention implementation.") qk_norm = qk_norm # See http://arxiv.org/abs/2302.05442 for details. @@ -981,20 +983,24 @@ def __init__( ) # Attention pooling + pooler_out_dim = 1024 self.pooler = AttentionPool( - self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024 + self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim ) - # Here we use a default learned embedder layer for future extension. - self.style_embedder = nn.Embedding(1, hidden_size) + # Dimension of the extra input vectors + self.extra_in_dim = pooler_out_dim - # Image size and crop size conditions - self.extra_in_dim = 256 * 6 + hidden_size + if self.use_extra_cond: + # Image source size, image target size and crop size conditions + self.extra_in_dim += 6 * 256 + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, hidden_size) + self.extra_in_dim += hidden_size # Text embedding for `add` self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size) self.t_embedder = TimestepEmbedder(hidden_size) - self.extra_in_dim += 1024 self.extra_embedder = nn.Sequential( nn.Linear(self.extra_in_dim, hidden_size * 4), FP32_SiLU(), @@ -1092,9 +1098,9 @@ def forward( T5 text embedding, (B, L_t5, D) text_embedding_mask_t5: torch.Tensor T5 text embedding mask, (B, L_t5) - image_meta_size: torch.Tensor + image_meta_size: None or torch.Tensor (B, 6) - style: torch.Tensor + style: None or torch.Tensor (B) cos_cis_img: torch.Tensor sin_cis_img: torch.Tensor @@ -1144,17 +1150,18 @@ def forward( # Build text tokens with pooling extra_vec = self.pooler(encoder_hidden_states_t5) - # Build image meta size tokens - image_meta_size = timestep_embedding( - image_meta_size.view(-1), 256 - ) # [B * 6, 256] - image_meta_size = image_meta_size.to(x) - image_meta_size = image_meta_size.view(-1, 6 * 256) - extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] - - # Build style tokens - style_embedding = self.style_embedder(style) - extra_vec = torch.cat([extra_vec, style_embedding], dim=1) + if self.use_extra_cond: + # Build image meta size tokens + image_meta_size = timestep_embedding( + image_meta_size.view(-1), 256 + ) # [B * 6, 256] + image_meta_size = image_meta_size.to(x) + image_meta_size = image_meta_size.view(-1, 6 * 256) + extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] + + # Build style tokens + style_embedding = self.style_embedder(style) + extra_vec = torch.cat([extra_vec, style_embedding], dim=1) # Concatenate all extra vectors c = t + self.extra_embedder(extra_vec) # [B, D] diff --git a/library/hunyuan_utils.py b/library/hunyuan_utils.py index 1d4a0d334..40105b07d 100644 --- a/library/hunyuan_utils.py +++ b/library/hunyuan_utils.py @@ -1,5 +1,6 @@ import os from typing import Tuple, Union, Optional, Any +import argparse import numpy as np import torch @@ -13,9 +14,12 @@ BertTokenizer, ) +from library import model_util from library.device_utils import init_ipex, clean_memory_on_device from .hunyuan_models import MT5Embedder, HunYuanDiT, BertModel, DiT_g_2 from .utils import setup_logging +import library.train_util as train_util +from safetensors.torch import load_file, save_file setup_logging() import logging @@ -222,17 +226,22 @@ def load_tokenizers(): def load_scheduler_sigmas(): scheduler: LMSDiscreteScheduler = LMSDiscreteScheduler.from_pretrained( - "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", + "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="scheduler", ) return scheduler.alphas_cumprod, scheduler.sigmas -def load_model(model_path: str, dtype=torch.float16, device="cuda"): +def load_model(model_path: str, dtype=torch.float16, device="cuda", use_extra_cond=False, dit_path=None): denoiser: HunYuanDiT - denoiser, patch_size, head_dim = DiT_g_2(input_size=(128, 128)) - sd = torch.load(os.path.join(model_path, "denoiser/pytorch_model_module.pt")) - denoiser.load_state_dict(sd) + denoiser, patch_size, head_dim = DiT_g_2(input_size=(128, 128), use_extra_cond=use_extra_cond) + if dit_path is not None: + state_dict = torch.load(dit_path) + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + else: + state_dict = torch.load(os.path.join(model_path, "denoiser/pytorch_model_module.pt")) + denoiser.load_state_dict(state_dict) denoiser.to(device).to(dtype) clip_tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "clip")) @@ -279,7 +288,8 @@ def match_mixed_precision(args, weight_dtype): return None -def load_target_model(args, accelerator, model_version: str, weight_dtype): +def load_target_model(args, accelerator, model_version: str, weight_dtype, use_extra_cond=False): + _ = model_version # unused model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: @@ -299,6 +309,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): args.pretrained_model_name_or_path, model_dtype, accelerator.device if args.lowram else "cpu", + use_extra_cond, ) # work on low-ram device @@ -568,6 +579,158 @@ def calc_rope(height, width, patch_size=2, head_size=64): return rope +def add_hydit_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--use_extra_cond", action="store_true", help="Use image_meta_size and style conditions for the model" + ) + parser.add_argument( + "--beta_end", type=float, default=0.018, help="End value of beta for DDPM training" + ) + + +def save_hydit_checkpoint( + output_file, + text_encoder1, + text_encoder2, + hydit, + epochs, + steps, + ckpt_info, + vae, + logit_scale, + metadata, + save_dtype=None, +): + state_dict = {} + + def update_state(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the hydit model + update_state('', hydit.state_dict()) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + if ckpt_info is not None: + epochs += ckpt_info[0] + steps += ckpt_info[1] + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if model_util.is_safetensors(output_file): + save_file(state_dict, output_file, metadata) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_hydit_model_on_train_end( + args: argparse.Namespace, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + text_encoder1, + text_encoder2, + unet, + vae, + logit_scale, + ckpt_info, + hydit, +): + def hydit_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=False, + hydit=hydit) + save_hydit_checkpoint( + ckpt_file, + text_encoder1, + text_encoder2, + unet, + epoch_no, + global_step, + ckpt_info, + vae, + logit_scale, + sai_metadata, + save_dtype, + ) + + def diffusers_saver(out_dir): + _ = out_dir + raise NotImplementedError("Diffusers saving is not supported yet for HunYuan DiT") + + train_util.save_sd_model_on_train_end_common( + args, save_stable_diffusion_format, use_safetensors, epoch, global_step, hydit_saver, diffusers_saver + ) + + +# Save epochs and steps, integrate because the metadata includes epochs/steps and the arguments are identical. +# on_epoch_end: If true, at the end of epoch, if false, after the steps have been completed. +def save_hydit_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + src_path, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + text_encoder1, + text_encoder2, + unet, + vae, + logit_scale, + ckpt_info, + hydit, +): + def hydit_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=False, + hydit=hydit) + save_hydit_checkpoint( + ckpt_file, + text_encoder1, + text_encoder2, + unet, + epoch_no, + global_step, + ckpt_info, + vae, + logit_scale, + sai_metadata, + save_dtype, + ) + + def diffusers_saver(out_dir): + _ = out_dir + raise NotImplementedError("Diffusers saving is not supported yet for HunYuan DiT") + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + save_stable_diffusion_format, + use_safetensors, + epoch, + num_train_epochs, + global_step, + hydit_saver, + diffusers_saver, + ) + + if __name__ == "__main__": clip_tokenizer = AutoTokenizer.from_pretrained("./model/clip") clip_tokenizer.eos_token_id = 2 diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index a63bd82ec..7535fe02c 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -56,11 +56,15 @@ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" +ARCH_HYDIT_V1_1 = "hunyuan-dit-g2-v1_1" +ARCH_HYDIT_V1_2 = "hunyuan-dit-g2-v1_2" + ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" IMPL_DIFFUSERS = "diffusers" +IMPL_HUNYUAN_DIT = "https://github.com/Tencent/HunyuanDiT" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -113,6 +117,7 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + hydit: Optional[str] = None, ): # if state_dict is None, hash is not calculated @@ -124,7 +129,14 @@ def build_metadata( # hash = precalculate_safetensors_hashes(state_dict) # metadata["modelspec.hash_sha256"] = hash - if sdxl: + if hydit: + if hydit == 'v1.1': + arch = ARCH_HYDIT_V1_1 + elif hydit == 'v1.2': + arch = ARCH_HYDIT_V1_2 + else: + raise ValueError(f"Invalid hydit version: {hydit}") + elif sdxl: arch = ARCH_SD_XL_V1_BASE elif v2: if v_parameterization: @@ -144,7 +156,9 @@ def build_metadata( if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + if hydit: + impl = IMPL_HUNYUAN_DIT + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI else: diff --git a/library/train_util.py b/library/train_util.py index 2a027b1bd..07c3502ce 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2922,6 +2922,7 @@ def get_sai_model_spec( lora: bool, textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + hydit: str = None, ): timestamp = time.time() @@ -2955,6 +2956,7 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int + hydit=hydit, ) return metadata