Skip to content

Commit

Permalink
A quick conversion of train_network from sdxl ver
Browse files Browse the repository at this point in the history
  • Loading branch information
KohakuBlueleaf committed Jun 21, 2024
1 parent 09620b4 commit 982cf79
Showing 1 changed file with 199 additions and 0 deletions.
199 changes: 199 additions & 0 deletions hunyuan_train_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import argparse

import torch
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from library import (
hunyuan_models,
hunyuan_utils,
sdxl_model_util,
sdxl_train_util,
train_util,
)
import train_network
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


class HunYuanNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)

if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"

assert (
args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"

train_dataset_group.verify_bucket_reso_steps(16)

def load_target_model(self, args, weight_dtype, accelerator):
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = hunyuan_utils.load_target_model(
args,
accelerator,
hunyuan_models.MODEL_VERSION_HUNYUAN_V1_1,
weight_dtype,
)

self.load_stable_diffusion_format = load_stable_diffusion_format
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info

return (
hunyuan_models.MODEL_VERSION_HUNYUAN_V1_1,
[text_encoder1, text_encoder2],
vae,
unet,
)

def load_tokenizer(self, args):
tokenizer = hunyuan_utils.load_tokenizers(args)
return tokenizer

def is_text_encoder_outputs_cached(self, args):
return args.cache_text_encoder_outputs

def cache_text_encoder_outputs_if_needed(
self,
args,
accelerator,
unet,
vae,
tokenizers,
text_encoders,
dataset: train_util.DatasetGroup,
weight_dtype,
):
if args.cache_text_encoder_outputs:
raise NotImplementedError
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)

def get_text_cond(
self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype
):
if (
"text_encoder_outputs1_list" not in batch
or batch["text_encoder_outputs1_list"] is None
):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, mask1, encoder_hidden_states2, mask2 = (
hunyuan_utils.hunyuan_get_hidden_states(
args.max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
)
else:
raise NotImplementedError
return encoder_hidden_states1, mask1, encoder_hidden_states2, mask2

def call_unet(
self,
args,
accelerator,
unet,
noisy_latents,
timesteps,
text_conds,
batch,
weight_dtype,
):
noisy_latents = noisy_latents.to(
weight_dtype
) # TODO check why noisy_latents is not weight_dtype

# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
B, C, H, W = noisy_latents.shape

# TODO implement correct meta_size info
style = torch.as_tensor([0] * B, device=accelerator.device)
# src hw, dst hw, 0, 0
size_cond = [1024, 1024, 1024, 1024, 0, 0]
image_meta_size = torch.as_tensor([size_cond], device=accelerator.device)
freqs_cis_img = hunyuan_utils.calc_rope(H*8, W*8, 2, 88)

# concat embeddings
encoder_hidden_states1, mask1, encoder_hidden_states2, mask2 = text_conds
noise_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states1,
text_embedding_mask=mask1,
encoder_hidden_states_t5=encoder_hidden_states2,
text_embedding_mask_t5=mask2,
image_meta_size=None,
style=style,
cos_cis_img=freqs_cis_img[0],
sin_cis_img=freqs_cis_img[1],
)
return noise_pred

def sample_images(
self,
accelerator,
args,
epoch,
global_step,
device,
vae,
tokenizer,
text_encoder,
unet,
):
raise NotImplementedError


def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
sdxl_train_util.add_sdxl_training_arguments(parser)
return parser


if __name__ == "__main__":
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

trainer = HunYuanNetworkTrainer()
trainer.train(args)

0 comments on commit 982cf79

Please sign in to comment.