forked from kohya-ss/sd-scripts
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A quick conversion of train_network from sdxl ver
- Loading branch information
1 parent
09620b4
commit 982cf79
Showing
1 changed file
with
199 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import argparse | ||
|
||
import torch | ||
from library.device_utils import init_ipex, clean_memory_on_device | ||
|
||
init_ipex() | ||
|
||
from library import ( | ||
hunyuan_models, | ||
hunyuan_utils, | ||
sdxl_model_util, | ||
sdxl_train_util, | ||
train_util, | ||
) | ||
import train_network | ||
from library.utils import setup_logging | ||
|
||
setup_logging() | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class HunYuanNetworkTrainer(train_network.NetworkTrainer): | ||
def __init__(self): | ||
super().__init__() | ||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR | ||
self.is_sdxl = True | ||
|
||
def assert_extra_args(self, args, train_dataset_group): | ||
super().assert_extra_args(args, train_dataset_group) | ||
sdxl_train_util.verify_sdxl_training_args(args) | ||
|
||
if args.cache_text_encoder_outputs: | ||
assert ( | ||
train_dataset_group.is_text_encoder_output_cacheable() | ||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" | ||
|
||
assert ( | ||
args.network_train_unet_only or not args.cache_text_encoder_outputs | ||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" | ||
|
||
train_dataset_group.verify_bucket_reso_steps(16) | ||
|
||
def load_target_model(self, args, weight_dtype, accelerator): | ||
( | ||
load_stable_diffusion_format, | ||
text_encoder1, | ||
text_encoder2, | ||
vae, | ||
unet, | ||
logit_scale, | ||
ckpt_info, | ||
) = hunyuan_utils.load_target_model( | ||
args, | ||
accelerator, | ||
hunyuan_models.MODEL_VERSION_HUNYUAN_V1_1, | ||
weight_dtype, | ||
) | ||
|
||
self.load_stable_diffusion_format = load_stable_diffusion_format | ||
self.logit_scale = logit_scale | ||
self.ckpt_info = ckpt_info | ||
|
||
return ( | ||
hunyuan_models.MODEL_VERSION_HUNYUAN_V1_1, | ||
[text_encoder1, text_encoder2], | ||
vae, | ||
unet, | ||
) | ||
|
||
def load_tokenizer(self, args): | ||
tokenizer = hunyuan_utils.load_tokenizers(args) | ||
return tokenizer | ||
|
||
def is_text_encoder_outputs_cached(self, args): | ||
return args.cache_text_encoder_outputs | ||
|
||
def cache_text_encoder_outputs_if_needed( | ||
self, | ||
args, | ||
accelerator, | ||
unet, | ||
vae, | ||
tokenizers, | ||
text_encoders, | ||
dataset: train_util.DatasetGroup, | ||
weight_dtype, | ||
): | ||
if args.cache_text_encoder_outputs: | ||
raise NotImplementedError | ||
else: | ||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく | ||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) | ||
text_encoders[1].to(accelerator.device, dtype=weight_dtype) | ||
|
||
def get_text_cond( | ||
self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype | ||
): | ||
if ( | ||
"text_encoder_outputs1_list" not in batch | ||
or batch["text_encoder_outputs1_list"] is None | ||
): | ||
input_ids1 = batch["input_ids"] | ||
input_ids2 = batch["input_ids2"] | ||
with torch.enable_grad(): | ||
input_ids1 = input_ids1.to(accelerator.device) | ||
input_ids2 = input_ids2.to(accelerator.device) | ||
encoder_hidden_states1, mask1, encoder_hidden_states2, mask2 = ( | ||
hunyuan_utils.hunyuan_get_hidden_states( | ||
args.max_token_length, | ||
input_ids1, | ||
input_ids2, | ||
tokenizers[0], | ||
tokenizers[1], | ||
text_encoders[0], | ||
text_encoders[1], | ||
None if not args.full_fp16 else weight_dtype, | ||
accelerator=accelerator, | ||
) | ||
) | ||
else: | ||
raise NotImplementedError | ||
return encoder_hidden_states1, mask1, encoder_hidden_states2, mask2 | ||
|
||
def call_unet( | ||
self, | ||
args, | ||
accelerator, | ||
unet, | ||
noisy_latents, | ||
timesteps, | ||
text_conds, | ||
batch, | ||
weight_dtype, | ||
): | ||
noisy_latents = noisy_latents.to( | ||
weight_dtype | ||
) # TODO check why noisy_latents is not weight_dtype | ||
|
||
# get size embeddings | ||
orig_size = batch["original_sizes_hw"] | ||
crop_size = batch["crop_top_lefts"] | ||
target_size = batch["target_sizes_hw"] | ||
B, C, H, W = noisy_latents.shape | ||
|
||
# TODO implement correct meta_size info | ||
style = torch.as_tensor([0] * B, device=accelerator.device) | ||
# src hw, dst hw, 0, 0 | ||
size_cond = [1024, 1024, 1024, 1024, 0, 0] | ||
image_meta_size = torch.as_tensor([size_cond], device=accelerator.device) | ||
freqs_cis_img = hunyuan_utils.calc_rope(H*8, W*8, 2, 88) | ||
|
||
# concat embeddings | ||
encoder_hidden_states1, mask1, encoder_hidden_states2, mask2 = text_conds | ||
noise_pred = unet( | ||
noisy_latents, | ||
timesteps, | ||
encoder_hidden_states=encoder_hidden_states1, | ||
text_embedding_mask=mask1, | ||
encoder_hidden_states_t5=encoder_hidden_states2, | ||
text_embedding_mask_t5=mask2, | ||
image_meta_size=None, | ||
style=style, | ||
cos_cis_img=freqs_cis_img[0], | ||
sin_cis_img=freqs_cis_img[1], | ||
) | ||
return noise_pred | ||
|
||
def sample_images( | ||
self, | ||
accelerator, | ||
args, | ||
epoch, | ||
global_step, | ||
device, | ||
vae, | ||
tokenizer, | ||
text_encoder, | ||
unet, | ||
): | ||
raise NotImplementedError | ||
|
||
|
||
def setup_parser() -> argparse.ArgumentParser: | ||
parser = train_network.setup_parser() | ||
sdxl_train_util.add_sdxl_training_arguments(parser) | ||
return parser | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = setup_parser() | ||
|
||
args = parser.parse_args() | ||
train_util.verify_command_line_training_args(args) | ||
args = train_util.read_config_from_file(args, parser) | ||
|
||
trainer = HunYuanNetworkTrainer() | ||
trainer.train(args) |