From 5b19bda85c2ce01e4a1c7f324b7ef14bffed3315 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 12:35:46 -0500 Subject: [PATCH 01/76] Add validation loss --- library/train_util.py | 4 ++ train_network.py | 117 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index cc9ac4555..e26f39799 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4736,6 +4736,10 @@ def __call__(self, examples): else: dataset = self.dataset + # If we split a dataset we will get a Subset + if type(dataset) is torch.utils.data.Subset: + dataset = dataset.dataset + # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) diff --git a/train_network.py b/train_network.py index d50916b74..58767b6f7 100644 --- a/train_network.py +++ b/train_network.py @@ -345,8 +345,21 @@ def train(self, args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + if args.validation_ratio > 0.0: + train_ratio = 1 - args.validation_ratio + validation_ratio = args.validation_ratio + train, val = torch.utils.data.random_split( + train_dataset_group, + [train_ratio, validation_ratio] + ) + print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}") + print(f"train images: {len(train)}, validation images: {len(val)}") + else: + train = train_dataset_group + val = [] + train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, + train, batch_size=1, shuffle=True, collate_fn=collator, @@ -354,6 +367,15 @@ def train(self, args): persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val, + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -711,6 +733,8 @@ def train(self, args): ) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + del train_dataset_group # callback for step start @@ -752,6 +776,8 @@ def remove_model(old_ckpt_name): network.on_epoch_start(text_encoder, unet) + # TRAINING + for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(network): @@ -877,6 +903,87 @@ def remove_model(old_ckpt_name): if global_step >= args.max_train_steps: break + # VALIDATION + + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + + with torch.no_grad(): + for val_step, batch in enumerate(val_dataloader): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + current_loss = loss.detach().item() + + val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + + if len(val_dataloader) > 0: + avr_loss: float = val_loss_recorder.moving_average + + if args.logging_dir is not None: + logs = {"loss/validation": avr_loss} + accelerator.log(logs, step=epoch + 1) + + if args.logging_dir is not None: logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) @@ -999,6 +1106,14 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + + parser.add_argument( + "--validation_ratio", + type=float, + default=0.0, + help="Ratio for validation images out of the training dataset" + ) + return parser From 33c311ed19821c9be7094ba89371777d7478b028 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 12:37:37 -0500 Subject: [PATCH 02/76] new ratio code --- train_network.py | 48 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index 58767b6f7..967c95fb4 100644 --- a/train_network.py +++ b/train_network.py @@ -345,10 +345,48 @@ def train(self, args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + def get_indices_without_reg(dataset: torch.utils.data.Dataset): + return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False] + + from typing import Sequence, Union + from torch._utils import _accumulate + import warnings + from torch.utils.data.dataset import Subset + + def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]): + indices = get_indices_without_reg(dataset) + random.shuffle(indices) + + subset_lengths = [] + + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError(f"Fraction at index {i} is not between 0 and 1") + n_items_in_split = int(math.floor(len(indices) * frac)) + subset_lengths.append(n_items_in_split) + + remainder = len(indices) - sum(subset_lengths) + + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn(f"Length of split at index {i} is 0. " + f"This might result in an empty dataset.") + + if sum(lengths) != len(indices): + raise ValueError("Sum of input lengths does not equal the length of the input dataset!") + + return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)] + + if args.validation_ratio > 0.0: train_ratio = 1 - args.validation_ratio validation_ratio = args.validation_ratio - train, val = torch.utils.data.random_split( + train, val = random_split( train_dataset_group, [train_ratio, validation_ratio] ) @@ -358,6 +396,8 @@ def train(self, args): train = train_dataset_group val = [] + + train_dataloader = torch.utils.data.DataLoader( train, batch_size=1, @@ -898,7 +938,7 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs, step=global_step) + accelerator.log(logs) if global_step >= args.max_train_steps: break @@ -973,13 +1013,11 @@ def remove_model(old_ckpt_name): loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) if len(val_dataloader) > 0: - avr_loss: float = val_loss_recorder.moving_average - if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation": avr_loss} accelerator.log(logs, step=epoch + 1) From 3de9e6c443037abf99832d1be60f4fc9c0d67b8c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 01:45:23 -0500 Subject: [PATCH 03/76] Add validation split of datasets --- library/config_util.py | 145 ++++++++++++++++++++++++++--------------- library/train_util.py | 26 ++++++++ train_network.py | 67 ++++--------------- 3 files changed, 128 insertions(+), 110 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7c..1bf7ed955 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -85,6 +85,8 @@ class BaseDatasetParams: max_token_length: int = None resolution: Optional[Tuple[int, int]] = None debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -200,6 +202,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, + "validation_seed": int, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), } @@ -427,64 +431,89 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - """) - - if dataset.enable_bucket: - info += indent(dedent(f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """), " ") - - if is_dreambooth: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + + # print info + def print_info(_datasets): + info = "" + for i, dataset in enumerate(_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) + + if dataset.enable_bucket: info += indent(dedent(f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): info += indent(dedent(f"""\ - metadata_file: {subset.metadata_file} - \n"""), " ") - - print(info) + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + elif not is_controlnet: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + print(info) + + print_info(datasets) + + if len(val_datasets) > 0: + print("Validation dataset") + print_info(val_datasets) # make buckets first because it determines the length of dataset # and set the same seed for all datasets @@ -494,7 +523,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + print(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index e26f39799..ba37ec13d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -123,6 +123,22 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] + + class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: self.image_key: str = image_key @@ -1314,6 +1330,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1324,12 +1341,18 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + validation_split: float, + validation_seed: Optional[int], debug_dataset, ) -> None: super().__init__(tokenizer, max_token_length, resolution, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed + self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1382,6 +1405,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): return [], [] img_paths = glob_images(subset.image_dir, "*") + + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う diff --git a/train_network.py b/train_network.py index 967c95fb4..97ecfe7be 100644 --- a/train_network.py +++ b/train_network.py @@ -189,10 +189,11 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + val_dataset_group = None # placeholder until validation dataset supported for arbitrary current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -212,6 +213,10 @@ def train(self, args): assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" self.assert_extra_args(args, train_dataset_group) @@ -264,6 +269,9 @@ def train(self, args): vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -345,61 +353,8 @@ def train(self, args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - def get_indices_without_reg(dataset: torch.utils.data.Dataset): - return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False] - - from typing import Sequence, Union - from torch._utils import _accumulate - import warnings - from torch.utils.data.dataset import Subset - - def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]): - indices = get_indices_without_reg(dataset) - random.shuffle(indices) - - subset_lengths = [] - - for i, frac in enumerate(lengths): - if frac < 0 or frac > 1: - raise ValueError(f"Fraction at index {i} is not between 0 and 1") - n_items_in_split = int(math.floor(len(indices) * frac)) - subset_lengths.append(n_items_in_split) - - remainder = len(indices) - sum(subset_lengths) - - for i in range(remainder): - idx_to_add_at = i % len(subset_lengths) - subset_lengths[idx_to_add_at] += 1 - - lengths = subset_lengths - for i, length in enumerate(lengths): - if length == 0: - warnings.warn(f"Length of split at index {i} is 0. " - f"This might result in an empty dataset.") - - if sum(lengths) != len(indices): - raise ValueError("Sum of input lengths does not equal the length of the input dataset!") - - return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)] - - - if args.validation_ratio > 0.0: - train_ratio = 1 - args.validation_ratio - validation_ratio = args.validation_ratio - train, val = random_split( - train_dataset_group, - [train_ratio, validation_ratio] - ) - print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}") - print(f"train images: {len(train)}, validation images: {len(val)}") - else: - train = train_dataset_group - val = [] - - - train_dataloader = torch.utils.data.DataLoader( - train, + train_dataset_group, batch_size=1, shuffle=True, collate_fn=collator, @@ -408,7 +363,7 @@ def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, ) val_dataloader = torch.utils.data.DataLoader( - val, + val_dataset_group if val_dataset_group is not None else [], shuffle=False, batch_size=1, collate_fn=collator, From a93c524b3a0e5c80a58c1317211dec93b6c137a7 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 02:07:39 -0500 Subject: [PATCH 04/76] Update args to validation_seed and validation_split --- train_network.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 97ecfe7be..f9e5debdb 100644 --- a/train_network.py +++ b/train_network.py @@ -1099,12 +1099,17 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - parser.add_argument( - "--validation_ratio", + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", type=float, default=0.0, - help="Ratio for validation images out of the training dataset" + help="Split for validation images out of the training dataset" ) return parser From c89252101e8e8bd74cb3ab09ae33b548fd828e15 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:27:36 -0500 Subject: [PATCH 05/76] Add process_batch for train_network --- train_network.py | 211 ++++++++++++++++++----------------------------- 1 file changed, 82 insertions(+), 129 deletions(-) diff --git a/train_network.py b/train_network.py index f9e5debdb..387b94b1c 100644 --- a/train_network.py +++ b/train_network.py @@ -130,6 +130,75 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoders[0], + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + return loss + + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -777,71 +846,8 @@ def remove_model(old_ckpt_name): current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) - - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor - b_size = latents.shape[0] - - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + is_train = True + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -893,7 +899,7 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs) + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break @@ -905,80 +911,27 @@ def remove_model(old_ckpt_name): with torch.no_grad(): for val_step, batch in enumerate(val_dataloader): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor - b_size = latents.shape[0] - - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + is_train = False + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + if len(val_dataloader) > 0: if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation": avr_loss} + logs = {"loss/validation_average": avr_loss} accelerator.log(logs, step=epoch + 1) if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + # logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From e545fdfd9affabff83f8bd2e7680369bb34dd301 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:56:36 -0500 Subject: [PATCH 06/76] Removed/cleanup a line --- train_network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/train_network.py b/train_network.py index 387b94b1c..a4125e9f2 100644 --- a/train_network.py +++ b/train_network.py @@ -930,7 +930,6 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: - # logs = {"loss/epoch": loss_recorder.moving_average} logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) From 9c591bdb12ce663b3fe9e91c0963d2cf71461bad Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:58:20 -0500 Subject: [PATCH 07/76] Remove unnecessary subset line from collate --- library/train_util.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ba37ec13d..1979207b0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4762,10 +4762,6 @@ def __call__(self, examples): else: dataset = self.dataset - # If we split a dataset we will get a Subset - if type(dataset) is torch.utils.data.Subset: - dataset = dataset.dataset - # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) From 569ca72fc4cda2f4ce30e43b1c62989e79e3c3b3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 7 Nov 2023 11:59:30 -0500 Subject: [PATCH 08/76] Set grad enabled if is_train and train_text_encoder We only want to be enabling grad if we are training. --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index a4125e9f2..edd3ff944 100644 --- a/train_network.py +++ b/train_network.py @@ -145,7 +145,7 @@ def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, n latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( From b558a5b73d07a7e15ad90d9d15c2b55c5d2b3d61 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 04:37:16 +0800 Subject: [PATCH 09/76] val --- library/config_util.py | 176 ++++++++++++++++++++++------------------- library/train_util.py | 22 ++++++ train_network.py | 135 ++++++++++++++++++++++++++++--- 3 files changed, 241 insertions(+), 92 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index fc4b36175..17fc17818 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -98,7 +98,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False - + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -109,8 +110,7 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 - - + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -222,8 +222,11 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, + "validation_seed": int, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, + } # options handled by argparse but not handled by user config @@ -460,100 +463,107 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) - # print info - info = "" - for i, dataset in enumerate(datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent( - f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + + def print_info(_datasets): + info = "" + for i, dataset in enumerate(_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent(f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + """) if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) + info += indent(dedent(f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n"""), " ") else: info += "\n" - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - keep_tokens_separator: {subset.keep_tokens_separator} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """ - ), - " ", - ) - - if is_dreambooth: - info += indent( - dedent( - f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n""" - ), - " ", - ) - elif not is_controlnet: - info += indent( - dedent( - f"""\ - metadata_file: {subset.metadata_file} - \n""" - ), - " ", - ) - - logger.info(f'{info}') - + info += indent(dedent(f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """), " ") + + if is_dreambooth: + info += indent(dedent(f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n"""), " ") + elif not is_controlnet: + info += indent(dedent(f"""\ + metadata_file: {subset.metadata_file} + \n"""), " ") + + print(info) + + print_info(datasets) + + if len(val_datasets) > 0: + print("Validation dataset") + print_info(val_datasets) + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") + print(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + for i, dataset in enumerate(val_datasets): + print(f"[Validation Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) - - + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) + def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): def extract_dreambooth_params(name: str) -> Tuple[int, str]: tokens = name.split("_") diff --git a/library/train_util.py b/library/train_util.py index d2b69edb5..753539e04 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -134,6 +134,20 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -1360,6 +1374,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1371,12 +1386,17 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + validation_split: float, + validation_seed: Optional[int], debug_dataset: bool, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1429,6 +1449,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset): return [], [] img_paths = glob_images(subset.image_dir, "*") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う diff --git a/train_network.py b/train_network.py index e0fa69458..db7000e82 100644 --- a/train_network.py +++ b/train_network.py @@ -136,6 +136,67 @@ def all_reduce_network(self, accelerator, network): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoders[0], + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, _ = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + for timesteps in timesteps_list: + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss + + average_loss = total_loss / len(timesteps_list) + return average_loss + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -196,11 +257,12 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - + val_dataset_group = None # placeholder until validation dataset supported for arbitrary + current_epoch = Value("i", 0) current_step = Value("i", 0) ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None @@ -219,7 +281,11 @@ def train(self, args): assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する @@ -271,6 +337,9 @@ def train(self, args): vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -360,6 +429,15 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -707,6 +785,8 @@ def train(self, args): ) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + del train_dataset_group # callback for step start @@ -755,7 +835,8 @@ def remove_model(old_ckpt_name): current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) - + + is_train = True with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -780,7 +861,7 @@ def remove_model(old_ckpt_name): # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -810,7 +891,7 @@ def remove_model(old_ckpt_name): t.requires_grad_(True) # Predict the noise residual - with accelerator.autocast(): + with torch.set_grad_enabled(is_train), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -844,7 +925,7 @@ def remove_model(old_ckpt_name): loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - + accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -898,14 +979,38 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - + + if global_step % 25 == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + + with torch.no_grad(): + val_dataloader_iter = iter(val_dataloader) + batch = next(val_dataloader_iter) + is_train = False + loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + + current_loss = loss.detach().item() + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) + if len(val_dataloader) > 0: + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1045,6 +1150,18 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset" + ) return parser From 78cfb01922ff97bbc62ff12a4d69eaaa2d89d7c1 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 18:55:48 +0800 Subject: [PATCH 10/76] improve --- library/config_util.py | 260 +++++++++++++++++++++++++++++------------ train_network.py | 67 +++++++---- 2 files changed, 234 insertions(+), 93 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 17fc17818..d198cee35 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -41,12 +41,17 @@ DatasetGroup, ) from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def add_config_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) # TODO: inherit Params class in Subset, Dataset @@ -60,6 +65,8 @@ class BaseSubsetParams: caption_separator: str = (",",) keep_tokens: int = 0 keep_tokens_separator: str = (None,) + secondary_separator: Optional[str] = None + enable_wildcard: bool = False color_aug: bool = False flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None @@ -181,6 +188,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "shuffle_caption": bool, "keep_tokens": int, "keep_tokens_separator": str, + "secondary_separator": str, + "enable_wildcard": bool, "token_warmup_min": int, "token_warmup_step": Any(float, int), "caption_prefix": str, @@ -247,9 +256,10 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] } def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: - assert ( - support_dreambooth or support_finetuning or support_controlnet - ), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + assert support_dreambooth or support_finetuning or support_controlnet, ( + "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more." + + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。" + ) self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, @@ -361,7 +371,9 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug - logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) raise # NOTE: value would be overwritten by latter dict if there is already the same key @@ -447,7 +459,6 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value=None): return default_value - def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] @@ -467,7 +478,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu datasets.append(dataset) val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] - + for dataset_blueprint in dataset_group_blueprint.datasets: if dataset_blueprint.params.validation_split <= 0.0: continue @@ -485,75 +496,174 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - def print_info(_datasets): - info = "" - for i, dataset in enumerate(_datasets): - is_dreambooth = isinstance(dataset, DreamBoothDataset) - is_controlnet = isinstance(dataset, ControlNetDataset) - info += dedent(f"""\ - [Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - """) + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) if dataset.enable_bucket: - info += indent(dedent(f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n"""), " ") + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) else: info += "\n" + for j, subset in enumerate(dataset.subsets): - info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - shuffle_caption: {subset.shuffle_caption} - keep_tokens: {subset.keep_tokens} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} - caption_prefix: {subset.caption_prefix} - caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - """), " ") - - if is_dreambooth: - info += indent(dedent(f"""\ - is_reg: {subset.is_reg} - class_tokens: {subset.class_tokens} - caption_extension: {subset.caption_extension} - \n"""), " ") - elif not is_controlnet: - info += indent(dedent(f"""\ - metadata_file: {subset.metadata_file} - \n"""), " ") - - print(info) - - print_info(datasets) - - if len(val_datasets) > 0: - print("Validation dataset") - print_info(val_datasets) - + info += indent( + dedent( + f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ + metadata_file: {subset.metadata_file} + \n""" + ), + " ", + ) + + logger.info(f'{info}') + + # print validation info + info = "" + for i, dataset in enumerate(val_datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ + metadata_file: {subset.metadata_file} + \n""" + ), + " ", + ) + + logger.info(f'{info}') + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - print(f"[Dataset {i}]") + logger.info(f"[Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) - + for i, dataset in enumerate(val_datasets): print(f"[Validation Dataset {i}]") dataset.make_buckets() @@ -562,8 +672,8 @@ def print_info(_datasets): return ( DatasetGroup(datasets), DatasetGroup(val_datasets) if val_datasets else None - ) - + ) + def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): def extract_dreambooth_params(name: str) -> Tuple[int, str]: tokens = name.split("_") @@ -642,13 +752,17 @@ def load_user_config(file: str) -> dict: with open(file, "r") as f: config = json.load(f) except Exception: - logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise elif file.name.lower().endswith(".toml"): try: config = toml.load(file) except Exception: - logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") @@ -675,13 +789,13 @@ def load_user_config(file: str) -> dict: train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) logger.info("[argparse_namespace]") - logger.info(f'{vars(argparse_namespace)}') + logger.info(f"{vars(argparse_namespace)}") user_config = load_user_config(config_args.dataset_config) logger.info("") logger.info("[user_config]") - logger.info(f'{user_config}') + logger.info(f"{user_config}") sanitizer = ConfigSanitizer( config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout @@ -690,10 +804,10 @@ def load_user_config(file: str) -> dict: logger.info("") logger.info("[sanitized_user_config]") - logger.info(f'{sanitized_user_config}') + logger.info(f"{sanitized_user_config}") blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) logger.info("") logger.info("[blueprint]") - logger.info(f'{blueprint}') + logger.info(f"{blueprint}") diff --git a/train_network.py b/train_network.py index db7000e82..d3e34eb7e 100644 --- a/train_network.py +++ b/train_network.py @@ -44,6 +44,7 @@ setup_logging() import logging +import itertools logger = logging.getLogger(__name__) @@ -438,6 +439,7 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -979,23 +981,24 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - - if global_step % 25 == 0: - if len(val_dataloader) > 0: - print("Validating バリデーション処理...") - - with torch.no_grad(): - val_dataloader_iter = iter(val_dataloader) - batch = next(val_dataloader_iter) - is_train = False - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - - current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.validation_every_n_step is not None: + if global_step % (args.validation_every_n_step) == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + for val_step in min(len(val_dataloader), args.validation_batches): + is_train = False + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_current": current_loss} + logs = {"loss/avr_val_loss": avr_loss} accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1005,12 +1008,24 @@ def remove_model(old_ckpt_name): logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) - if len(val_dataloader) > 0: - if args.logging_dir is not None: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_epoch_average": avr_loss} - accelerator.log(logs, step=epoch + 1) - + if args.validation_every_n_step is None: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + for val_step in min(len(val_dataloader), args.validation_batches): + is_train = False + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/val_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 @@ -1162,6 +1177,18 @@ def setup_parser() -> argparse.ArgumentParser: default=0.0, help="Split for validation images out of the training dataset" ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of steps for counting validation loss. By default, validation per epoch is performed" + ) + parser.add_argument( + "--validation_batches", + type=int, + default=1, + help="Number of val steps for counting validation loss. By default, validation one batch is performed" + ) return parser From 923b761ce3622a3132bf0db7768e6b97df21c607 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:01:40 +0800 Subject: [PATCH 11/76] Update train_network.py --- train_network.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index d3e34eb7e..821100666 100644 --- a/train_network.py +++ b/train_network.py @@ -988,6 +988,7 @@ def remove_model(old_ckpt_name): print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): + validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) for val_step in min(len(val_dataloader), args.validation_batches): is_train = False batch = next(cyclic_val_dataloader) @@ -1013,6 +1014,7 @@ def remove_model(old_ckpt_name): print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): + validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) for val_step in min(len(val_dataloader), args.validation_batches): is_train = False batch = next(cyclic_val_dataloader) @@ -1186,8 +1188,8 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--validation_batches", type=int, - default=1, - help="Number of val steps for counting validation loss. By default, validation one batch is performed" + default=None, + help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed" ) return parser From 47359b8fac9602415f56b1f7e3f25a00255a1d78 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:17:40 +0800 Subject: [PATCH 12/76] Update train_network.py --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 821100666..d549378cc 100644 --- a/train_network.py +++ b/train_network.py @@ -989,7 +989,7 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) @@ -1015,7 +1015,7 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) - for val_step in min(len(val_dataloader), args.validation_batches): + for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) From a51723cc2a3dd50b45e60945f97bc5adfe753d1f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:42:58 +0800 Subject: [PATCH 13/76] fix timesteps --- train_network.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/train_network.py b/train_network.py index d549378cc..f0f27ea74 100644 --- a/train_network.py +++ b/train_network.py @@ -141,7 +141,6 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va total_loss = 0.0 timesteps_list = [10, 350, 500, 650, 990] - with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -174,16 +173,17 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, _ = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - for timesteps in timesteps_list: - # Predict the noise residual + + for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - if args.v_parameterization: # v-parameterization training target = noise_scheduler.get_velocity(latents, noise, timesteps) @@ -988,7 +988,7 @@ def remove_model(old_ckpt_name): print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) @@ -999,7 +999,7 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/avr_val_loss": avr_loss} + logs = {"loss/average_val_loss": avr_loss} accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1014,7 +1014,7 @@ def remove_model(old_ckpt_name): print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) for val_step in range(validation_steps): is_train = False batch = next(cyclic_val_dataloader) From 7d84ac2177a603e9aa6834fd1c0ee19a463eb5a0 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:41:51 +0800 Subject: [PATCH 14/76] only use train subset to val --- library/config_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/config_util.py b/library/config_util.py index d198cee35..1a6cef971 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -492,7 +492,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From befbec5335ed1f8018d22b65993b376571ea2989 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:47:04 +0800 Subject: [PATCH 15/76] Update train_network.py --- train_network.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/train_network.py b/train_network.py index f0f27ea74..cbc107b6b 100644 --- a/train_network.py +++ b/train_network.py @@ -174,7 +174,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - for fixed_timesteps in timesteps_list: + for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'): with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] @@ -184,16 +184,16 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss average_loss = total_loss / len(timesteps_list) return average_loss @@ -985,7 +985,7 @@ def remove_model(old_ckpt_name): if args.validation_every_n_step is not None: if global_step % (args.validation_every_n_step) == 0: if len(val_dataloader) > 0: - print("Validating バリデーション処理...") + print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) @@ -994,10 +994,12 @@ def remove_model(old_ckpt_name): batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) avr_loss: float = val_loss_recorder.moving_average logs = {"loss/average_val_loss": avr_loss} accelerator.log(logs, step=global_step) @@ -1011,7 +1013,7 @@ def remove_model(old_ckpt_name): if args.validation_every_n_step is None: if len(val_dataloader) > 0: - print("Validating バリデーション処理...") + print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) @@ -1025,7 +1027,7 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/val_epoch_average": avr_loss} + logs = {"loss/epoch_val_average": avr_loss} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() From 63e58f78e3df7608045071cdc247bb26bd19a333 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:15:55 +0800 Subject: [PATCH 16/76] Update train_network.py --- train_network.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index cbc107b6b..82d72df24 100644 --- a/train_network.py +++ b/train_network.py @@ -178,8 +178,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] - timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype From a6c41c6bea0465112c7bd472dff68b7e8ecea46e Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:23:48 +0800 Subject: [PATCH 17/76] Update train_network.py --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 82d72df24..6eefdb2be 100644 --- a/train_network.py +++ b/train_network.py @@ -174,7 +174,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'): + for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] @@ -988,7 +988,7 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in range(validation_steps): + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) @@ -1016,7 +1016,7 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in range(validation_steps): + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) From bd7e2295b7c4d1444a9e844309e1685cb29c6961 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Wed, 13 Mar 2024 17:54:21 +0800 Subject: [PATCH 18/76] fix --- train_network.py | 38 +++++++++----------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/train_network.py b/train_network.py index 6eefdb2be..128690fba 100644 --- a/train_network.py +++ b/train_network.py @@ -981,20 +981,19 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if args.validation_every_n_step is not None: - if global_step % (args.validation_every_n_step) == 0: - if len(val_dataloader) > 0: + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps: print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): - validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc='Validation Steps'): is_train = False batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss) if args.logging_dir is not None: logs = {"loss/current_val_loss": current_loss} @@ -1009,25 +1008,6 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) - - if args.validation_every_n_step is None: - if len(val_dataloader) > 0: - print(f"\nValidating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - is_train = False - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/epoch_val_average": avr_loss} - accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -1184,14 +1164,14 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_every_n_step", type=int, default=None, - help="Number of steps for counting validation loss. By default, validation per epoch is performed" + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" ) parser.add_argument( - "--validation_batches", + "--max_validation_steps", type=int, default=None, - help="Number of val steps for counting validation loss. By default, validation for all val_dataset is performed" - ) + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" + ) return parser From d05965dbadf430dab6a05f171292f6d2077ec946 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Wed, 13 Mar 2024 18:33:51 +0800 Subject: [PATCH 19/76] Update train_network.py --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 864bfd708..cc9fcbbed 100644 --- a/train_network.py +++ b/train_network.py @@ -987,8 +987,8 @@ def remove_model(old_ckpt_name): accelerator.log(logs, step=global_step) if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or step == len(train_dataloader) - 1 or global_step >= args.max_train_steps: - print(f"\nValidating バリデーション処理...") + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) @@ -998,7 +998,7 @@ def remove_model(old_ckpt_name): loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) if args.logging_dir is not None: logs = {"loss/current_val_loss": current_loss} From b5e8045df40ed4a437492ed2b6ea6d5be7282080 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Sat, 16 Mar 2024 11:51:11 +0800 Subject: [PATCH 20/76] fix control net --- library/config_util.py | 6 ++++-- library/train_util.py | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index ec6ef4b2b..0da0b1437 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -491,8 +491,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] + if subset_klass == DreamBoothSubset: + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] + else: + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) diff --git a/library/train_util.py b/library/train_util.py index 892979628..ae7968d73 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1816,6 +1816,7 @@ class ControlNetDataset(BaseDataset): def __init__( self, subsets: Sequence[ControlNetSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1826,6 +1827,8 @@ def __init__( max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, + validation_split: float, + validation_seed: Optional[int], debug_dataset: float, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) @@ -1860,6 +1863,7 @@ def __init__( self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, + is_train, batch_size, tokenizer, max_token_length, @@ -1871,6 +1875,8 @@ def __init__( bucket_reso_steps, bucket_no_upscale, 1.0, + validation_split, + validation_seed, debug_dataset, ) @@ -1878,7 +1884,10 @@ def __init__( self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -1911,8 +1920,8 @@ def __init__( [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] ) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + #assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" + #assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS From 36d4023431d10718b00673d5ba34f426690c62de Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:39:17 +0800 Subject: [PATCH 21/76] Update config_util.py --- library/config_util.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a7e0024e3..c6667690e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -498,10 +498,21 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - if subset_klass == DreamBoothSubset: - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] - else: - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + + subsets = [] + for subset_blueprint in dataset_blueprint.subsets: + subset_blueprint.params.num_repeats = 1 + subset_blueprint.params.color_aug = False + subset_blueprint.params.flip_aug = False + subset_blueprint.params.random_crop = False + subset_blueprint.params.random_crop = None + subset_blueprint.params.caption_dropout_rate = 0.0 + subset_blueprint.params.caption_dropout_every_n_epochs = 0 + subset_blueprint.params.caption_tag_dropout_rate = 0.0 + subset_blueprint.params.token_warmup_step = 0 + if subset_klass != DreamBoothSubset or not subset_blueprint.params.is_reg: + subsets.append(subset_klass(**asdict(subset_blueprint.params))) + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From 229c5a38ef4e93e2023d748b4fa1588d490340ad Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:45:49 +0800 Subject: [PATCH 22/76] Update train_util.py --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 832be75d5..b143e85a8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3123,7 +3123,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( - "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" ) parser.add_argument( "--gradient_accumulation_steps", From 3b251b758dae6e4f11e0bbc7e544dc9542c836ff Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:50:32 +0800 Subject: [PATCH 23/76] Update config_util.py --- library/config_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index c6667690e..8f01e1f60 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -510,8 +510,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_blueprint.params.caption_dropout_every_n_epochs = 0 subset_blueprint.params.caption_tag_dropout_rate = 0.0 subset_blueprint.params.token_warmup_step = 0 - if subset_klass != DreamBoothSubset or not subset_blueprint.params.is_reg: - subsets.append(subset_klass(**asdict(subset_blueprint.params))) + + if subset_klass != DreamBoothSubset or (subset_klass == DreamBoothSubset and not subset_blueprint.params.is_reg): + subset = subset_klass(**asdict(subset_blueprint.params)) + subsets.append(subset) dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From 459b12539b0ae1a92da98e38568ea0a61db1e89f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:52:14 +0800 Subject: [PATCH 24/76] Update config_util.py --- library/config_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 8f01e1f60..6f243aac3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -512,8 +512,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_blueprint.params.token_warmup_step = 0 if subset_klass != DreamBoothSubset or (subset_klass == DreamBoothSubset and not subset_blueprint.params.is_reg): - subset = subset_klass(**asdict(subset_blueprint.params)) - subsets.append(subset) + subsets.append(subset_klass(**asdict(subset_blueprint.params))) dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) From 89ad69b6a0d35791627cb58630a711befc6bb3b5 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 08:42:31 +0800 Subject: [PATCH 25/76] Update train_util.py --- library/train_util.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b143e85a8..8bf6823bb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1511,17 +1511,6 @@ def load_dreambooth_dir(subset: DreamBoothSubset): logger.warning(f"not directory: {subset.image_dir}") return [], [] - img_paths = glob_images(subset.image_dir, "*") - if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - missing_captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) use_cached_info_for_subset = subset.cache_info if use_cached_info_for_subset: @@ -1545,6 +1534,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset): # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) sizes = [None] * len(img_paths) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") From fde8026c2d92fe4991927eed6fa1ff373e8d38d2 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:29:26 +0800 Subject: [PATCH 26/76] Update config_util.py --- library/config_util.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 6f243aac3..a1b02bd1e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -636,19 +636,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [Subset {j} of Dataset {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} - num_repeats: {subset.num_repeats} shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} keep_tokens_separator: {subset.keep_tokens_separator} - caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} - caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} - color_aug: {subset.color_aug} - flip_aug: {subset.flip_aug} - face_crop_aug_range: {subset.face_crop_aug_range} - random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, """ @@ -688,7 +680,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.set_seed(seed) for i, dataset in enumerate(val_datasets): - print(f"[Validation Dataset {i}]") + logger.info(f"[Validation Dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) From 31507b9901d1d9ab65ba79ebd747b7f35c7e0fc1 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:15:21 +0800 Subject: [PATCH 27/76] Remove unnecessary is_train changes and use apply_debiased_estimation to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results) --- train_network.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/train_network.py b/train_network.py index 2a3a44824..4a5940cd5 100644 --- a/train_network.py +++ b/train_network.py @@ -135,7 +135,7 @@ def all_reduce_network(self, accelerator, network): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): total_loss = 0.0 timesteps_list = [10, 350, 500, 650, 990] @@ -153,7 +153,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(False), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -173,7 +173,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va # with noise offset and/or multires noise if specified for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(is_train), accelerator.autocast(): + with torch.set_grad_enabled(False), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) @@ -189,6 +189,7 @@ def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, va loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし total_loss += loss @@ -885,8 +886,7 @@ def remove_model(old_ckpt_name): for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_model): - on_step_start(text_encoder, unet) - is_train = True + on_step_start(text_encoder, unet) if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: @@ -911,7 +911,7 @@ def remove_model(old_ckpt_name): # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) - with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -941,7 +941,7 @@ def remove_model(old_ckpt_name): t.requires_grad_(True) # Predict the noise residual - with torch.set_grad_enabled(is_train), accelerator.autocast(): + with accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -1040,10 +1040,9 @@ def remove_model(old_ckpt_name): total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - is_train = False + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() current_loss = total_loss / validation_steps val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) From 1db495127f25c1b17694780f635a4760b4e345d0 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 14:53:46 +0800 Subject: [PATCH 28/76] Update train_db.py --- train_db.py | 132 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 6 deletions(-) diff --git a/train_db.py b/train_db.py index 1de504ed8..9f8ec777c 100644 --- a/train_db.py +++ b/train_db.py @@ -2,7 +2,6 @@ # XXX dropped option: fine_tune import argparse -import itertools import math import os from multiprocessing import Value @@ -41,11 +40,73 @@ setup_logging() import logging +import itertools logger = logging.getLogger(__name__) # perlin_noise, - +def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + + with accelerator.accumulate(*training_models): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(False), accelerator.autocast(): + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss + + average_loss = total_loss / len(timesteps_list) + return average_loss def train(args): train_util.verify_training_args(args) @@ -81,9 +142,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -148,6 +210,9 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") + if val_dataset_group is not None: + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() @@ -195,6 +260,15 @@ def train(args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -296,6 +370,8 @@ def train(args): train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) loss_recorder = train_util.LossRecorder() + val_loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -427,12 +503,33 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - + + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -515,7 +612,30 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset" + ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" + ) return parser From 68162172ebf9afa21ad526fc833fcc04f74aeb5f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:03:56 +0800 Subject: [PATCH 29/76] Update train_db.py --- train_db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_db.py b/train_db.py index 9f8ec777c..e98434dba 100644 --- a/train_db.py +++ b/train_db.py @@ -209,10 +209,10 @@ def train(args): vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") if val_dataset_group is not None: print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() From 96eb74f0cba3253ba29c8e87d7479c355916cca5 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:06:05 +0800 Subject: [PATCH 30/76] Update train_db.py --- train_db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_db.py b/train_db.py index e98434dba..80fdff3e7 100644 --- a/train_db.py +++ b/train_db.py @@ -210,8 +210,8 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) if val_dataset_group is not None: - print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") clean_memory_on_device(accelerator.device) From b9bdd101296b8dc3c60b25e31d04d39b57eaee71 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:11:26 +0800 Subject: [PATCH 31/76] Update train_network.py --- train_network.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/train_network.py b/train_network.py index 4a5940cd5..d7b24dae9 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,25 +1034,25 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From 3d68754defde57b10f96d9c934dd78bf25c39235 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:15:42 +0800 Subject: [PATCH 32/76] Update train_db.py --- train_db.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/train_db.py b/train_db.py index 80fdff3e7..800a157bf 100644 --- a/train_db.py +++ b/train_db.py @@ -503,28 +503,26 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - - + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From a593e837f36b6299101dc85a367c0986501ecc0a Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:17:30 +0800 Subject: [PATCH 33/76] Update train_network.py --- train_network.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/train_network.py b/train_network.py index d7b24dae9..7d9134638 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,26 +1034,26 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From f6dbf7c419bbcf2e51c82a6bffa8d30cad2e3512 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:18:53 +0800 Subject: [PATCH 34/76] Update train_network.py --- train_network.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/train_network.py b/train_network.py index 7d9134638..fa6407eef 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,26 +1034,26 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From aa850aa531b0e396b6f2fbd68cd1e6f1319d1d0b Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:34:20 +0800 Subject: [PATCH 35/76] Update train_network.py --- train_network.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/train_network.py b/train_network.py index fa6407eef..938e41938 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,25 +1034,25 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From cdb2d9c516fbffe0faa9788b8174e5d418fb766b Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:36:34 +0800 Subject: [PATCH 36/76] Update train_network.py --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 938e41938..e10c17c0c 100644 --- a/train_network.py +++ b/train_network.py @@ -192,7 +192,7 @@ def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_s loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし total_loss += loss - + average_loss = total_loss / len(timesteps_list) return average_loss From 3028027e074c891f33d45fff27068b490a408329 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Fri, 4 Oct 2024 16:41:41 +0800 Subject: [PATCH 37/76] Update train_network.py --- train_network.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/train_network.py b/train_network.py index e10c17c0c..c0239a6da 100644 --- a/train_network.py +++ b/train_network.py @@ -1034,26 +1034,26 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) - + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break From dece2c388f1c39e7baca201b4bf4e61d9f67a219 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Fri, 4 Oct 2024 16:43:07 +0800 Subject: [PATCH 38/76] Update train_db.py --- train_db.py | 164 ++++++++++++++++++++++++++-------------------------- 1 file changed, 82 insertions(+), 82 deletions(-) diff --git a/train_db.py b/train_db.py index 800a157bf..2c17e521f 100644 --- a/train_db.py +++ b/train_db.py @@ -46,67 +46,67 @@ # perlin_noise, def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): - total_loss = 0.0 - timesteps_list = [10, 350, 500, 650, 990] - - with accelerator.accumulate(*training_models): - with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - with torch.set_grad_enabled(False), accelerator.autocast(): - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - - for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(False), accelerator.autocast(): - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] - timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss: - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss - - average_loss = total_loss / len(timesteps_list) - return average_loss + total_loss = 0.0 + timesteps_list = [10, 350, 500, 650, 990] + + with accelerator.accumulate(*training_models): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(False), accelerator.autocast(): + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + if args.masked_loss: + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss + + average_loss = total_loss / len(timesteps_list) + return average_loss def train(args): train_util.verify_training_args(args) @@ -210,8 +210,8 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) if val_dataset_group is not None: - print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + print("Cache validation latents...") + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -503,25 +503,25 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: + accelerator.print("Validating バリデーション処理...") + total_loss = 0.0 + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc='Validation Steps'): + batch = next(cyclic_val_dataloader) + loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + total_loss += loss.detach().item() + current_loss = total_loss / validation_steps + val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/average_val_loss": avr_loss} + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From 05bb9183fae18c62a1730fe5060f80c0b99a21f3 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Fri, 27 Dec 2024 16:47:59 +0800 Subject: [PATCH 39/76] Add Validation loss for LoRA training --- library/config_util.py | 78 +++++++++++++++++++++++- library/train_util.py | 54 ++++++++++++++++- train_network.py | 131 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 257 insertions(+), 6 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 12d0be173..a57cd36f0 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -73,6 +73,8 @@ class BaseSubsetParams: token_warmup_min: int = 1 token_warmup_step: float = 0 custom_attributes: Optional[Dict[str, Any]] = None + validation_seed: int = 0 + validation_split: float = 0.0 @dataclass @@ -102,6 +104,8 @@ class BaseDatasetParams: resolution: Optional[Tuple[int, int]] = None network_multiplier: float = 1.0 debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass @@ -478,9 +482,27 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) + val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.params.validation_split <= 0.0: + continue + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) + val_datasets.append(dataset) + # print info info = "" for i, dataset in enumerate(datasets): @@ -566,6 +588,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu logger.info(f"{info}") + if len(val_datasets) > 0: + info = "" + + for i, dataset in enumerate(val_datasets): + info += dedent( + f"""\ + [Validation Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Validation Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + """ + ), + " ", + ) + + logger.info(f"{info}") + # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no @@ -574,7 +640,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset.make_buckets() dataset.set_seed(seed) - return DatasetGroup(datasets) + for i, dataset in enumerate(val_datasets): + logger.info(f"[Validation Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return ( + DatasetGroup(datasets), + DatasetGroup(val_datasets) if val_datasets else None + ) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): diff --git a/library/train_util.py b/library/train_util.py index 72b5b24db..a3fa98e99 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -145,6 +145,17 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" +def split_train_val(paths: List[str], validation_split: float, validation_seed: int) -> List[str]: + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + return paths[len(paths) - round(len(paths) * validation_split):] class ImageInfo: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -397,6 +408,8 @@ def __init__( token_warmup_min: int, token_warmup_step: Union[float, int], custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -424,6 +437,9 @@ def __init__( self.img_count = 0 + self.validation_seed = validation_seed + self.validation_split = validation_split + class DreamBoothSubset(BaseSubset): def __init__( @@ -453,6 +469,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -478,6 +496,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.is_reg = is_reg @@ -518,6 +538,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -543,6 +565,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.metadata_file = metadata_file @@ -579,6 +603,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes: Optional[Dict[str, Any]] = None, + validation_seed: Optional[int] = None, + validation_split: Optional[float] = 0.0, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -604,6 +630,8 @@ def __init__( token_warmup_min, token_warmup_step, custom_attributes=custom_attributes, + validation_seed=validation_seed, + validation_split=validation_split, ) self.conditioning_data_dir = conditioning_data_dir @@ -1799,6 +1827,9 @@ def __init__( bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -1808,6 +1839,9 @@ def __init__( self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_train = is_train + self.validation_seed = validation_seed + self.validation_split = validation_split self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1992,6 +2026,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): ) continue + if self.is_train == False: + img_paths = split_train_val(img_paths, self.validation_split, self.validation_seed) + if subset.is_reg: num_reg_images += subset.num_repeats * len(img_paths) else: @@ -2009,7 +2046,11 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeating.") + if self.is_train: + logger.info(f"{num_train_images} train images with repeating.") + else: + logger.info(f"{num_train_images} validation images with repeating.") + self.num_train_images = num_train_images logger.info(f"{num_reg_images} reg images.") @@ -2050,6 +2091,9 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2276,6 +2320,9 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: float, + is_train: bool, + validation_seed: int, + validation_split: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2324,6 +2371,9 @@ def __init__( bucket_no_upscale, 1.0, debug_dataset, + is_train, + validation_seed, + validation_split, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -4887,7 +4937,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - + if optimizer_type == "RAdamScheduleFree".lower(): optimizer_class = sf.RAdamScheduleFree logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") diff --git a/train_network.py b/train_network.py index 5e82b307c..776feaf76 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ from multiprocessing import Value from typing import Any, List import toml +import itertools from tqdm import tqdm @@ -114,7 +115,7 @@ def generate_step_logs( ) if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): + ): logs[f"lr/d*lr/group{i}"] = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) @@ -373,10 +374,11 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -398,6 +400,11 @@ def train(self, args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if val_dataset_group is not None: + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # may change some args # acceleratorを準備する @@ -444,6 +451,8 @@ def train(self, args): vae.eval() train_dataset_group.new_cache_latents(vae, accelerator) + if val_dataset_group is not None: + val_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -459,6 +468,8 @@ def train(self, args): if text_encoder_outputs_caching_strategy is not None: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + if val_dataset_group is not None: + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) # prepare network net_kwargs = {} @@ -567,6 +578,8 @@ def train(self, args): # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None train_dataset_group.set_current_strategies() + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -580,6 +593,17 @@ def train(self, args): persistent_workers=args.persistent_data_loader_workers, ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + batch_size=1, + shuffle=False, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + cyclic_val_dataloader = itertools.cycle(val_dataloader) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( @@ -592,6 +616,10 @@ def train(self, args): # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) + # Not for sure here. + # if val_dataset_group is not None: + # val_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -1064,7 +1092,11 @@ def load_model_hook(models, input_dir): ) loss_recorder = train_util.LossRecorder() + # val_loss_recorder = train_util.LossRecorder() + del train_dataset_group + if val_dataset_group is not None: + del val_dataset_group # callback for step start if hasattr(accelerator.unwrap_model(network), "on_step_start"): @@ -1308,6 +1340,77 @@ def remove_model(old_ckpt_name): ) accelerator.log(logs, step=global_step) + if len(val_dataloader) > 0: + if ((args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps): + accelerator.print("\nValidating バリデーション処理...") + + total_loss = 0.0 + + with torch.no_grad(): + validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + batch = next(cyclic_val_dataloader) + + timesteps_list = [10, 350, 500, 650, 990] + + val_loss = 0.0 + + for fixed_timesteps in timesteps_list: + with torch.set_grad_enabled(False), accelerator.autocast(): + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] + + timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") + timesteps = timesteps.long().to(latents.device) + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(False), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + val_loss += loss / len(timesteps_list) + + total_loss += val_loss.detach().item() + + current_val_loss = total_loss / validation_steps + # val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_val_loss) + + if len(accelerator.trackers) > 0: + logs = {"loss/current_val_loss": current_val_loss} + accelerator.log(logs, step=global_step) + + # avr_loss: float = val_loss_recorder.moving_average + # logs = {"loss/average_val_loss": avr_loss} + # accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break @@ -1496,6 +1599,30 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) + parser.add_argument( + "--validation_seed", + type=int, + default=None, + help="Validation seed / 検証シード" + ) + parser.add_argument( + "--validation_split", + type=float, + default=0.0, + help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" + ) + parser.add_argument( + "--validation_every_n_step", + type=int, + default=None, + help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + ) + parser.add_argument( + "--max_validation_steps", + type=int, + default=None, + help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する" + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") From 62164e57925125ed6268983ffa441f1ffecc0e6d Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Fri, 27 Dec 2024 17:28:05 +0800 Subject: [PATCH 40/76] Change val loss calculate method --- train_network.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index 776feaf76..5fd1b212f 100644 --- a/train_network.py +++ b/train_network.py @@ -1383,16 +1383,20 @@ def remove_model(old_ckpt_name): else: target = noise - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) + # huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + # loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + # if weighting is not None: + # loss = loss * weighting + # if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + # loss = apply_masked_loss(loss, batch) + # loss = loss.mean([1, 2, 3]) # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. - loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 64bd5317dc9cb39d69ab7728f36b03157c9b341f Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Sat, 28 Dec 2024 11:42:15 +0800 Subject: [PATCH 41/76] Split val latents/batch and pick up val latents shape size which equal to training batch. --- train_network.py | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/train_network.py b/train_network.py index 5fd1b212f..6bce9e964 100644 --- a/train_network.py +++ b/train_network.py @@ -1349,7 +1349,27 @@ def remove_model(old_ckpt_name): with torch.no_grad(): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): - batch = next(cyclic_val_dataloader) + + while True: + val_batch = next(cyclic_val_dataloader) + + if "latents" in val_batch and val_batch["latents"] is not None: + val_latents = val_batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + val_latents = self.encode_images_to_latents(args, accelerator, vae, val_batch["images"].to(vae_dtype)) + val_latents = val_latents.to(dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(val_latents)): + accelerator.print("NaN found in validation latents, replacing with zeros") + val_latents = torch.nan_to_num(val_latents, 0, out=val_latents) + + val_latents = self.shift_scale_latents(args, val_latents) + + if val_latents.shape == latents.shape: + break timesteps_list = [10, 350, 500, 650, 990] @@ -1357,13 +1377,13 @@ def remove_model(old_ckpt_name): for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(False), accelerator.autocast(): - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] + noise = torch.randn_like(val_latents, device=val_latents.device) + b_size = val_latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") - timesteps = timesteps.long().to(latents.device) + timesteps = timesteps.long().to(val_latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps) with accelerator.autocast(): noise_pred = self.call_unet( @@ -1373,27 +1393,16 @@ def remove_model(old_ckpt_name): noisy_latents.requires_grad_(False), timesteps, text_encoder_conds, - batch, + val_batch, weight_dtype, ) if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(val_latents, noise, timesteps) else: target = noise - # huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - # loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - # if weighting is not None: - # loss = loss * weighting - # if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - # loss = apply_masked_loss(loss, batch) - # loss = loss.mean([1, 2, 3]) - - # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. - # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) From cb89e0284e1a25b41401861107159e6b943ee387 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Sat, 28 Dec 2024 11:57:04 +0800 Subject: [PATCH 42/76] Change val latent loss compare --- train_network.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index 6bce9e964..7276d5dc0 100644 --- a/train_network.py +++ b/train_network.py @@ -1350,6 +1350,8 @@ def remove_model(old_ckpt_name): validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + val_latents = None + while True: val_batch = next(cyclic_val_dataloader) @@ -1371,19 +1373,22 @@ def remove_model(old_ckpt_name): if val_latents.shape == latents.shape: break + if val_latents is not None: + del val_latents + timesteps_list = [10, 350, 500, 650, 990] val_loss = 0.0 for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(False), accelerator.autocast(): - noise = torch.randn_like(val_latents, device=val_latents.device) - b_size = val_latents.shape[0] + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") - timesteps = timesteps.long().to(val_latents.device) + timesteps = timesteps.long().to(latents.device) - noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) with accelerator.autocast(): noise_pred = self.call_unet( @@ -1399,7 +1404,7 @@ def remove_model(old_ckpt_name): if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(val_latents, noise, timesteps) + target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise From 874353296304c753b452511a412472f8a3e4ba09 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 04:37:16 +0800 Subject: [PATCH 43/76] val --- library/config_util.py | 32 +++++++------ library/train_util.py | 20 ++++++-- train_network.py | 104 +++++++++++++++++++++++++++-------------- 3 files changed, 103 insertions(+), 53 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 1bf7ed955..cb2c5b68f 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -81,23 +81,24 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - debug_dataset: bool = False - validation_seed: Optional[int] = None - validation_split: float = 0.0 + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 + debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 - + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -203,8 +204,9 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "max_bucket_reso": int, "min_bucket_reso": int, "validation_seed": int, - "validation_split": float, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, } # options handled by argparse but not handled by user config diff --git a/library/train_util.py b/library/train_util.py index 1979207b0..2364d62b3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -122,6 +122,20 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] def split_train_val(paths, is_train, validation_split, validation_seed): if validation_seed is not None: @@ -1352,7 +1366,6 @@ def __init__( self.is_train = is_train self.validation_split = validation_split self.validation_seed = validation_seed - self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1405,10 +1418,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): return [], [] img_paths = glob_images(subset.image_dir, "*") - if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] diff --git a/train_network.py b/train_network.py index edd3ff944..48885503f 100644 --- a/train_network.py +++ b/train_network.py @@ -130,7 +130,9 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True, timesteps_list=None): + total_loss = 0.0 + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -167,37 +169,40 @@ def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, n args, noise_scheduler, latents ) - # Predict the noise residual - with torch.set_grad_enabled(is_train), accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + # Use input timesteps_list or use described timesteps above + timesteps_list = timesteps_list or [timesteps] + for timesteps in timesteps_list: + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - loss = loss * loss_weights + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + loss = loss * loss_weights - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - return loss + total_loss += loss.mean() # 平均なのでbatch_sizeで割る必要なし + average_loss = total_loss / len(timesteps_list) + return average_loss def train(self, args): session_id = random.randint(0, 2**32) @@ -283,10 +288,10 @@ def train(self, args): train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" if val_dataset_group is not None: - assert ( - val_dataset_group.is_latent_cacheable() - ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する @@ -430,6 +435,15 @@ def train(self, args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) val_dataloader = torch.utils.data.DataLoader( val_dataset_group if val_dataset_group is not None else [], @@ -798,7 +812,6 @@ def train(self, args): loss_recorder = train_util.LossRecorder() val_loss_recorder = train_util.LossRecorder() - del train_dataset_group # callback for step start @@ -848,7 +861,6 @@ def remove_model(old_ckpt_name): on_step_start(text_encoder, unet) is_train = True loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder) - accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = network.get_trainable_params() @@ -900,7 +912,25 @@ def remove_model(old_ckpt_name): if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) - + + if global_step % 25 == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + + with torch.no_grad(): + val_dataloader_iter = iter(val_dataloader) + batch = next(val_dataloader_iter) + is_train = False + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990]) + + current_loss = loss.detach().item() + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break @@ -912,7 +942,7 @@ def remove_model(old_ckpt_name): with torch.no_grad(): for val_step, batch in enumerate(val_dataloader): is_train = False - loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) @@ -933,6 +963,12 @@ def remove_model(old_ckpt_name): logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) + if len(val_dataloader) > 0: + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存 From 449c1c5c502375713e609ad9e00e747b4013063a Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 2 Jan 2025 15:59:20 -0500 Subject: [PATCH 44/76] Adding modified train_util and config_util --- library/config_util.py | 1 - library/train_util.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index cb2c5b68f..727e1a409 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -84,7 +84,6 @@ class BaseDatasetParams: tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None max_token_length: int = None resolution: Optional[Tuple[int, int]] = None - network_multiplier: float = 1.0 debug_dataset: bool = False validation_seed: Optional[int] = None validation_split: float = 0.0 diff --git a/library/train_util.py b/library/train_util.py index 2364d62b3..394337397 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1420,7 +1420,7 @@ def load_dreambooth_dir(subset: DreamBoothSubset): img_paths = glob_images(subset.image_dir, "*") if self.validation_split > 0.0: img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] From 7470173044ca5b700bc4723709bd9c012e2216f3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:13:57 -0500 Subject: [PATCH 45/76] Remove defunct code for train_controlnet.py --- train_controlnet.py | 569 -------------------------------------------- 1 file changed, 569 deletions(-) diff --git a/train_controlnet.py b/train_controlnet.py index 09a911a00..365e35c8c 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -6,577 +6,8 @@ logger = logging.getLogger(__name__) -<<<<<<< HEAD -# TODO 他のスクリプトと共通化する -def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = { - "loss/current": current_loss, - "loss/average": avr_loss, - "lr": lr_scheduler.get_last_lr()[0], - } - - if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - - return logs - - -def train(args): - # session_id = random.randint(0, 2**32) - # training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - setup_logging(args, reset=True) - - cache_latents = args.cache_latents - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - tokenizer = train_util.load_tokenizer(args) - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) - if use_user_config: - logger.info(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "conditioning_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( - args.train_data_dir, - args.conditioning_data_dir, - args.caption_extension, - ) - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - logger.error( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - logger.info("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model( - args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True - ) - - # DiffusersのControlNetが使用するデータを準備する - if args.v2: - unet.config = { - "act_fn": "silu", - "attention_head_dim": [5, 10, 20, 20], - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 1024, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "dual_cross_attention": False, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_class_embeds": None, - "only_cross_attention": False, - "out_channels": 4, - "sample_size": 96, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "use_linear_projection": True, - "upcast_attention": True, - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": True, - "class_embed_type": None, - "num_class_embeds": None, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - else: - unet.config = { - "act_fn": "silu", - "attention_head_dim": 8, - "block_out_channels": [320, 640, 1280, 1280], - "center_input_sample": False, - "cross_attention_dim": 768, - "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], - "downsample_padding": 1, - "flip_sin_to_cos": True, - "freq_shift": 0, - "in_channels": 4, - "layers_per_block": 2, - "mid_block_scale_factor": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "out_channels": 4, - "sample_size": 64, - "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], - "only_cross_attention": False, - "downsample_padding": 1, - "use_linear_projection": False, - "class_embed_type": None, - "num_class_embeds": None, - "upcast_attention": False, - "resnet_time_scale_shift": "default", - "projection_class_embeddings_input_dim": None, - } - unet.config = SimpleNamespace(**unet.config) - - controlnet = ControlNetModel.from_unet(unet) - - if args.controlnet_model_name_or_path: - filename = args.controlnet_model_name_or_path - if os.path.isfile(filename): - if os.path.splitext(filename)[1] == ".safetensors": - state_dict = load_file(filename) - else: - state_dict = torch.load(filename) - state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) - controlnet.load_state_dict(state_dict) - elif os.path.isdir(filename): - controlnet = ControlNetModel.from_pretrained(filename) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) - vae.to("cpu") - clean_memory_on_device(accelerator.device) - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - controlnet.enable_gradient_checkpointing() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - trainable_params = controlnet.parameters() - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 - n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - controlnet.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) - - unet.requires_grad_(False) - text_encoder.requires_grad_(False) - unet.to(accelerator.device) - text_encoder.to(accelerator.device) - - # transform DDP after prepare - controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet - - controlnet.train() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, - ) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - loss_recorder = train_util.LossRecorder() - del train_dataset_group - - # function for saving/removing - def save_model(ckpt_name, model, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(ckpt_file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, ckpt_file) - else: - torch.save(state_dict, ckpt_file) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # For --sample_at_first - train_util.sample_images( - accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet - ) - - # training loop - for epoch in range(num_train_epochs): - if is_main_process: - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(controlnet): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like( - noise, - latents.device, - args.multires_noise_iterations, - args.multires_noise_discount, - ) - - # Sample a random timestep for each image - timesteps = train_util.get_timesteps(args, 0, noise_scheduler.config.num_train_timesteps, b_size) - huber_c = train_util.get_huber_c(args, noise_scheduler, timesteps.item(), latents.device) - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=controlnet_image, - return_dict=False, - ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states, - down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - ).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model( - ckpt_name, - accelerator.unwrap_model(controlnet), - ) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - avr_loss: float = loss_recorder.moving_average - logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(controlnet)) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - train_util.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # end of epoch - if is_main_process: - controlnet = accelerator.unwrap_model(controlnet) - - accelerator.end_training() - - if is_main_process and (args.save_state or args.save_state_on_train_end): - train_util.save_state_on_train_end(args, accelerator) - - # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, controlnet, force_sync_upload=True) - - logger.info("model saved.") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - add_logging_arguments(parser) - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - deepspeed_utils.add_deepspeed_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) - - parser.add_argument( - "--save_model_as", - type=str, - default="safetensors", - choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", - ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - parser.add_argument( - "--conditioning_data_dir", - type=str, - default=None, - help="conditioning data directory / 条件付けデータのディレクトリ", - ) - - return parser - -======= from library import train_util from train_control_net import setup_parser, train ->>>>>>> hina/feature/val-loss if __name__ == "__main__": logger.warning( From 534059dea517d44de387e7d467d64209f9dcfba2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:18:15 -0500 Subject: [PATCH 46/76] Typos and lingering is_train --- library/config_util.py | 2 +- library/train_util.py | 4 ---- train_network.py | 6 +++--- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a09d2c7ca..418c179dc 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -535,7 +535,7 @@ def print_info(_datasets): shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} diff --git a/library/train_util.py b/library/train_util.py index bf1b6731c..220d4702b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2092,7 +2092,6 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, - is_train: bool, validation_seed: int, validation_split: float, ) -> None: @@ -2312,7 +2311,6 @@ class ControlNetDataset(BaseDataset): def __init__( self, subsets: Sequence[ControlNetSubset], - is_train: bool, batch_size: int, resolution, network_multiplier: float, @@ -2362,7 +2360,6 @@ def __init__( self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, - is_train, batch_size, resolution, network_multiplier, @@ -2382,7 +2379,6 @@ def __init__( self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images - self.is_train = is_train self.validation_split = validation_split self.validation_seed = validation_seed diff --git a/train_network.py b/train_network.py index 99b9717a5..4bcfc0ac7 100644 --- a/train_network.py +++ b/train_network.py @@ -380,11 +380,11 @@ def pick_timesteps_list() -> torch.IntTensor: else: return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device)) - choosen_timesteps_list = pick_timesteps_list() + chosen_timesteps_list = pick_timesteps_list() total_loss = torch.zeros((batch_size, 1)).to(latents.device) # Use input timesteps_list or use described timesteps above - for fixed_timestep in choosen_timesteps_list: + for fixed_timestep in chosen_timesteps_list: fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep) # Predict the noise residual @@ -447,7 +447,7 @@ def pick_timesteps_list() -> torch.IntTensor: total_loss += loss - return total_loss / len(choosen_timesteps_list) + return total_loss / len(chosen_timesteps_list) def train(self, args): session_id = random.randint(0, 2**32) From c8c3569df292109fe3be4d209c9f6131afe2ba5f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:26:45 -0500 Subject: [PATCH 47/76] Cleanup order, types, print to logger --- library/config_util.py | 7 +++---- library/train_util.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 418c179dc..5a4d3aa2d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -485,7 +485,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) datasets.append(dataset) - val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: if dataset_blueprint.params.validation_split <= 0.0: continue @@ -503,7 +503,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - # print info def print_info(_datasets): info = "" for i, dataset in enumerate(_datasets): @@ -565,7 +564,7 @@ def print_info(_datasets): print_info(datasets) if len(val_datasets) > 0: - print("Validation dataset") + logger.info("Validation dataset") print_info(val_datasets) if len(val_datasets) > 0: @@ -610,7 +609,7 @@ def print_info(_datasets): " ", ) - logger.info(f"{info}") + logger.info(info) # make buckets first because it determines the length of dataset # and set the same seed for all datasets diff --git a/library/train_util.py b/library/train_util.py index 220d4702b..782f57e8f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1833,9 +1833,9 @@ def __init__( bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, + debug_dataset: bool, validation_split: float, validation_seed: Optional[int], - debug_dataset, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2319,9 +2319,9 @@ def __init__( max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, + debug_dataset: bool, validation_split: float, validation_seed: Optional[int], - debug_dataset: float, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset) @@ -2369,9 +2369,9 @@ def __init__( bucket_reso_steps, bucket_no_upscale, 1.0, + debug_dataset, validation_split, validation_seed, - debug_dataset ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) From fbfc2753eb7fa57724eb525ee65d851b5e80b8ea Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:53:12 -0500 Subject: [PATCH 48/76] Update text for train/reg with repeats --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 782f57e8f..77a6a9f9a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2050,11 +2050,11 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} images with repeating.") + logger.info(f"{num_train_images} train images with repeats.") self.num_train_images = num_train_images - logger.info(f"{num_reg_images} reg images.") + logger.info(f"{num_reg_images} reg images with repeats.") if num_train_images < num_reg_images: logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") From 58bfa36d0275d864d5a2d64c51632e808f789ddd Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 02:00:28 -0500 Subject: [PATCH 49/76] Add seed help clarifying info --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 4bcfc0ac7..7d064d210 100644 --- a/train_network.py +++ b/train_network.py @@ -1639,7 +1639,7 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed / 検証シード" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証シード" ) parser.add_argument( "--validation_split", From 6604b36044a83f3531faed508096f3e6bfe48fc9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 02:04:59 -0500 Subject: [PATCH 50/76] Remove duplicate assignment --- library/train_util.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 77a6a9f9a..3710c865d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -86,8 +86,6 @@ import library.deepspeed_utils as deepspeed_utils from library.utils import setup_logging, pil_resize - - setup_logging() import logging @@ -1841,8 +1839,6 @@ def __init__( assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" - self.validation_split = validation_split - self.validation_seed = validation_seed self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight From 0522070d197d92745dbdb408d74c9c3f869bff76 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:20:25 -0500 Subject: [PATCH 51/76] Fix training, validation split, revert to using upstream implemenation --- library/config_util.py | 67 +++----------- library/custom_train_functions.py | 6 +- library/strategy_sd.py | 2 +- library/train_util.py | 143 +++++++++++++++++------------- train_network.py | 94 ++++++++++++-------- 5 files changed, 152 insertions(+), 160 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 5a4d3aa2d..63d28c969 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -482,7 +482,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] @@ -500,16 +500,16 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) - def print_info(_datasets): + def print_info(_datasets, dataset_type: str): info = "" for i, dataset in enumerate(_datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) is_controlnet = isinstance(dataset, ControlNetDataset) info += dedent(f"""\ - [Dataset {i}] + [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} @@ -527,7 +527,7 @@ def print_info(_datasets): for j, subset in enumerate(dataset.subsets): info += indent(dedent(f"""\ - [Subset {j} of Dataset {i}] + [Subset {j} of {dataset_type} {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} num_repeats: {subset.num_repeats} @@ -544,8 +544,8 @@ def print_info(_datasets): random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask} - custom_attributes: {subset.custom_attributes} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} """), " ") if is_dreambooth: @@ -561,67 +561,22 @@ def print_info(_datasets): logger.info(info) - print_info(datasets) + print_info(datasets, "Dataset") if len(val_datasets) > 0: - logger.info("Validation dataset") - print_info(val_datasets) - - if len(val_datasets) > 0: - info = "" - - for i, dataset in enumerate(val_datasets): - info += dedent( - f"""\ - [Validation Dataset {i}] - batch_size: {dataset.batch_size} - resolution: {(dataset.width, dataset.height)} - enable_bucket: {dataset.enable_bucket} - network_multiplier: {dataset.network_multiplier} - """ - ) - - if dataset.enable_bucket: - info += indent( - dedent( - f"""\ - min_bucket_reso: {dataset.min_bucket_reso} - max_bucket_reso: {dataset.max_bucket_reso} - bucket_reso_steps: {dataset.bucket_reso_steps} - bucket_no_upscale: {dataset.bucket_no_upscale} - \n""" - ), - " ", - ) - else: - info += "\n" - - for j, subset in enumerate(dataset.subsets): - info += indent( - dedent( - f"""\ - [Subset {j} of Validation Dataset {i}] - image_dir: "{subset.image_dir}" - image_count: {subset.img_count} - num_repeats: {subset.num_repeats} - """ - ), - " ", - ) - - logger.info(info) + print_info(val_datasets, "Validation Dataset") # make buckets first because it determines the length of dataset # and set the same seed for all datasets seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): - logger.info(f"[Dataset {i}]") + logger.info(f"[Prepare dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) for i, dataset in enumerate(val_datasets): - logger.info(f"[Validation Dataset {i}]") + logger.info(f"[Prepare validation dataset {i}]") dataset.make_buckets() dataset.set_seed(seed) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 9a7c21a3e..ad3e69ffb 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -455,7 +455,7 @@ def get_weighted_text_embeddings( # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.4): +def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor: b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): @@ -468,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4): # https://www.crosslabs.org//blog/diffusion-with-offset-noise -def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor: if noise_offset is None: return noise if adaptive_noise_scale is not None: @@ -484,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def apply_masked_loss(loss, batch) -> torch.FloatTensor: if "conditioning_images" in batch: # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel diff --git a/library/strategy_sd.py b/library/strategy_sd.py index d0a3a68bf..a44fc4092 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,7 +40,7 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] - def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: text = [text] if isinstance(text, str) else text tokens_list = [] weights_list = [] diff --git a/library/train_util.py b/library/train_util.py index 3710c865d..0f16a4f31 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -146,7 +146,15 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -def split_train_val(paths: List[str], is_train: bool, validation_split: float, validation_seed: int) -> List[str]: +def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]: + """ + Split the dataset into train and validation + + Shuffle the dataset based on the validation_seed or the current random seed. + For example if the split of 0.2 of 100 images. + [0:79] = 80 training images + [80:] = 20 validation images + """ if validation_seed is not None: print(f"Using validation seed: {validation_seed}") prevstate = random.getstate() @@ -156,9 +164,12 @@ def split_train_val(paths: List[str], is_train: bool, validation_split: float, v else: random.shuffle(paths) - if is_train: + # Split the dataset between training and validation + if is_training_dataset: + # Training dataset we split to the first part return paths[0:math.ceil(len(paths) * (1 - validation_split))] else: + # Validation dataset we split to the second part return paths[len(paths) - round(len(paths) * validation_split):] @@ -1822,6 +1833,7 @@ class DreamBoothDataset(BaseDataset): def __init__( self, subsets: Sequence[DreamBoothSubset], + is_training_dataset: bool, batch_size: int, resolution, network_multiplier: float, @@ -1843,6 +1855,7 @@ def __init__( self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.latents_cache = None + self.is_training_dataset = is_training_dataset self.validation_seed = validation_seed self.validation_split = validation_split @@ -1952,6 +1965,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + if self.validation_split > 0.0: + img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -2046,7 +2062,8 @@ def load_dreambooth_dir(subset: DreamBoothSubset): subset.img_count = len(img_paths) self.subsets.append(subset) - logger.info(f"{num_train_images} train images with repeats.") + images_split_name = "train" if self.is_training_dataset else "validation" + logger.info(f"{num_train_images} {images_split_name} images with repeats.") self.num_train_images = num_train_images @@ -2411,8 +2428,12 @@ def __init__( conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -4586,7 +4607,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar config_args = argparse.Namespace(**ignore_nesting_dict) args = parser.parse_args(namespace=config_args) args.config_file = os.path.splitext(args.config_file)[0] - logger.info(args.config_file) return args @@ -5880,55 +5900,35 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_random_timesteps(args, min_timestep: int, max_timestep: int, batch_size: int, device: torch.device) -> torch.IntTensor: - """ - Get a random timestep between the min and max timesteps - Can error (NotImplementedError) if the loss type is not supported - """ - # TODO: if a huber loss is selected, it will use constant timesteps for each batch - # as. In the future there may be a smarter way - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") - timesteps = timesteps.repeat(batch_size).to(device) - elif args.loss_type == "l2": - timesteps = torch.randint(min_timestep, max_timestep, (batch_size,), device=device) - else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") - - return typing.cast(torch.IntTensor, timesteps) - +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) + return timesteps -def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor) -> Optional[float]: - """ - Calculate the Huber convolution (huber_c) value - Huber loss is a loss function used in robust regression, that is less sensitive - to outliers in data than the squared error loss. - https://en.wikipedia.org/wiki/Huber_loss - """ - if args.loss_type == "huber" or args.loss_type == "smooth_l1": - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.get('num_train_timesteps', 1000) - huber_c = math.exp(-alpha * timesteps.item()) - elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): - raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = noise_scheduler.alphas_cumprod.index_select(0, timesteps) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - elif args.huber_schedule == "constant": - huber_c = args.huber_c - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - elif args.loss_type == "l2": +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) else: - raise NotImplementedError(f"Unknown loss type {args.loss_type}") + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - return huber_c + return result -def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor): +def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: """ Apply noise modifications like noise offset and multires noise """ @@ -5964,27 +5964,44 @@ def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep # Sample a random timestep for each image - timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, device) + timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device) return timesteps -def get_noise_noisy_latents_and_timesteps(args, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor, Optional[float]]: - """ - Unified noise, noisy_latents, timesteps and huber loss convolution calculations - """ - batch_size = latents.shape[0] +def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) + if args.multires_noise_iterations: + noise = custom_train_functions.pyramid_noise_like( + noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount + ) + + # Sample a random timestep for each image + b_size = latents.shape[0] min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.get("num_train_timesteps", 1000) if args.max_timestep is None else args.max_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - # A random timestep for each image in the batch - timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, latents.device) - huber_c = get_huber_c(args, noise_scheduler, timesteps) + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) - noise = make_noise(args, latents) - noisy_latents = get_noisy_latents(args, noise, noise_scheduler, latents, timesteps) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - return noise, noisy_latents, timesteps, huber_c + return noise, noisy_latents, timesteps def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: @@ -6015,6 +6032,8 @@ def conditional_loss( elif loss_type == "l1": loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": @@ -6022,6 +6041,8 @@ def conditional_loss( elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + if huber_c is None: + raise NotImplementedError("huber_c not implemented correctly") huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": diff --git a/train_network.py b/train_network.py index 7d064d210..f870734fd 100644 --- a/train_network.py +++ b/train_network.py @@ -205,10 +205,10 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, accelerator, vae, images): + def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor: return vae.encode(images).latent_dist.sample() - def shift_scale_latents(self, args, latents): + def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor: return latents * self.vae_scale_factor def get_noise_pred_and_target( @@ -280,7 +280,7 @@ def get_noise_pred_and_target( return noise_pred, target, timesteps, None - def post_process_loss(self, loss, args, timesteps, noise_scheduler): + def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: @@ -317,20 +317,21 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: + def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: - latents: torch.Tensor = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) + latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) else: # latentに変換 - latents: torch.Tensor = typing.cast(torch.FloatTensor, typing.cast(AutoencoderKLOutput, vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype))).latent_dist.sample()) + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = typing.cast(torch.FloatTensor, torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)) - latents = typing.cast(torch.FloatTensor, latents * self.vae_scale_factor) + latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents)) + + latents = self.shift_scale_latents(args, latents) text_encoder_conds = [] @@ -384,22 +385,36 @@ def pick_timesteps_list() -> torch.IntTensor: total_loss = torch.zeros((batch_size, 1)).to(latents.device) # Use input timesteps_list or use described timesteps above - for fixed_timestep in chosen_timesteps_list: - fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep) + for fixed_timesteps in chosen_timesteps_list: + fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps) # Predict the noise residual # and add noise to the latents # with noise offset and/or multires noise if specified - noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timestep) + noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents.requires_grad_(train_unet), fixed_timestep, text_encoder_conds, batch, weight_dtype + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + fixed_timesteps, + text_encoder_conds, + batch, + weight_dtype, ) if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, fixed_timestep) + target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps) else: target = noise @@ -418,7 +433,7 @@ def pick_timesteps_list() -> torch.IntTensor: accelerator, unet, noisy_latents, - timesteps, + fixed_timesteps, text_encoder_conds, batch, weight_dtype, @@ -427,7 +442,8 @@ def pick_timesteps_list() -> torch.IntTensor: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): @@ -436,14 +452,7 @@ def pick_timesteps_list() -> torch.IntTensor: loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, fixed_timestep, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, fixed_timestep, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, fixed_timestep, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, fixed_timestep, noise_scheduler) + loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler) total_loss += loss @@ -526,8 +535,12 @@ def train(self, args): collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) + + if val_dataset_group is not None: + val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly + train_util.debug_dataset(val_dataset_group) return if len(train_dataset_group) == 0: logger.error( @@ -753,10 +766,6 @@ def train(self, args): # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) - # Not for sure here. - # if val_dataset_group is not None: - # val_dataset_group.set_max_train_steps(args.max_train_steps) - # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -1304,7 +1313,7 @@ def remove_model(old_ckpt_name): clean_memory_on_device(accelerator.device) for epoch in range(epoch_to_start, num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) @@ -1324,7 +1333,7 @@ def remove_model(old_ckpt_name): continue with accelerator.accumulate(training_model): - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -1384,7 +1393,8 @@ def remove_model(old_ckpt_name): logs = self.generate_step_logs( args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm ) - accelerator.log(logs, step=global_step) + # accelerator.log(logs, step=global_step) + accelerator.log(logs) # VALIDATION PER STEP should_validate = (args.validation_every_n_step is not None @@ -1401,7 +1411,7 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) @@ -1409,10 +1419,12 @@ def remove_model(old_ckpt_name): if is_tracking: logs = {"loss/current_val_loss": loss.detach().item()} - accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + accelerator.log(logs) logs = {"loss/average_val_loss": val_loss_recorder.moving_average} - accelerator.log(logs, step=global_step) + # accelerator.log(logs, step=global_step) + accelerator.log(logs) if global_step >= args.max_train_steps: break @@ -1427,7 +1439,7 @@ def remove_model(old_ckpt_name): ) for val_step, batch in enumerate(val_dataloader): - loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) @@ -1437,22 +1449,26 @@ def remove_model(old_ckpt_name): if is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_current": current_loss} - accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + accelerator.log(logs) if is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_average": avr_loss} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) # END OF EPOCH if is_tracking: logs = {"loss/epoch_average": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) if len(val_dataloader) > 0 and is_tracking: avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_epoch_average": avr_loss} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs) accelerator.wait_for_everyone() From 695f38962ce279adfee3fabb3479b84b1076b4e8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:25:12 -0500 Subject: [PATCH 52/76] Move get_huber_threshold_if_needed --- library/train_util.py | 44 ++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0f16a4f31..0907a8c03 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5905,27 +5905,6 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor return timesteps -def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: - if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): - return None - - b_size = timesteps.shape[0] - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - result = torch.exp(-alpha * timesteps) * args.huber_scale - elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): - raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - result = result.to(timesteps.device) - elif args.huber_schedule == "constant": - result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - return result def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: @@ -6004,6 +5983,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch. return noise, noisy_latents, timesteps +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result + + def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: """ Add noise to the latents according to the noise magnitude at each timestep From 1f9ba40b8b70fd08e6b87a70727d5e789666a925 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:32:07 -0500 Subject: [PATCH 53/76] Add step break for validation epoch. Remove unused variable --- train_network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index f870734fd..ce34f26d3 100644 --- a/train_network.py +++ b/train_network.py @@ -1439,6 +1439,9 @@ def remove_model(old_ckpt_name): ) for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() @@ -1447,7 +1450,6 @@ def remove_model(old_ckpt_name): val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) if is_tracking: - avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) From 1c0ae306e551ede5bd162819debb4d80a7fe620b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:43:02 -0500 Subject: [PATCH 54/76] Add missing functions for training batch --- train_network.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index ce34f26d3..377ddf48e 100644 --- a/train_network.py +++ b/train_network.py @@ -318,7 +318,7 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: - + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -1333,6 +1333,11 @@ def remove_model(old_ckpt_name): continue with accelerator.accumulate(training_model): + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) accelerator.backward(loss) if accelerator.sync_gradients: From bbf6bbd5ea27231066cec98b8bf2a65f162cb18f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 10:48:38 -0500 Subject: [PATCH 55/76] Use self.get_noise_pred_and_target and drop fixed timesteps --- flux_train_network.py | 7 ++- sd3_train_network.py | 3 +- train_network.py | 116 ++++++++++++------------------------------ 3 files changed, 40 insertions(+), 86 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 75e975bae..b3aebecc7 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -339,6 +339,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -375,7 +376,7 @@ def get_noise_pred_and_target( def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # if not args.split_mode: # normal forward - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, @@ -420,7 +421,9 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t intermediate_txt.requires_grad_(True) vec.requires_grad_(True) pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + + with torch.set_grad_enabled(is_train and train_unet): + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) """ return model_pred diff --git a/sd3_train_network.py b/sd3_train_network.py index fb7711bda..c7417802d 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -312,6 +312,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -339,7 +340,7 @@ def get_noise_pred_and_target( t5_attn_mask = None # call model - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/train_network.py b/train_network.py index 377ddf48e..61e6369ae 100644 --- a/train_network.py +++ b/train_network.py @@ -223,6 +223,7 @@ def get_noise_pred_and_target( network, weight_dtype, train_unet, + is_train=True ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -236,7 +237,7 @@ def get_noise_pred_and_target( t.requires_grad_(True) # Predict the noise residual - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -317,7 +318,7 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: + def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor: with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -372,91 +373,40 @@ def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: Au batch_size = latents.shape[0] - # Sample noise, - noise = train_util.make_noise(args, latents) - def pick_timesteps_list() -> torch.IntTensor: - if timesteps_list is None or timesteps_list == []: - return typing.cast(torch.IntTensor, train_util.make_random_timesteps(args, noise_scheduler, batch_size, latents.device).unsqueeze(1)) - else: - return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device)) - - chosen_timesteps_list = pick_timesteps_list() - total_loss = torch.zeros((batch_size, 1)).to(latents.device) - - # Use input timesteps_list or use described timesteps above - for fixed_timesteps in chosen_timesteps_list: - fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps) - - # Predict the noise residual - # and add noise to the latents - # with noise offset and/or multires noise if specified - noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps) - - # ensure the hidden state will require grad - if args.gradient_checkpointing: - for x in noisy_latents: - x.requires_grad_(True) - for t in text_encoder_conds: - t.requires_grad_(True) - - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): - noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - fixed_timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps) - else: - target = noise - - # differential output preservation - if "custom_attributes" in batch: - diff_output_pr_indices = [] - for i, custom_attributes in enumerate(batch["custom_attributes"]): - if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: - diff_output_pr_indices.append(i) - - if len(diff_output_pr_indices) > 0: - network.set_multiplier(0.0) - with torch.no_grad(), accelerator.autocast(): - noise_pred_prior = self.call_unet( - args, - accelerator, - unet, - noisy_latents, - fixed_timesteps, - text_encoder_conds, - batch, - weight_dtype, - indices=diff_output_pr_indices, - ) - network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step - target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - - huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし + # Predict the noise residual + # and add noise to the latents + # with noise offset and/or multires noise if specified - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) + # sample noise, call unet, get target + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=is_train + ) - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - loss = loss * loss_weights + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) - loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights - total_loss += loss + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return total_loss / len(chosen_timesteps_list) + return loss.mean() def train(self, args): session_id = random.randint(0, 2**32) @@ -1416,7 +1366,7 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) @@ -1447,7 +1397,7 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) From f4840ef29ef67878d7c7ccec92bdce89c3b61c6d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 10:52:07 -0500 Subject: [PATCH 56/76] Revert train_db.py --- train_db.py | 121 ++-------------------------------------------------- 1 file changed, 3 insertions(+), 118 deletions(-) diff --git a/train_db.py b/train_db.py index 398489ffe..ad21f8d1b 100644 --- a/train_db.py +++ b/train_db.py @@ -2,6 +2,7 @@ # XXX dropped option: fine_tune import argparse +import itertools import math import os from multiprocessing import Value @@ -41,73 +42,11 @@ setup_logging() import logging -import itertools logger = logging.getLogger(__name__) # perlin_noise, -def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): - total_loss = 0.0 - timesteps_list = [10, 350, 500, 650, 990] - - with accelerator.accumulate(*training_models): - with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - with torch.set_grad_enabled(False), accelerator.autocast(): - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - - for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(False), accelerator.autocast(): - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] - timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss: - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss - - average_loss = total_loss / len(timesteps_list) - return average_loss def train(args): train_util.verify_training_args(args) @@ -150,10 +89,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) - val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -274,15 +212,6 @@ def train(args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) - val_dataloader = torch.utils.data.DataLoader( - val_dataset_group if val_dataset_group is not None else [], - shuffle=False, - batch_size=1, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -393,8 +322,6 @@ def train(args): accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() - val_loss_recorder = train_util.LossRecorder() - for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -525,25 +452,6 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break @@ -634,30 +542,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - parser.add_argument( - "--validation_seed", - type=int, - default=None, - help="Validation seed" - ) - parser.add_argument( - "--validation_split", - type=float, - default=0.0, - help="Split for validation images out of the training dataset" - ) - parser.add_argument( - "--validation_every_n_step", - type=int, - default=None, - help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" - ) - parser.add_argument( - "--max_validation_steps", - type=int, - default=None, - help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" - ) + return parser From 1c63e7cc4979b528417b5bfe181e0a9ac119209c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:07:47 -0500 Subject: [PATCH 57/76] Cleanup unused code and formatting --- train_network.py | 85 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 15 deletions(-) diff --git a/train_network.py b/train_network.py index 61e6369ae..5a80d825d 100644 --- a/train_network.py +++ b/train_network.py @@ -318,8 +318,27 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor: - + def process_batch( + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, + tokenize_strategy: strategy_sd.SdTokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True + ) -> torch.Tensor: + """ + Process a batch for the network + """ with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -334,7 +353,6 @@ def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: Au latents = self.shift_scale_latents(args, latents) - text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: @@ -371,13 +389,6 @@ def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: Au if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] - batch_size = latents.shape[0] - - - # Predict the noise residual - # and add noise to the latents - # with noise offset and/or multires noise if specified - # sample noise, call unet, get target noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, @@ -1288,7 +1299,23 @@ def remove_model(old_ckpt_name): # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) + loss = self.process_batch(batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet + ) + accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -1366,12 +1393,26 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) - + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) + val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) - + if is_tracking: logs = {"loss/current_val_loss": loss.detach().item()} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) @@ -1397,7 +1438,21 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) From c64d1a22fc4ff25625873e50d63d480b297301c6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:30:21 -0500 Subject: [PATCH 58/76] Add validate_every_n_epochs, change name validate_every_n_steps --- train_network.py | 69 ++++++++++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/train_network.py b/train_network.py index 5a80d825d..f3c8d8c96 100644 --- a/train_network.py +++ b/train_network.py @@ -1199,7 +1199,8 @@ def load_model_hook(models, input_dir): ) loss_recorder = train_util.LossRecorder() - val_loss_recorder = train_util.LossRecorder() + val_step_loss_recorder = train_util.LossRecorder() + val_epoch_loss_recorder = train_util.LossRecorder() del train_dataset_group if val_dataset_group is not None: @@ -1299,7 +1300,8 @@ def remove_model(old_ckpt_name): # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - loss = self.process_batch(batch, + loss = self.process_batch( + batch, text_encoders, unet, network, @@ -1373,15 +1375,25 @@ def remove_model(old_ckpt_name): if is_tracking: logs = self.generate_step_logs( - args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + args, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer, + keys_scaled, + mean_norm, + maximum_norm ) # accelerator.log(logs, step=global_step) accelerator.log(logs) # VALIDATION PER STEP - should_validate = (args.validation_every_n_step is not None - and global_step % args.validation_every_n_step == 0) - if validation_steps > 0 and should_validate: + should_validate_epoch = ( + args.validate_every_n_steps is not None + and global_step % args.validate_every_n_steps == 0 + ) + if validation_steps > 0 and should_validate_epoch: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( @@ -1409,16 +1421,17 @@ def remove_model(old_ckpt_name): is_train=False ) - val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) + val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/current_val_loss": loss.detach().item()} + logs = {"loss/step_validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) - logs = {"loss/average_val_loss": val_loss_recorder.moving_average} + logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average} # accelerator.log(logs, step=global_step) accelerator.log(logs) @@ -1426,12 +1439,18 @@ def remove_model(old_ckpt_name): break # VALIDATION EPOCH - if len(val_dataloader) > 0: + should_validate_epoch = ( + (epoch + 1) % args.validate_every_n_epochs == 0 + if args.validate_every_n_epochs is not None + else False + ) + + if should_validate_epoch and len(val_dataloader) > 0: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, - desc="validation steps" + desc="epoch validation steps" ) for val_step, batch in enumerate(val_dataloader): @@ -1455,18 +1474,18 @@ def remove_model(old_ckpt_name): ) current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) + val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/validation_current": current_loss} + logs = {"loss/epoch_validation_current": current_loss} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) accelerator.log(logs) if is_tracking: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_average": avr_loss} + avr_loss: float = val_epoch_loss_recorder.moving_average + logs = {"loss/epoch_validation_average": avr_loss} # accelerator.log(logs, step=epoch + 1) accelerator.log(logs) @@ -1475,12 +1494,6 @@ def remove_model(old_ckpt_name): logs = {"loss/epoch_average": loss_recorder.moving_average} # accelerator.log(logs, step=epoch + 1) accelerator.log(logs) - - if len(val_dataloader) > 0 and is_tracking: - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation_epoch_average": avr_loss} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) accelerator.wait_for_everyone() @@ -1676,10 +1689,16 @@ def setup_parser() -> argparse.ArgumentParser: help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" ) parser.add_argument( - "--validation_every_n_step", + "--validate_every_n_steps", + type=int, + default=None, + help="Run validation dataset every N steps" + ) + parser.add_argument( + "--validate_every_n_epochs", type=int, default=None, - help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available" ) parser.add_argument( "--max_validation_steps", From f8850296c83ef2091bf1cb0f6e9ba462adfd9045 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:34:10 -0500 Subject: [PATCH 59/76] Fix validate epoch, cleanup imports --- train_network.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index f3c8d8c96..11bba71e8 100644 --- a/train_network.py +++ b/train_network.py @@ -3,15 +3,13 @@ import math import os import typing -from typing import List, Optional, Union +from typing import Any, List import sys import random import time import json from multiprocessing import Value -from typing import Any, List import toml -import itertools from tqdm import tqdm @@ -23,8 +21,8 @@ from accelerate.utils import set_seed from accelerate import Accelerator -from diffusers import DDPMScheduler, AutoencoderKL -from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers import DDPMScheduler +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util @@ -49,7 +47,6 @@ setup_logging() import logging -import itertools logger = logging.getLogger(__name__) @@ -1442,7 +1439,7 @@ def remove_model(old_ckpt_name): should_validate_epoch = ( (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None - else False + else True ) if should_validate_epoch and len(val_dataloader) > 0: From fcb2ff010cf2e42c50b3745a17317f2d4b4319d9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:39:32 -0500 Subject: [PATCH 60/76] Clean up some validation help documentation --- train_network.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index 11bba71e8..af180c455 100644 --- a/train_network.py +++ b/train_network.py @@ -1677,7 +1677,7 @@ def setup_parser() -> argparse.ArgumentParser: "--validation_seed", type=int, default=None, - help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証シード" + help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" ) parser.add_argument( "--validation_split", @@ -1689,19 +1689,19 @@ def setup_parser() -> argparse.ArgumentParser: "--validate_every_n_steps", type=int, default=None, - help="Run validation dataset every N steps" + help="Run validation on validation dataset every N steps if a validation dataset is available / 検証データセットが利用可能な場合は、Nステップごとに検証データセットの検証を実行します" ) parser.add_argument( "--validate_every_n_epochs", type=int, default=None, - help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available" + help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" ) parser.add_argument( "--max_validation_steps", type=int, default=None, - help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する" + help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" ) return parser From 742bee9738e9d190a39f5a36adf4515fa415e9b7 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 17:34:23 -0500 Subject: [PATCH 61/76] Set validation steps in multiple lines for readability --- train_network.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index af180c455..d0596fcae 100644 --- a/train_network.py +++ b/train_network.py @@ -1251,7 +1251,11 @@ def remove_model(old_ckpt_name): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) + if args.max_validation_steps is not None + else len(val_dataloader) + ) # training loop if initial_step > 0: # only if skip_until_initial_step is specified @@ -1689,7 +1693,7 @@ def setup_parser() -> argparse.ArgumentParser: "--validate_every_n_steps", type=int, default=None, - help="Run validation on validation dataset every N steps if a validation dataset is available / 検証データセットが利用可能な場合は、Nステップごとに検証データセットの検証を実行します" + help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" ) parser.add_argument( "--validate_every_n_epochs", From 1231f5114ccd6a0a26a53da82b89083299ccc333 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 7 Jan 2025 22:31:41 -0500 Subject: [PATCH 62/76] Remove unused train_util code, fix accelerate.log for wandb, add init_trackers library code --- library/train_util.py | 70 ++++++++++++++++--------------------------- train_network.py | 66 ++++++++++++++++++++-------------------- 2 files changed, 59 insertions(+), 77 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0907a8c03..b8894752e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5900,51 +5900,9 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor: +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor: timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - return timesteps - - - - -def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: - """ - Apply noise modifications like noise offset and multires noise - """ - if args.noise_offset: - if args.noise_offset_random_strength: - noise_offset = torch.rand(1, device=latents.device) * args.noise_offset - else: - noise_offset = args.noise_offset - noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) - if args.multires_noise_iterations: - noise = custom_train_functions.pyramid_noise_like( - noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount - ) - return noise - - -def make_noise(args, latents: torch.Tensor) -> torch.FloatTensor: - """ - Make a noise tensor to denoise and apply noise modifications (noise offset, multires noise). See `modify_noise` - """ - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - noise = modify_noise(args, noise, latents) - - return typing.cast(torch.FloatTensor, noise) - - -def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, device: torch.device) -> torch.IntTensor: - """ - From args, produce random timesteps for each image in the batch - """ - min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep - - # Sample a random timestep for each image - timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device) - + timesteps = timesteps.long().to(device) return timesteps @@ -6457,6 +6415,30 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): + """ + Initialize experiment trackers with tracker specific behaviors + """ + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + default_tracker_name if args.log_tracker_name is None else args.log_tracker_name, + config=get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) + + # Define specific metrics to handle validation and epochs "steps" + wandb_tracker.define_metric("epoch", hidden=True) + wandb_tracker.define_metric("val_step", hidden=True) + # endregion diff --git a/train_network.py b/train_network.py index d0596fcae..199f589b0 100644 --- a/train_network.py +++ b/train_network.py @@ -327,8 +327,8 @@ def process_batch( weight_dtype, accelerator, args, - text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, - tokenize_strategy: strategy_sd.SdTokenizeStrategy, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True @@ -1183,17 +1183,7 @@ def load_model_hook(models, input_dir): noise_scheduler = self.get_noise_scheduler(args, accelerator.device) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, - config=train_util.get_sanitized_config_or_none(args), - init_kwargs=init_kwargs, - ) + train_util.init_trackers(accelerator, args, "network_train") loss_recorder = train_util.LossRecorder() val_step_loss_recorder = train_util.LossRecorder() @@ -1386,15 +1376,14 @@ def remove_model(old_ckpt_name): mean_norm, maximum_norm ) - # accelerator.log(logs, step=global_step) - accelerator.log(logs) + accelerator.log(logs, step=global_step) # VALIDATION PER STEP - should_validate_epoch = ( + should_validate_step = ( args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 ) - if validation_steps > 0 and should_validate_epoch: + if validation_steps > 0 and should_validate_step: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( @@ -1406,6 +1395,9 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch( batch, text_encoders, @@ -1428,18 +1420,22 @@ def remove_model(old_ckpt_name): val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/step_validation_current": current_loss} - # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) - accelerator.log(logs) + logs = { + "loss/validation/step/current": current_loss, + "val_step": (epoch * validation_steps) + val_step, + } + accelerator.log(logs, step=global_step) - logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average} - # accelerator.log(logs, step=global_step) - accelerator.log(logs) + if is_tracking: + logs = { + "loss/validation/step/average": val_step_loss_recorder.moving_average, + } + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - # VALIDATION EPOCH + # EPOCH VALIDATION should_validate_epoch = ( (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None @@ -1458,6 +1454,9 @@ def remove_model(old_ckpt_name): if val_step >= validation_steps: break + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch( batch, text_encoders, @@ -1480,21 +1479,22 @@ def remove_model(old_ckpt_name): val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/epoch_validation_current": current_loss} - # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) - accelerator.log(logs) + logs = { + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_steps) + val_step + } + accelerator.log(logs, step=global_step) if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - logs = {"loss/epoch_validation_average": avr_loss} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) + logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1} + accelerator.log(logs, step=global_step) # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) + logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} + accelerator.log(logs, step=global_step) accelerator.wait_for_everyone() From 556f3f1696eadcc16ee77425243b732a84c7a2aa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 13:41:15 -0500 Subject: [PATCH 63/76] Fix documentation, remove unused function, fix bucket reso for sd1.5, fix multiple datasets --- library/config_util.py | 6 +++--- library/train_util.py | 25 ++++--------------------- train_network.py | 5 +---- 3 files changed, 8 insertions(+), 28 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 63d28c969..de1e154a1 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -481,9 +481,9 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) - datasets.append(dataset) + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) + datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: diff --git a/library/train_util.py b/library/train_util.py index b8894752e..62aae37ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -152,11 +152,11 @@ def split_train_val(paths: List[str], is_training_dataset: bool, validation_spli Shuffle the dataset based on the validation_seed or the current random seed. For example if the split of 0.2 of 100 images. - [0:79] = 80 training images + [0:80] = 80 training images [80:] = 20 validation images """ if validation_seed is not None: - print(f"Using validation seed: {validation_seed}") + logging.info(f"Using validation seed: {validation_seed}") prevstate = random.getstate() random.seed(validation_seed) random.shuffle(paths) @@ -5900,8 +5900,8 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") timesteps = timesteps.long().to(device) return timesteps @@ -5964,23 +5964,6 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler return result -def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: - """ - Add noise to the latents according to the noise magnitude at each timestep - (this is the forward diffusion process) - """ - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma - else: - strength = args.ip_noise_gamma - noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) - else: - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - return noisy_latents - - def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None ): diff --git a/train_network.py b/train_network.py index 199f589b0..7dbd12e88 100644 --- a/train_network.py +++ b/train_network.py @@ -125,10 +125,7 @@ def generate_step_logs( return logs def assert_extra_args(self, args, train_dataset_group): - # train_dataset_group.verify_bucket_reso_steps(64) - # TODO: Number of bucket reso steps may differ for each model, so a static number won't work - # and prevents models like SD1.5 with 64 - pass + train_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From 9fde0d797282c0cb9fcea01682e2e6e2eece47bc Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 18:38:20 -0500 Subject: [PATCH 64/76] Handle tuple return from generate_dataset_group_by_blueprint --- fine_tune.py | 4 ++-- flux_train.py | 3 ++- flux_train_control_net.py | 4 ++-- library/config_util.py | 2 +- sd3_train.py | 3 ++- sdxl_train.py | 3 ++- sdxl_train_control_net.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- tools/cache_latents.py | 3 ++- tools/cache_text_encoder_outputs.py | 3 ++- train_control_net.py | 2 +- train_db.py | 3 ++- train_textual_inversion.py | 3 ++- train_textual_inversion_XTI.py | 2 +- 15 files changed, 24 insertions(+), 17 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 176087065..6be2f98ca 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,9 +91,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train.py b/flux_train.py index fced3bef9..6f98adea8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -138,9 +138,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 9d36a41d3..54dec2a77 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -126,9 +126,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/library/config_util.py b/library/config_util.py index de1e154a1..834d6bfaf 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -467,7 +467,7 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value=None): return default_value -def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]: datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: diff --git a/sd3_train.py b/sd3_train.py index 120455e7b..3bff6a50f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -149,9 +149,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train.py b/sdxl_train.py index b9d529243..a60f6df63 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -176,9 +176,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index ffbf03cab..c6e8136f7 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -114,7 +114,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 365059b75..00e51a673 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -123,7 +123,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5b372befc..63457cc61 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -103,7 +103,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/tools/cache_latents.py b/tools/cache_latents.py index c034f949a..515ece98d 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -116,10 +116,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") diff --git a/tools/cache_text_encoder_outputs.py b/tools/cache_text_encoder_outputs.py index 5888b8e3d..00459658e 100644 --- a/tools/cache_text_encoder_outputs.py +++ b/tools/cache_text_encoder_outputs.py @@ -103,10 +103,11 @@ def cache_to_disk(args: argparse.Namespace) -> None: } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None # acceleratorを準備する logger.info("prepare accelerator") diff --git a/train_control_net.py b/train_control_net.py index 177d2b11f..ba016ac5d 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -100,7 +100,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_db.py b/train_db.py index ad21f8d1b..edd674034 100644 --- a/train_db.py +++ b/train_db.py @@ -89,9 +89,10 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 65da4859b..113f35997 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -320,9 +320,10 @@ def train(self, args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None self.assert_extra_args(args, train_dataset_group) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2a2b42310..6ff97d03f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -239,7 +239,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) current_epoch = Value("i", 0) current_step = Value("i", 0) From 1e61392cf2f601e1c66aaede6846ef70f599c34f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 18:43:26 -0500 Subject: [PATCH 65/76] Revert bucket_reso_steps to correct 64 --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 7dbd12e88..7e9f12659 100644 --- a/train_network.py +++ b/train_network.py @@ -125,7 +125,7 @@ def generate_step_logs( return logs def assert_extra_args(self, args, train_dataset_group): - train_dataset_group.verify_bucket_reso_steps(32) + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From d6f158ddf6a3631df7db10ac97453b12de8eadbe Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 8 Jan 2025 18:48:05 -0500 Subject: [PATCH 66/76] Fix incorrect destructoring for load_abritrary_dataset --- fine_tune.py | 3 ++- flux_train_control_net.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 6be2f98ca..e1ed47496 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -93,7 +93,8 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 54dec2a77..cecd00019 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -128,7 +128,8 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: - train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args) + train_dataset_group = train_util.load_arbitrary_dataset(args) + val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) From 264167fa1636c79f106c63c3cdb67b6bee80aceb Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 9 Jan 2025 12:43:58 -0500 Subject: [PATCH 67/76] Apply is_training_dataset only to DreamBoothDataset. Add validation_split check and warning --- library/config_util.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 834d6bfaf..a2e07dc6c 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -471,36 +471,49 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: + extra_dataset_params = {} + if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": True} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.params.validation_split <= 0.0: + if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0: + logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...") + continue + + # if the dataset isn't setting a validation split, there is no current validation dataset + if dataset_blueprint.params.validation_split == 0.0: continue + + extra_dataset_params = {} if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": False} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) val_datasets.append(dataset) def print_info(_datasets, dataset_type: str): From 4c61adc9965df6861ae3705c96143f4299074744 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 13:18:26 -0500 Subject: [PATCH 68/76] Add divergence to logs Divergence is the difference between training and validation to allow a clear value to indicate the difference between the two in the logs. --- train_network.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 7e9f12659..5ed92b7e2 100644 --- a/train_network.py +++ b/train_network.py @@ -1418,14 +1418,16 @@ def remove_model(old_ckpt_name): if is_tracking: logs = { - "loss/validation/step/current": current_loss, + "loss/validation/step_current": current_loss, "val_step": (epoch * validation_steps) + val_step, } accelerator.log(logs, step=global_step) if is_tracking: + loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { - "loss/validation/step/average": val_step_loss_recorder.moving_average, + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, } accelerator.log(logs, step=global_step) @@ -1485,7 +1487,12 @@ def remove_model(old_ckpt_name): if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1} + loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss + logs = { + "loss/validation/epoch_average": avr_loss, + "loss/validation/epoch_divergence": loss_validation_divergence, + "epoch": epoch + 1 + } accelerator.log(logs, step=global_step) # END OF EPOCH From 2bbb40ce51d5be3ce8c3e1990d30455201f9e852 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:29:50 -0500 Subject: [PATCH 69/76] Fix regularization images with validation Adding metadata recording for validation arguments Add comments about the validation split for clarity of intention --- library/train_util.py | 33 +++++++++++++++++++++++++++++++-- train_network.py | 7 +++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 62aae37ef..6d3a772bb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -146,7 +146,12 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" -def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]: +def split_train_val( + paths: List[str], + is_training_dataset: bool, + validation_split: float, + validation_seed: int | None +) -> List[str]: """ Split the dataset into train and validation @@ -1830,6 +1835,9 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + # The is_training_dataset defines the type of dataset, training or validation + # if is_training_dataset is True -> training dataset + # if is_training_dataset is False -> validation dataset def __init__( self, subsets: Sequence[DreamBoothSubset], @@ -1965,8 +1973,29 @@ def load_dreambooth_dir(subset: DreamBoothSubset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + # We want to create a training and validation split. This should be improved in the future + # to allow a clearer distinction between training and validation. This can be seen as a + # short-term solution to limit what is necessary to implement validation datasets + # + # We split the dataset for the subset based on if we are doing a validation split + # The self.is_training_dataset defines the type of dataset, training or validation + # if self.is_training_dataset is True -> training dataset + # if self.is_training_dataset is False -> validation dataset if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed) + # For regularization images we do not want to split this dataset. + if subset.is_reg is True: + # Skip any validation dataset for regularization images + if self.is_training_dataset is False: + img_paths = [] + # Otherwise the img_paths remain as original img_paths and no split + # required for training images dataset of regularization images + else: + img_paths = split_train_val( + img_paths, + self.is_training_dataset, + self.validation_split, + self.validation_seed + ) logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") diff --git a/train_network.py b/train_network.py index 5ed92b7e2..605dbc60c 100644 --- a/train_network.py +++ b/train_network.py @@ -898,6 +898,7 @@ def load_model_hook(models, input_dir): accelerator.print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") @@ -917,6 +918,7 @@ def load_model_hook(models, input_dir): "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0, "ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, @@ -964,6 +966,11 @@ def load_model_hook(models, input_dir): "ss_huber_c": args.huber_c, "ss_fp8_base": bool(args.fp8_base), "ss_fp8_base_unet": bool(args.fp8_base_unet), + "ss_validation_seed": args.validation_seed, + "ss_validation_split": args.validation_split, + "ss_max_validation_steps": args.max_validation_steps, + "ss_validate_every_n_epochs": args.validate_every_n_epochs, + "ss_validate_every_n_steps": args.validate_every_n_steps, } self.update_metadata(metadata, args) # architecture specific metadata From 0456858992909ca0b821ec1b2ca40fa633113224 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:47:49 -0500 Subject: [PATCH 70/76] Fix validate_every_n_steps always running first step --- train_network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_network.py b/train_network.py index 605dbc60c..75e36dca9 100644 --- a/train_network.py +++ b/train_network.py @@ -1385,6 +1385,7 @@ def remove_model(old_ckpt_name): # VALIDATION PER STEP should_validate_step = ( args.validate_every_n_steps is not None + and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) if validation_steps > 0 and should_validate_step: From ee9265cf2678df5c9dfa6c1148d20fb738a9e6ce Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 14:56:35 -0500 Subject: [PATCH 71/76] Fix validate_every_n_steps for gradient accumulation --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 75e36dca9..2f3203c94 100644 --- a/train_network.py +++ b/train_network.py @@ -1388,7 +1388,7 @@ def remove_model(old_ckpt_name): and global_step != 0 # Skip first step and global_step % args.validate_every_n_steps == 0 ) - if validation_steps > 0 and should_validate_step: + if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( From 25929dd0d733144859008479c374968102e5d3a3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 15:38:57 -0500 Subject: [PATCH 72/76] Remove Validating... print to fix output layout --- train_network.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/train_network.py b/train_network.py index 2f3203c94..e7d93a108 100644 --- a/train_network.py +++ b/train_network.py @@ -1389,8 +1389,6 @@ def remove_model(old_ckpt_name): and global_step % args.validate_every_n_steps == 0 ) if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: - accelerator.print("Validating バリデーション処理...") - val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, @@ -1450,7 +1448,6 @@ def remove_model(old_ckpt_name): ) if should_validate_epoch and len(val_dataloader) > 0: - accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, From b489082495ba6779385f282797227799413715f5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 12 Jan 2025 16:42:04 -0500 Subject: [PATCH 73/76] Disable repeats for validation datasets --- library/train_util.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 6d3a772bb..4d143c373 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2055,9 +2055,10 @@ def load_dreambooth_dir(subset: DreamBoothSubset): num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] for subset in subsets: - if subset.num_repeats < 1: + num_repeats = subset.num_repeats if self.is_training_dataset else 1 + if num_repeats < 1: logger.warning( - f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {num_repeats}" ) continue @@ -2075,12 +2076,12 @@ def load_dreambooth_dir(subset: DreamBoothSubset): continue if subset.is_reg: - num_reg_images += subset.num_repeats * len(img_paths) + num_reg_images += num_repeats * len(img_paths) else: - num_train_images += subset.num_repeats * len(img_paths) + num_train_images += num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) if size is not None: info.image_size = size if subset.is_reg: From c04e5dfe92250a4790dc5f6e092cd85809a4e81d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 Jan 2025 09:57:24 -0500 Subject: [PATCH 74/76] Fix loss recorder on 0. Fix validation for cached runs. Assert on validation dataset --- flux_train_network.py | 8 +++++--- library/train_util.py | 8 +++++++- requirements.txt | 1 + sd3_train_network.py | 11 ++++++++--- sdxl_train_network.py | 8 +++++--- sdxl_train_textual_inversion.py | 5 +++-- train_network.py | 16 +++++++++++----- train_textual_inversion.py | 9 ++++++--- 8 files changed, 46 insertions(+), 20 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index b3aebecc7..5cd1b9d51 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -2,7 +2,7 @@ import copy import math import random -from typing import Any, Optional +from typing import Any, Optional, Union import torch from accelerate import Accelerator @@ -36,8 +36,8 @@ def __init__(self): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) if args.fp8_base_unet: @@ -80,6 +80,8 @@ def assert_extra_args(self, args, train_dataset_group): args.blocks_to_swap = 18 # 18 is safe for most cases train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) # TODO check this def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models diff --git a/library/train_util.py b/library/train_util.py index 4d143c373..56fea4a8c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2893,6 +2893,9 @@ def __getitem__(self, idx): """ raise NotImplementedError + def get_resolutions(self) -> List[Tuple[int, int]]: + return [] + def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: module = ".".join(args.dataset_class.split(".")[:-1]) @@ -6520,4 +6523,7 @@ def add(self, *, epoch: int, step: int, loss: float) -> None: @property def moving_average(self) -> float: - return self.loss_total / len(self.loss_list) + losses = len(self.loss_list) + if losses == 0: + return 0 + return self.loss_total / losses diff --git a/requirements.txt b/requirements.txt index e0091749a..de39f5887 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ voluptuous==0.13.1 huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 +numpy<=2.0 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 diff --git a/sd3_train_network.py b/sd3_train_network.py index c7417802d..dcf497f53 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -2,7 +2,7 @@ import copy import math import random -from typing import Any, Optional +from typing import Any, Optional, Union import torch from accelerate import Accelerator @@ -26,7 +26,7 @@ def __init__(self): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -56,9 +56,14 @@ def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) # TODO check this # enumerate resolutions from dataset for positional embeddings - self.resolutions = train_dataset_group.get_resolutions() + resolutions = train_dataset_group.get_resolutions() + if val_dataset_group is not None: + resolutions = resolutions + val_dataset_group.get_resolutions() + self.resolutions = resolutions def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d45df6e05..eb09831ec 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,5 +1,5 @@ import argparse -from typing import List, Optional +from typing import List, Optional, Union import torch from accelerate import Accelerator @@ -23,8 +23,8 @@ def __init__(self): self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: @@ -37,6 +37,8 @@ def assert_extra_args(self, args, train_dataset_group): ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 821a69558..bf56faf34 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -18,11 +18,12 @@ def __init__(self): self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) train_dataset_group.verify_bucket_reso_steps(32) + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( diff --git a/train_network.py b/train_network.py index e7d93a108..2c3bb2aae 100644 --- a/train_network.py +++ b/train_network.py @@ -3,7 +3,7 @@ import math import os import typing -from typing import Any, List +from typing import Any, List, Union, Optional import sys import random import time @@ -124,8 +124,10 @@ def generate_step_logs( return logs - def assert_extra_args(self, args, train_dataset_group): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): train_dataset_group.verify_bucket_reso_steps(64) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) @@ -512,7 +514,7 @@ def train(self, args): val_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - self.assert_extra_args(args, train_dataset_group) # may change some args + self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args # acceleratorを準備する logger.info("preparing accelerator") @@ -1414,7 +1416,9 @@ def remove_model(old_ckpt_name): args, text_encoding_strategy, tokenize_strategy, - is_train=False + is_train=False, + train_text_encoder=False, + train_unet=False ) current_loss = loss.detach().item() @@ -1474,7 +1478,9 @@ def remove_model(old_ckpt_name): args, text_encoding_strategy, tokenize_strategy, - is_train=False + is_train=False, + train_text_encoder=False, + train_unet=False ) current_loss = loss.detach().item() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 113f35997..0c6568b08 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -2,7 +2,7 @@ import math import os from multiprocessing import Value -from typing import Any, List +from typing import Any, List, Optional, Union import toml from tqdm import tqdm @@ -99,9 +99,12 @@ def __init__(self): self.vae_scale_factor = 0.18215 self.is_sdxl = False - def assert_extra_args(self, args, train_dataset_group): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): train_dataset_group.verify_bucket_reso_steps(64) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(64) + def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet @@ -325,7 +328,7 @@ def train(self, args): train_dataset_group = train_util.load_arbitrary_dataset(args) val_dataset_group = None - self.assert_extra_args(args, train_dataset_group) + self.assert_extra_args(args, train_dataset_group, val_dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) From 58b82a576e32c2157e476840339ddafa98222dfc Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 Jan 2025 21:21:21 +0900 Subject: [PATCH 75/76] Fix to work with validation dataset --- library/train_util.py | 1 + sdxl_train_textual_inversion.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 56fea4a8c..37ed0a994 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2403,6 +2403,7 @@ def __init__( self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, + True, batch_size, resolution, network_multiplier, diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index bf56faf34..982007601 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -1,5 +1,6 @@ import argparse import os +from typing import Optional, Union import regex @@ -23,7 +24,8 @@ def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetG sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) train_dataset_group.verify_bucket_reso_steps(32) - val_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( From e8529613d8a06ce91d3b304bccf85a172b1b4b31 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 Jan 2025 21:27:22 +0900 Subject: [PATCH 76/76] README.md: Update recent updates section to include validation loss support for training scripts --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 4dff15440..053354103 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,13 @@ The command to install PyTorch is as follows: ### Recent Updates +Jan 25, 2025: + +- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO! + - For details on how to set it up, please refer to the PR. The documentation will be updated as needed. + - It will be added to other scripts as well. + - As a current limitation, validation loss is not supported when `--block_to_swap` is specified. + Dec 15, 2024: - RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu!