From febb8655718ca22b1f6b3a29a60c32c07ce672ae Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 11 Dec 2023 22:41:23 -0600 Subject: [PATCH 01/22] WIP: Multi-dataset sampler --- helpers/data_backend/factory.py | 128 ++++++++++++++++++++++++++++ helpers/multiaspect/dataset.py | 74 ++++------------ helpers/multiaspect/multisampler.py | 47 ++++++++++ helpers/multiaspect/sampler.py | 26 ++++-- helpers/training/state_tracker.py | 16 +++- 5 files changed, 224 insertions(+), 67 deletions(-) create mode 100644 helpers/data_backend/factory.py create mode 100644 helpers/multiaspect/multisampler.py diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py new file mode 100644 index 00000000..9adf8902 --- /dev/null +++ b/helpers/data_backend/factory.py @@ -0,0 +1,128 @@ +from helpers.data_backend.local import LocalDataBackend +from helpers.data_backend.aws import S3DataBackend + +import json, os + + +def configure_multi_databackend(args: dict, accelerator): + """ + Configure a multiple dataloaders based on the provided commandline args. + """ + if args.data_backend_config is None: + raise ValueError( + "Must provide a data backend config file via --data_backend_config" + ) + if not os.path.exists(args.data_backend_config): + raise FileNotFoundError( + f"Data backend config file {args.data_backend_config} not found." + ) + with open(args.data_backend_config, "r") as f: + data_backend_config = json.load(f) + if len(data_backend_config) == 0: + raise ValueError( + "Must provide at least one data backend in the data backend config file." + ) + data_backends = [] + for backend in data_backend_config: + if backend["type"] == "local": + data_backends.append(get_local_backend(accelerator)) + elif backend["type"] == "aws": + check_aws_config(backend) + data_backends.append( + get_aws_backend( + aws_bucket_name=backend["aws_bucket_name"], + aws_region_name=backend["aws_region_name"], + aws_endpoint_url=backend["aws_endpoint_url"], + aws_access_key_id=backend["aws_access_key_id"], + aws_secret_access_key=backend["aws_secret_access_key"], + accelerator=accelerator, + ) + ) + else: + raise ValueError(f"Unknown data backend type: {backend['type']}") + if len(data_backends) == 0: + raise ValueError( + "Must provide at least one data backend in the data backend config file." + ) + return data_backends + + +def get_local_backend(accelerator) -> LocalDataBackend: + """ + Get a local disk backend. + + Args: + accelerator (Accelerator): A Huggingface Accelerate object. + Returns: + LocalDataBackend: A LocalDataBackend object. + """ + return LocalDataBackend(accelerator=accelerator) + + +def check_aws_config(backend: dict) -> None: + """ + Check the configuration for an AWS backend. + + Args: + backend (dict): A dictionary of the backend configuration. + Returns: + None + """ + required_keys = [ + "aws_bucket_name", + "aws_region_name", + "aws_endpoint_url", + "aws_access_key_id", + "aws_secret_access_key", + ] + for key in required_keys: + if key not in backend: + raise ValueError(f"Missing required key {key} in AWS backend config.") + + +def get_aws_backend( + aws_bucket_name: str, + aws_region_name: str, + aws_endpoint_url: str, + aws_access_key_id: str, + aws_secret_access_key: str, + accelerator, +) -> S3DataBackend: + return S3DataBackend( + bucket_name=aws_bucket_name, + accelerator=accelerator, + region_name=aws_region_name, + endpoint_url=aws_endpoint_url, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + + +def get_dataset(args: dict, accelerator) -> list: + """Retrieve a dataset based on the provided commandline args. + + Args: + args (dict): A dictionary from parseargs. + accelerator (Accelerator): A Huggingface Accelerate object. + Returns: + list: A list of DataBackend objects. + """ + if args.data_backend == "multi": + return configure_multi_databackend(args) + elif args.data_backend == "local": + if not os.path.exists(args.instance_data_dir): + raise FileNotFoundError( + f"Instance {args.instance_data_root} images root doesn't exist. Cannot continue." + ) + return [get_local_backend(args, accelerator)] + elif args.data_backend == "aws": + return [ + get_aws_backend( + aws_bucket_name=args.aws_bucket_name, + aws_region_name=args.aws_region_name, + aws_endpoint_url=args.aws_endpoint_url, + aws_access_key_id=args.aws_access_key_id, + aws_secret_access_key=args.aws_secret_access_key, + accelerator=accelerator, + ) + ] diff --git a/helpers/multiaspect/dataset.py b/helpers/multiaspect/dataset.py index c5a5eeb2..92050019 100644 --- a/helpers/multiaspect/dataset.py +++ b/helpers/multiaspect/dataset.py @@ -2,11 +2,6 @@ from pathlib import Path import logging, os -from helpers.multiaspect.image import MultiaspectImage -from helpers.data_backend.base import BaseDataBackend -from helpers.multiaspect.bucket import BucketManager -from helpers.prompts import PromptHandler - logger = logging.getLogger("MultiAspectDataset") logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "WARNING")) @@ -20,78 +15,41 @@ class MultiAspectDataset(Dataset): def __init__( self, - instance_data_root, - accelerator, - bucket_manager: BucketManager, - data_backend: BaseDataBackend, - instance_prompt: str = None, - tokenizer=None, - aspect_ratio_buckets=[1.0, 1.5, 0.67, 0.75, 1.78], - size=1024, + datasets: list, print_names=False, - use_captions=True, - prepend_instance_prompt=False, - use_original_images=False, - caption_dropout_interval: int = 0, - use_precomputed_token_ids: bool = True, - debug_dataset_loader: bool = False, - caption_strategy: str = "filename", - return_tensor: bool = False, - size_type: str = "pixel", ): - self.prepend_instance_prompt = prepend_instance_prompt - self.bucket_manager = bucket_manager - self.data_backend = data_backend - self.use_captions = use_captions - self.size = size - self.size_type = size_type - self.tokenizer = tokenizer + self.datasets = datasets self.print_names = print_names - self.debug_dataset_loader = debug_dataset_loader - self.instance_data_root = Path(instance_data_root) - self.instance_prompt = instance_prompt - self.aspect_ratio_buckets = aspect_ratio_buckets - self.use_original_images = use_original_images - self.accelerator = accelerator - self.caption_dropout_interval = caption_dropout_interval - self.caption_loop_count = 0 - self.caption_strategy = caption_strategy - self.use_precomputed_token_ids = use_precomputed_token_ids - logger.debug(f"Building transformations.") - self.image_transforms = MultiaspectImage.get_image_transforms() - self.return_tensor = return_tensor def __len__(self): - return len(self.bucket_manager) + # Sum the length of all data backends: + return sum([len(dataset) for dataset in self.datasets]) def __getitem__(self, image_tuple): output_data = [] for sample in image_tuple: - image_path = sample["image_path"] - logger.debug(f"Running __getitem__ for {image_path} inside Dataloader.") - image_metadata = self.bucket_manager.get_metadata_by_filepath(image_path) - image_metadata["image_path"] = image_path + logger.debug( + f"Running __getitem__ for {image_metadata['image_path']} inside Dataloader." + ) + image_metadata = sample if ( image_metadata["original_size"] is None or image_metadata["target_size"] is None ): raise Exception( - f"Metadata was unavailable for image: {image_path}. Ensure --skip_file_discovery=metadata is not set." - f" Metadata: {self.bucket_manager.get_metadata_by_filepath(image_path)}" + f"Metadata was unavailable for image: {image_metadata['image_path']}. Ensure --skip_file_discovery=metadata is not set." ) if self.print_names: - logger.info(f"Dataset is now using image: {image_path}") + logger.info( + f"Dataset is now using image: {image_metadata['image_path']}" + ) - # Use the magic prompt handler to retrieve the captions. - image_metadata["instance_prompt_text"] = PromptHandler.magic_prompt( - data_backend=self.data_backend, - image_path=image_path, - caption_strategy=self.caption_strategy, - use_captions=self.use_captions, - prepend_instance_prompt=self.prepend_instance_prompt, - ) + if "instance_prompt_text" not in image_metadata: + raise ValueError( + f"Instance prompt text must be provided in image metadata. Image metadata: {image_metadata}" + ) output_data.append(image_metadata) return output_data diff --git a/helpers/multiaspect/multisampler.py b/helpers/multiaspect/multisampler.py new file mode 100644 index 00000000..d99df1a3 --- /dev/null +++ b/helpers/multiaspect/multisampler.py @@ -0,0 +1,47 @@ +# A class to act as a wrapper for multiple MultiAspectSampler objects, feeding samples from them in proportion. +from helpers.multiaspect.bucket import BucketManager +from helpers.data_backend.base import BaseDataBackend +from helpers.multiaspect.sampler import MultiAspectSampler + + +class MultiSampler: + def __init__( + self, + bucket_manager: BucketManager, + data_backend: BaseDataBackend, + accelerator, + args: dict, + ): + self.batch_size = args.train_batch_size + self.seen_images_path = args.seen_state_path + self.state_path = args.state_path + self.debug_aspect_buckets = args.debug_aspect_buckets + self.delete_unwanted_images = args.delete_unwanted_images + self.resolution = args.resolution + self.resolution_type = args.resolution_type + self.args = args + + def configure(self): + if self.args.data_backend is None: + raise ValueError("Must provide a data backend via --data_backend") + if self.args.data_backend != "multi": + # Return a basic MultiAspectSampler for the single data backend: + self.sampler = self.get_single_sampler() + return + # Configure a multi-aspect sampler: + + def get_single_sampler(self) -> list: + """ + Get a single MultiAspectSampler object. + """ + return [ + MultiAspectSampler( + batch_size=self.batch_size, + seen_images_path=self.seen_images_path, + state_path=self.state_path, + debug_aspect_buckets=self.debug_aspect_buckets, + delete_unwanted_images=self.delete_unwanted_images, + resolution=self.resolution, + resolution_type=self.resolution_type, + ) + ] diff --git a/helpers/multiaspect/sampler.py b/helpers/multiaspect/sampler.py index 1c8da371..93d9c51d 100644 --- a/helpers/multiaspect/sampler.py +++ b/helpers/multiaspect/sampler.py @@ -7,6 +7,7 @@ from helpers.multiaspect.state import BucketStateManager from helpers.data_backend.base import BaseDataBackend from helpers.training.state_tracker import StateTracker +from helpers.prompts import PromptHandler from accelerate.logging import get_logger logger = get_logger( @@ -38,6 +39,9 @@ def __init__( minimum_image_size: int = None, resolution: int = 1024, resolution_type: str = "pixel", + caption_strategy: str = "filename", + use_captions=True, + prepend_instance_prompt=False, ): """ Initializes the sampler with provided settings. @@ -64,6 +68,9 @@ def __init__( self.minimum_image_size = minimum_image_size self.resolution = resolution self.resolution_type = resolution_type + self.use_captions = use_captions + self.caption_strategy = caption_strategy + self.prepend_instance_prompt = prepend_instance_prompt self.load_states( state_path=state_path, ) @@ -308,17 +315,26 @@ def _validate_and_yield_images_from_samples(self, samples, bucket): self.debug_log( f"Begin analysing sample. We have {len(to_yield)} images to yield." ) - crop_coordinates = self.bucket_manager.get_metadata_attribute_by_filepath( - image_path, "crop_coordinates" - ) - if crop_coordinates is None: + image_metadata = self.bucket_manager.get_metadata_by_filepath(image_path) + if "crop_coordinates" not in image_metadata: raise Exception( f"An image was discovered ({image_path}) that did not have its metadata: {self.bucket_manager.get_metadata_by_filepath(image_path)}" ) self.debug_log( f"Image {image_path} is considered valid. Adding to yield list." ) - to_yield.append({"image_path": image_path}) + image_metadata["image_path"] = image_path + + # Use the magic prompt handler to retrieve the captions. + image_metadata["instance_prompt_text"] = PromptHandler.magic_prompt( + data_backend=self.data_backend, + image_path=image_metadata["image_path"], + caption_strategy=self.caption_strategy, + use_captions=self.use_captions, + prepend_instance_prompt=self.prepend_instance_prompt, + ) + + to_yield.append(image_metadata) self.debug_log( f"Completed analysing sample. We have {len(to_yield)} images to yield." ) diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 1c97d22c..13839fec 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -20,7 +20,7 @@ class StateTracker: vaecache = None embedcache = None accelerator = None - bucket_manager = None + bucket_managers = [] vae = None vae_dtype = None weight_dtype = None @@ -148,11 +148,11 @@ def get_vae_dtype(cls): @classmethod def set_bucket_manager(cls, bucket_manager): - cls.bucket_manager = bucket_manager + cls.bucket_managers.append(bucket_manager) @classmethod - def get_bucket_manager(cls): - return cls.bucket_manager + def get_bucket_managers(cls): + return cls.bucket_managers @classmethod def set_weight_dtype(cls, weight_dtype): @@ -185,3 +185,11 @@ def set_embedcache(cls, embedcache): @classmethod def get_embedcache(cls): return cls.embedcache + + @classmethod + def get_metadata_by_filepath(cls, filepath): + for bucket_manager in cls.get_bucket_managers(): + metadata = bucket_manager.get_metadata_by_filepath(filepath) + if metadata is not None: + return metadata + return None From 64f10c631b0671a7fbffcab384aee859725efc55 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 13:48:47 -0600 Subject: [PATCH 02/22] Revert "remove multi-dataset work" This reverts commit 6c2f8b03c0ad27b886718ff327ee7f59f410ebb4. --- train_sdxl.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/train_sdxl.py b/train_sdxl.py index 23e1c688..ab348bab 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -26,6 +26,8 @@ from helpers.multiaspect.dataset import MultiAspectDataset from helpers.multiaspect.bucket import BucketManager from helpers.multiaspect.sampler import MultiAspectSampler + +# from helpers.multiaspect.factory import configure_multi_dataset from helpers.training.state_tracker import StateTracker from helpers.training.collate import collate_fn from helpers.training.deepspeed import deepspeed_zero_init_disabled_context_manager @@ -260,7 +262,7 @@ def main(): data_backend = LocalDataBackend(accelerator=accelerator) if not os.path.exists(args.instance_data_dir): raise FileNotFoundError( - f"Instance {args.instance_data_root} images root doesn't exist. Cannot continue." + f"Instance {args.instance_data_dir} images root doesn't exist. Cannot continue." ) elif args.data_backend == "aws": from helpers.data_backend.aws import S3DataBackend @@ -276,8 +278,6 @@ def main(): else: raise ValueError(f"Unsupported data backend: {args.data_backend}") - # Get the datasets: you can either provide your own training and evaluation files (see below) - # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # Bucket manager. We keep the aspect config in the dataset so that switching datasets is simpler. bucket_manager = BucketManager( instance_data_root=args.instance_data_dir, @@ -420,18 +420,10 @@ def print_bucket_info(bucket_manager): # Data loader train_dataset = MultiAspectDataset( - bucket_manager=bucket_manager, - data_backend=data_backend, - instance_data_root=args.instance_data_dir, - accelerator=accelerator, - size=args.resolution, - size_type=args.resolution_type, print_names=args.print_filenames or False, - prepend_instance_prompt=args.prepend_instance_prompt or False, - use_captions=not args.only_instance_prompt or False, - use_precomputed_token_ids=True, - debug_dataset_loader=args.debug_dataset_loader, - caption_strategy=args.caption_strategy, + datasets=configure_multi_dataset( + args, accelerator + ), # We need to store the list of datasets inside the MAD so that it knows their lengths. ) logger.info("Creating aspect bucket sampler") custom_balanced_sampler = MultiAspectSampler( From ed45b64d86178dea781391727af16ac5c9c4b1a3 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 13:49:16 -0600 Subject: [PATCH 03/22] Revert "remove import from future work" This reverts commit 8abbb5c732430ab970566dfe318073e5c5fc4b8e. --- train_sdxl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train_sdxl.py b/train_sdxl.py index ab348bab..ac9f24cc 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -26,8 +26,7 @@ from helpers.multiaspect.dataset import MultiAspectDataset from helpers.multiaspect.bucket import BucketManager from helpers.multiaspect.sampler import MultiAspectSampler - -# from helpers.multiaspect.factory import configure_multi_dataset +from helpers.multiaspect.factory import configure_multi_dataset from helpers.training.state_tracker import StateTracker from helpers.training.collate import collate_fn from helpers.training.deepspeed import deepspeed_zero_init_disabled_context_manager From 11bc9022230a58b9a048825de957da2fb0fbf98e Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 15:57:06 -0600 Subject: [PATCH 04/22] WIP: multi-backend --- helpers/caching/vae.py | 11 +- helpers/data_backend/aws.py | 2 + helpers/data_backend/factory.py | 219 ++++++++++++++++++++++++++++-- helpers/data_backend/local.py | 3 +- helpers/multiaspect/bucket.py | 14 +- helpers/multiaspect/dataset.py | 4 +- helpers/multiaspect/image.py | 2 +- helpers/multiaspect/sampler.py | 7 + helpers/training/state_tracker.py | 16 ++- multidatabackend.example.json | 25 ++++ train_sdxl.py | 160 +--------------------- 11 files changed, 279 insertions(+), 184 deletions(-) create mode 100644 multidatabackend.example.json diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index f108ed33..29583fd8 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -26,6 +26,7 @@ class VAECache: def __init__( self, + id: str, vae, accelerator, bucket_manager: BucketManager, @@ -41,6 +42,11 @@ def __init__( resolution_type: str = "pixel", minimum_image_size: int = None, ): + self.id = id + if data_backend.id != id: + raise ValueError( + f"VAECache received incorrect data_backend: {data_backend}" + ) self.data_backend = data_backend self.vae = vae self.accelerator = accelerator @@ -322,7 +328,7 @@ def _process_images_in_batch(self) -> None: filepaths.append(filepath) self.debug_log(f"Processing {filepath}") if self.minimum_image_size is not None: - if not BucketManager.meets_resolution_requirements( + if not self.bucket_manager.meets_resolution_requirements( image_path=filepath, minimum_image_size=self.minimum_image_size, resolution_type=self.resolution_type, @@ -339,8 +345,7 @@ def _process_images_in_batch(self) -> None: ) self.vae_input_queue.put((pixel_values, filepath)) # Update the crop_coordinates in the metadata document - bucket_manager = StateTracker.get_bucket_manager() - bucket_manager.set_metadata_attribute_by_filepath( + self.bucket_manager.set_metadata_attribute_by_filepath( filepath=filepath, attribute="crop_coordinates", value=crop_coordinates, diff --git a/helpers/data_backend/aws.py b/helpers/data_backend/aws.py index 4c4d9527..fdf72854 100644 --- a/helpers/data_backend/aws.py +++ b/helpers/data_backend/aws.py @@ -42,6 +42,7 @@ class S3DataBackend(BaseDataBackend): def __init__( self, + id: str, bucket_name, accelerator, region_name="us-east-1", @@ -53,6 +54,7 @@ def __init__( read_retry_interval: int = 5, write_retry_interval: int = 5, ): + self.id = id self.accelerator = accelerator self.bucket_name = bucket_name self.read_retry_limit = read_retry_limit diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 9adf8902..967653b1 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -1,7 +1,29 @@ from helpers.data_backend.local import LocalDataBackend from helpers.data_backend.aws import S3DataBackend -import json, os +from helpers.multiaspect.bucket import BucketManager +from helpers.multiaspect.dataset import MultiAspectDataset +from helpers.multiaspect.sampler import MultiAspectSampler +from helpers.prompts import PromptHandler +from helpers.caching.vae import VAECache +from helpers.training.multi_process import rank_info +from helpers.training.collate import collate_fn +from helpers.training.state_tracker import StateTracker + +import json, os, torch + + +def print_bucket_info(bucket_manager): + # Print table header + print(f"{rank_info()} | {'Bucket':<10} | {'Image Count':<12}") + + # Print separator + print("-" * 30) + + # Print each bucket's information + for bucket in bucket_manager.aspect_ratio_bucket_indices: + image_count = len(bucket_manager.aspect_ratio_bucket_indices[bucket]) + print(f"{rank_info()} | {bucket:<10} | {image_count:<12}") def configure_multi_databackend(args: dict, accelerator): @@ -22,24 +44,190 @@ def configure_multi_databackend(args: dict, accelerator): raise ValueError( "Must provide at least one data backend in the data backend config file." ) - data_backends = [] + data_backends = {} + all_captions = [] for backend in data_backend_config: + # For each backend, we will create a dict to store all of its components in. + if "id" not in backend: + raise ValueError( + "No identifier was given for one more of your data backends. Add a unique 'id' field to each one." + ) + init_backend = {"id": backend["id"]} if backend["type"] == "local": - data_backends.append(get_local_backend(accelerator)) + init_backend["data_backend"] = get_local_backend( + accelerator, init_backend["id"] + ) + init_backend["instance_data_root"] = backend["instance_data_dir"] elif backend["type"] == "aws": check_aws_config(backend) - data_backends.append( - get_aws_backend( - aws_bucket_name=backend["aws_bucket_name"], - aws_region_name=backend["aws_region_name"], - aws_endpoint_url=backend["aws_endpoint_url"], - aws_access_key_id=backend["aws_access_key_id"], - aws_secret_access_key=backend["aws_secret_access_key"], - accelerator=accelerator, - ) + init_backend["data_backend"] = get_aws_backend( + identifier=init_backend["id"], + aws_bucket_name=backend["aws_bucket_name"], + aws_region_name=backend["aws_region_name"], + aws_endpoint_url=backend["aws_endpoint_url"], + aws_access_key_id=backend["aws_access_key_id"], + aws_secret_access_key=backend["aws_secret_access_key"], + accelerator=accelerator, ) + # S3 buckets use the aws_data_prefix as their prefix/ for all data. + init_backend["instance_data_root"] = backend["aws_data_prefix"] else: raise ValueError(f"Unknown data backend type: {backend['type']}") + + init_backend["bucket_manager"] = BucketManager( + id=init_backend["id"], + instance_data_root=init_backend["instance_data_root"], + data_backend=init_backend["data_backend"], + accelerator=accelerator, + resolution=backend["resolution"] or args.resolution, + minimum_image_size=backend["minimum_image_size"] or args.minimum_image_size, + resolution_type=backend["resolution_type"] or args.resolution_type, + batch_size=args.train_batch_size, + metadata_update_interval=backend["metadata_update_interval"] + or args.metadata_update_interval, + cache_file=os.path.join( + init_backend["instance_data_root"], "aspect_ratio_bucket_indices.json" + ), + metadata_file=os.path.join( + init_backend["instance_data_root"], "aspect_ratio_bucket_metadata.json" + ), + delete_problematic_images=args.delete_problematic_images or False, + ) + if init_backend["bucket_manager"].has_single_underfilled_bucket(): + raise Exception( + f"Cannot train using a dataset that has a single bucket with fewer than {args.train_batch_size} images." + f" You have to reduce your batch size, or increase your dataset size (id={init_backend['id']})." + ) + if "aspect" not in args.skip_file_discovery: + if accelerator.is_local_main_process: + init_backend["bucket_manager"].refresh_buckets(rank_info()) + accelerator.wait_for_everyone() + init_backend["bucket_manager"].reload_cache() + # Now split the contents of these buckets between all processes + init_backend["bucket_manager"].split_buckets_between_processes( + gradient_accumulation_steps=args.gradient_accumulation_steps, + ) + + print_bucket_info(init_backend["bucket_manager"]) + if len(init_backend["bucket_manager"]) == 0: + raise Exception( + "No images were discovered by the bucket manager in the dataset." + ) + + use_captions = True + if "only_instance_prompt" in backend and backend["only_instance_prompt"]: + use_captions = False + elif args.only_instance_prompt: + use_captions = False + caption_strategy = args.caption_strategy + if "caption_strategy" in backend: + caption_strategy = backend["caption_strategy"] + init_backend["train_dataset"] = MultiAspectDataset( + id=init_backend["id"], + bucket_manager=init_backend["bucket_manager"], + data_backend=init_backend["data_backend"], + instance_data_root=init_backend["instance_data_root"], + accelerator=accelerator, + size=backend["resolution"] or args.resolution, + size_type=backend["resolution_type"] or args.resolution_type, + print_names=args.print_filenames or False, + prepend_instance_prompt=backend["prepend_instance_prompt"] + or args.prepend_instance_prompt + or False, + use_captions=use_captions, + use_precomputed_token_ids=True, + debug_dataset_loader=args.debug_dataset_loader, + caption_strategy=caption_strategy, + ) + + # full filename path: + seen_state_path = args.seen_state_path + # split the filename by extension, append init_backend["id"] to the end of the filename, reassemble with extension: + seen_state_path = ".".join( + seen_state_path.split(".")[:-1] + + [init_backend["id"], seen_state_path.split(".")[-1]] + ) + state_path = args.state_path + state_path = ".".join( + state_path.split(".")[:-1] + [init_backend["id"], state_path.split(".")[-1]] + ) + + init_backend["sampler"] = MultiAspectSampler( + id=init_backend["id"], + bucket_manager=init_backend["bucket_manager"], + data_backend=init_backend["data_backend"], + accelerator=accelerator, + batch_size=args.train_batch_size, + seen_images_path=backend["seen_state_path"] or seen_state_path, + state_path=backend["state_path"] or state_path, + debug_aspect_buckets=args.debug_aspect_buckets, + delete_unwanted_images=backend["delete_unwanted_images"] + or args.delete_unwanted_images, + resolution=backend["resolution"] or args.resolution, + resolution_type=backend["resolution_type"] or args.resolution_type, + ) + + init_backend["train_dataloader"] = torch.utils.data.DataLoader( + init_backend["train_dataset"], + batch_size=1, # The sampler handles batching + shuffle=False, # The sampler handles shuffling + sampler=init_backend["sampler"], + collate_fn=lambda examples: collate_fn(examples), + num_workers=0, + persistent_workers=False, + ) + + with accelerator.main_process_first(): + all_captions.append( + PromptHandler.get_all_captions( + data_backend=init_backend["data_backend"], + instance_data_root=init_backend["instance_data_root"], + prepend_instance_prompt=backend["prepend_instance_prompt"] + or args.prepend_instance_prompt + or False, + use_captions=use_captions, + ) + ) + + logger.info(f"Pre-computing VAE latent space.") + init_backend["vaecache"] = VAECache( + id=init_backend["id"], + vae=StateTracker.get_vae(), + accelerator=accelerator, + bucket_manager=init_backend["bucket_manager"], + data_backend=init_backend["data_backend"], + instance_data_root=init_backend["instance_data_root"], + delete_problematic_images=backend["delete_problematic_images"] + or args.delete_problematic_images, + resolution=backend["resolution"] or args.resolution, + resolution_type=backend["resolution_type"] or args.resolution_type, + minimum_image_size=backend["minimum_image_size"] or args.minimum_image_size, + vae_batch_size=args.vae_batch_size, + write_batch_size=args.write_batch_size, + cache_dir=backend["cache_dir_vae"] or args.cache_dir_vae, + ) + + if accelerator.is_local_main_process: + init_backend["vaecache"].discover_all_files() + accelerator.wait_for_everyone() + + if "metadata" not in args.skip_file_discovery and accelerator.is_main_process: + init_backend["bucket_manager"].scan_for_metadata() + accelerator.wait_for_everyone() + if not accelerator.is_main_process: + init_backend["bucket_manager"].load_image_metadata() + accelerator.wait_for_everyone() + + if "vae" not in args.skip_file_discovery: + init_backend["vaecache"].split_cache_between_processes() + init_backend["vaecache"].process_buckets() + accelerator.wait_for_everyone() + + StateTracker.register_backend(init_backend) + + # After configuring all backends, register their captions. + StateTracker.set_caption_files(all_captions) + if len(data_backends) == 0: raise ValueError( "Must provide at least one data backend in the data backend config file." @@ -47,16 +235,17 @@ def configure_multi_databackend(args: dict, accelerator): return data_backends -def get_local_backend(accelerator) -> LocalDataBackend: +def get_local_backend(accelerator, identifier: str) -> LocalDataBackend: """ Get a local disk backend. Args: accelerator (Accelerator): A Huggingface Accelerate object. + identifier (str): An identifier that links this data backend to its other components. Returns: LocalDataBackend: A LocalDataBackend object. """ - return LocalDataBackend(accelerator=accelerator) + return LocalDataBackend(accelerator=accelerator, id=identifier) def check_aws_config(backend: dict) -> None: @@ -87,8 +276,10 @@ def get_aws_backend( aws_access_key_id: str, aws_secret_access_key: str, accelerator, + identifier: str, ) -> S3DataBackend: return S3DataBackend( + id=identifier, bucket_name=aws_bucket_name, accelerator=accelerator, region_name=aws_region_name, diff --git a/helpers/data_backend/local.py b/helpers/data_backend/local.py index abf23fa2..2116e5ad 100644 --- a/helpers/data_backend/local.py +++ b/helpers/data_backend/local.py @@ -10,8 +10,9 @@ class LocalDataBackend(BaseDataBackend): - def __init__(self, accelerator): + def __init__(self, accelerator, id: str): self.accelerator = accelerator + self.id = id def read(self, filepath, as_byteIO: bool = False): """Read and return the content of the file.""" diff --git a/helpers/multiaspect/bucket.py b/helpers/multiaspect/bucket.py index 31fd61b7..5046afb1 100644 --- a/helpers/multiaspect/bucket.py +++ b/helpers/multiaspect/bucket.py @@ -18,6 +18,7 @@ class BucketManager: def __init__( self, + id: str, instance_data_root: str, cache_file: str, metadata_file: str, @@ -30,6 +31,11 @@ def __init__( metadata_update_interval: int = 3600, minimum_image_size: int = None, ): + self.id = id + if self.id != data_backend.id: + raise ValueError( + f"BucketManager ID ({self.id}) must match the DataBackend ID ({data_backend.id})." + ) self.accelerator = accelerator self.data_backend = data_backend self.batch_size = batch_size @@ -434,7 +440,7 @@ def _enforce_resolution_constraints(self, bucket): self.aspect_ratio_bucket_indices[bucket] = [ img for img in images - if BucketManager.meets_resolution_requirements( + if self.meets_resolution_requirements( image_path=img, minimum_image_size=self.minimum_image_size, resolution_type=self.resolution_type, @@ -442,8 +448,8 @@ def _enforce_resolution_constraints(self, bucket): ) ] - @staticmethod def meets_resolution_requirements( + self, image_path: str = None, image: Image = None, minimum_image_size: int = None, @@ -453,9 +459,7 @@ def meets_resolution_requirements( Check if an image meets the resolution requirements. """ if image is None and image_path is not None: - metadata = StateTracker.get_bucket_manager().get_metadata_by_filepath( - image_path - ) + metadata = self.get_metadata_by_filepath(image_path) if metadata is None: logger.warning(f"Metadata not found for image {image_path}.") return False diff --git a/helpers/multiaspect/dataset.py b/helpers/multiaspect/dataset.py index 92050019..54450a1a 100644 --- a/helpers/multiaspect/dataset.py +++ b/helpers/multiaspect/dataset.py @@ -15,9 +15,11 @@ class MultiAspectDataset(Dataset): def __init__( self, + id: str, datasets: list, print_names=False, ): + self.id = id self.datasets = datasets self.print_names = print_names @@ -28,10 +30,10 @@ def __len__(self): def __getitem__(self, image_tuple): output_data = [] for sample in image_tuple: + image_metadata = sample logger.debug( f"Running __getitem__ for {image_metadata['image_path']} inside Dataloader." ) - image_metadata = sample if ( image_metadata["original_size"] is None diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index 21993fe5..00cb2aa2 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -56,7 +56,7 @@ def process_for_bucket( logger.debug( f"Image {image_path_str} has aspect ratio {aspect_ratio} and size {image.size}." ) - if not StateTracker.get_bucket_manager().meets_resolution_requirements( + if not bucket_manager.meets_resolution_requirements( image=image, minimum_image_size=minimum_image_size, resolution_type=resolution_type, diff --git a/helpers/multiaspect/sampler.py b/helpers/multiaspect/sampler.py index 93d9c51d..52ce1de4 100644 --- a/helpers/multiaspect/sampler.py +++ b/helpers/multiaspect/sampler.py @@ -28,6 +28,7 @@ class MultiAspectSampler(torch.utils.data.Sampler): def __init__( self, + id: str, bucket_manager: BucketManager, data_backend: BaseDataBackend, accelerator, @@ -46,6 +47,7 @@ def __init__( """ Initializes the sampler with provided settings. Parameters: + - id: An identifier to link this with its VAECache and DataBackend objects. - bucket_manager: An initialised instance of BucketManager. - batch_size: Number of samples to draw per batch. - seen_images_path: Path to store the seen images. @@ -54,6 +56,11 @@ def __init__( - delete_unwanted_images: Flag to decide whether to delete unwanted (small) images or just remove from the bucket. - minimum_image_size: The minimum pixel length of the smallest side of an image. """ + self.id = id + if self.id != data_backend.id or self.id != bucket_manager.id: + raise ValueError( + f"Sampler ID ({self.id}) must match DataBackend ID ({data_backend.id}) and BucketManager ID ({bucket_manager.id})." + ) self.rank_info = rank_info() self.accelerator = accelerator self.bucket_manager = bucket_manager diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 13839fec..031150d3 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -21,6 +21,7 @@ class StateTracker: embedcache = None accelerator = None bucket_managers = [] + backends = {} vae = None vae_dtype = None weight_dtype = None @@ -115,12 +116,19 @@ def has_caption_files_loaded(cls): return len(list(cls.all_caption_files.keys())) > 0 @classmethod - def set_data_backend(cls, data_backend): - cls.data_backend = data_backend + def register_backend(cls, data_backend): + cls.backends[data_backend["id"]] = data_backend @classmethod - def get_data_backend(cls): - return cls.data_backend + def get_data_backend(cls, id: str): + return cls.backends[id] + + @classmethod + def get_bucket_manager(cls, id: str): + for bucket_manager in cls.bucket_managers: + if bucket_manager.id == id: + return bucket_manager + return None @classmethod def set_accelerator(cls, accelerator): diff --git a/multidatabackend.example.json b/multidatabackend.example.json new file mode 100644 index 00000000..c4551c93 --- /dev/null +++ b/multidatabackend.example.json @@ -0,0 +1,25 @@ +[ + { + "id": "something-special-to-remember-by", + "type": "local", + "instance_data_dir": "/path/to/data/tree", + "crop": false, + "resolution": 1.0, + "resolution_type": "area", + "minimum_image_size": 1.0, + "prepend_instance_prompt": false, + "instance_prompt": "cat girls", + "only_instance_prompt": false, + "caption_strategy": "filename" + }, + { + "id": "another-special-name-for-another-backend", + "type": "aws", + "aws_bucket_name": "something-yummy", + "aws_region_name": null, + "aws_endpoint_url": "https://foo.bar/", + "aws_access_key_id": "wpz-764e9734523434", + "aws_secret_access_key": "xyz-sdajkhfhakhfjd", + "aws_data_prefix": "" + } +] \ No newline at end of file diff --git a/train_sdxl.py b/train_sdxl.py index ac9f24cc..6f6458c4 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -23,25 +23,16 @@ from pathlib import Path from helpers.arguments import parse_args from helpers.legacy.validation import prepare_validation_prompt_list, log_validations -from helpers.multiaspect.dataset import MultiAspectDataset -from helpers.multiaspect.bucket import BucketManager -from helpers.multiaspect.sampler import MultiAspectSampler -from helpers.multiaspect.factory import configure_multi_dataset from helpers.training.state_tracker import StateTracker -from helpers.training.collate import collate_fn from helpers.training.deepspeed import deepspeed_zero_init_disabled_context_manager - +from helpers.data_backend.factory import configure_multi_databackend from helpers.caching.vae import VAECache from helpers.caching.sdxl_embeds import TextEmbeddingCache -from helpers.image_manipulation.brightness import ( - calculate_luminance, -) from helpers.training.custom_schedule import ( get_polynomial_decay_schedule_with_warmup, generate_timestep_weights, ) from helpers.training.min_snr_gamma import compute_snr -from helpers.training.multi_process import rank_info from helpers.prompts import PromptHandler from accelerate.logging import get_logger @@ -255,81 +246,7 @@ def main(): ) # Create a DataBackend, so that we can access our dataset. - if args.data_backend == "local": - from helpers.data_backend.local import LocalDataBackend - - data_backend = LocalDataBackend(accelerator=accelerator) - if not os.path.exists(args.instance_data_dir): - raise FileNotFoundError( - f"Instance {args.instance_data_dir} images root doesn't exist. Cannot continue." - ) - elif args.data_backend == "aws": - from helpers.data_backend.aws import S3DataBackend - - data_backend = S3DataBackend( - bucket_name=args.aws_bucket_name, - accelerator=accelerator, - region_name=args.aws_region_name, - endpoint_url=args.aws_endpoint_url, - aws_access_key_id=args.aws_access_key_id, - aws_secret_access_key=args.aws_secret_access_key, - ) - else: - raise ValueError(f"Unsupported data backend: {args.data_backend}") - - # Bucket manager. We keep the aspect config in the dataset so that switching datasets is simpler. - bucket_manager = BucketManager( - instance_data_root=args.instance_data_dir, - data_backend=data_backend, - accelerator=accelerator, - resolution=args.resolution, - minimum_image_size=args.minimum_image_size, - resolution_type=args.resolution_type, - batch_size=args.train_batch_size, - metadata_update_interval=args.metadata_update_interval, - cache_file=os.path.join( - args.instance_data_dir, "aspect_ratio_bucket_indices.json" - ), - metadata_file=os.path.join( - args.instance_data_dir, "aspect_ratio_bucket_metadata.json" - ), - delete_problematic_images=args.delete_problematic_images or False, - ) - StateTracker.set_bucket_manager(bucket_manager) - if bucket_manager.has_single_underfilled_bucket(): - raise Exception( - f"Cannot train using a dataset that has a single bucket with fewer than {args.train_batch_size} images." - " You have to reduce your batch size, or increase your dataset size." - ) - if "aspect" not in args.skip_file_discovery: - if accelerator.is_local_main_process: - bucket_manager.refresh_buckets(rank_info()) - accelerator.wait_for_everyone() - bucket_manager.reload_cache() - # Now split the contents of these buckets between all processes - bucket_manager.split_buckets_between_processes( - gradient_accumulation_steps=args.gradient_accumulation_steps, - ) - - # Now, let's print the total of each bucket, along with the current rank, so that we might catch debug info: - def print_bucket_info(bucket_manager): - # Print table header - print(f"{rank_info()} | {'Bucket':<10} | {'Image Count':<12}") - - # Print separator - print("-" * 30) - - # Print each bucket's information - for bucket in bucket_manager.aspect_ratio_bucket_indices: - image_count = len(bucket_manager.aspect_ratio_bucket_indices[bucket]) - print(f"{rank_info()} | {bucket:<10} | {image_count:<12}") - - print_bucket_info(bucket_manager) - - if len(bucket_manager) == 0: - raise Exception( - "No images were discovered by the bucket manager in the dataset." - ) + configure_multi_databackend(args, accelerator) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. @@ -417,35 +334,6 @@ def print_bucket_info(bucket_manager): text_encoder_1.requires_grad_(False) text_encoder_2.requires_grad_(False) - # Data loader - train_dataset = MultiAspectDataset( - print_names=args.print_filenames or False, - datasets=configure_multi_dataset( - args, accelerator - ), # We need to store the list of datasets inside the MAD so that it knows their lengths. - ) - logger.info("Creating aspect bucket sampler") - custom_balanced_sampler = MultiAspectSampler( - bucket_manager=bucket_manager, - data_backend=data_backend, - accelerator=accelerator, - batch_size=args.train_batch_size, - seen_images_path=args.seen_state_path, - state_path=args.state_path, - debug_aspect_buckets=args.debug_aspect_buckets, - delete_unwanted_images=args.delete_unwanted_images, - resolution=args.resolution, - resolution_type=args.resolution_type, - ) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=1, # The sample handles batching - shuffle=False, # The sampler handles shuffling - sampler=custom_balanced_sampler, - collate_fn=lambda examples: collate_fn(examples), - num_workers=0, - persistent_workers=False, - ) prompt_handler = None if not args.disable_compel: prompt_handler = PromptHandler( @@ -479,14 +367,8 @@ def print_bucket_info(bucket_manager): if "text" not in args.skip_file_discovery: logger.info(f"Pre-computing text embeds / updating cache.") - with accelerator.main_process_first(): - all_captions = PromptHandler.get_all_captions( - data_backend=data_backend, - instance_data_root=args.instance_data_dir, - prepend_instance_prompt=args.prepend_instance_prompt or False, - use_captions=not args.only_instance_prompt, - ) - StateTracker.set_caption_files(all_captions) + # Captions are extracted from datasets during `configure_multi_databackend(...)` + all_captions = StateTracker.get_caption_files() if accelerator.is_main_process: embed_cache.compute_embeddings_for_sdxl_prompts( all_captions, return_concat=False @@ -814,41 +696,9 @@ def print_bucket_info(bucket_manager): else: logger.debug(f"Initialising VAE with custom dtype {vae_dtype}") vae.to(accelerator.device, dtype=vae_dtype) - logger.info(f"Loaded VAE into VRAM.") - logger.info(f"Pre-computing VAE latent space.") - vaecache = VAECache( - vae=vae, - accelerator=accelerator, - bucket_manager=bucket_manager, - data_backend=data_backend, - instance_data_root=args.instance_data_dir, - delete_problematic_images=args.delete_problematic_images, - resolution=args.resolution, - resolution_type=args.resolution_type, - minimum_image_size=args.minimum_image_size, - vae_batch_size=args.vae_batch_size, - write_batch_size=args.write_batch_size, - cache_dir=args.cache_dir_vae, - ) - StateTracker.set_vaecache(vaecache) StateTracker.set_vae_dtype(vae_dtype) StateTracker.set_vae(vae) - - if accelerator.is_local_main_process: - vaecache.discover_all_files() - accelerator.wait_for_everyone() - - if "metadata" not in args.skip_file_discovery and accelerator.is_main_process: - bucket_manager.scan_for_metadata() - accelerator.wait_for_everyone() - if not accelerator.is_main_process: - bucket_manager.load_image_metadata() - accelerator.wait_for_everyone() - - if "vae" not in args.skip_file_discovery: - vaecache.split_cache_between_processes() - vaecache.process_buckets() - accelerator.wait_for_everyone() + logger.info(f"Loaded VAE into VRAM.") # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( From da38cc21d3f476d4888ba1a5f61b7db1acfdacea Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 16:00:03 -0600 Subject: [PATCH 05/22] WIP: remove unnecessary code --- helpers/multiaspect/multisampler.py | 47 ----------------------------- 1 file changed, 47 deletions(-) delete mode 100644 helpers/multiaspect/multisampler.py diff --git a/helpers/multiaspect/multisampler.py b/helpers/multiaspect/multisampler.py deleted file mode 100644 index d99df1a3..00000000 --- a/helpers/multiaspect/multisampler.py +++ /dev/null @@ -1,47 +0,0 @@ -# A class to act as a wrapper for multiple MultiAspectSampler objects, feeding samples from them in proportion. -from helpers.multiaspect.bucket import BucketManager -from helpers.data_backend.base import BaseDataBackend -from helpers.multiaspect.sampler import MultiAspectSampler - - -class MultiSampler: - def __init__( - self, - bucket_manager: BucketManager, - data_backend: BaseDataBackend, - accelerator, - args: dict, - ): - self.batch_size = args.train_batch_size - self.seen_images_path = args.seen_state_path - self.state_path = args.state_path - self.debug_aspect_buckets = args.debug_aspect_buckets - self.delete_unwanted_images = args.delete_unwanted_images - self.resolution = args.resolution - self.resolution_type = args.resolution_type - self.args = args - - def configure(self): - if self.args.data_backend is None: - raise ValueError("Must provide a data backend via --data_backend") - if self.args.data_backend != "multi": - # Return a basic MultiAspectSampler for the single data backend: - self.sampler = self.get_single_sampler() - return - # Configure a multi-aspect sampler: - - def get_single_sampler(self) -> list: - """ - Get a single MultiAspectSampler object. - """ - return [ - MultiAspectSampler( - batch_size=self.batch_size, - seen_images_path=self.seen_images_path, - state_path=self.state_path, - debug_aspect_buckets=self.debug_aspect_buckets, - delete_unwanted_images=self.delete_unwanted_images, - resolution=self.resolution, - resolution_type=self.resolution_type, - ) - ] From 26777607ed5e1f4f5f3be60655c6cbd5025ada70 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 16:30:07 -0600 Subject: [PATCH 06/22] ignore config --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9f151550..eb7fc069 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ Thumbs.db files.tbz */config/auth.json work +multidatabackend.json sd21-env.sh sdxl-env.sh From 31443cab5566f992286b1bf9d6ab4a525885efbb Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 17:28:39 -0600 Subject: [PATCH 07/22] WIP refactoring commandline args and fixing issues --- helpers/arguments.py | 119 +++++------------------------ helpers/data_backend/factory.py | 84 ++++++++++---------- helpers/multiaspect/sampler.py | 4 +- helpers/multiaspect/state.py | 15 +++- helpers/training/state_tracker.py | 18 +++-- multidatabackend.example.json | 6 +- tests/test_bucket.py | 2 + tests/test_dataset.py | 34 +++++---- tests/test_sampler.py | 3 + train_sdxl.py | 122 ++++++++++++++++++------------ train_sdxl.sh | 2 +- 11 files changed, 186 insertions(+), 223 deletions(-) diff --git a/helpers/arguments.py b/helpers/arguments.py index c73ee142..41afdce2 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -214,17 +214,17 @@ def parse_args(input_args=None): help=( "This is the path to a local directory that will contain your VAE outputs." " Unlike the text embed cache, your VAE latents will be stored in the AWS data backend." - " If the AWS backend is in use, this will be a prefix for the bucket's VAE cache entries." + " Each backend can have its own value, but if that is not provided, this will be the default value." ), ) parser.add_argument( - "--data_backend", + "--data_backend_config", type=str, - default="local", - choices=["local", "aws"], + default=None, + required=True, help=( - "The data backend to use. Choose between ['local', 'aws']. Default: local." - " If using AWS, you must set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables." + "The relative or fully-qualified path for your data backend config." + " See multidatabackend.json.example for an example." ), ) parser.add_argument( @@ -237,60 +237,6 @@ def parse_args(input_args=None): " This mostly applies to S3, but some shared server filesystems may benefit as well, eg. Ceph. Default: 64." ), ) - parser.add_argument( - "--aws_config_file", - type=str, - default=None, - help=( - "Path to the AWS configuration file in JSON format." - " Config key names are the same as SimpleTuner option counterparts." - ), - ) - parser.add_argument( - "--aws_bucket_name", - type=str, - default=None, - help="The AWS bucket name to use.", - ) - parser.add_argument( - "--aws_bucket_image_prefix", - type=str, - default="", - help=( - "Instead of using --instance_data_dir, AWS S3 relies on aws_bucket_*_prefix parameters." - " When provided, this parameter will be prepended to the image path." - ), - ) - parser.add_argument( - "--aws_endpoint_url", - type=str, - default=None, - help=( - "The AWS server to use. If not specified, will use the default server for the region specified." - " For Wasabi, use https://s3.wasabisys.com." - ), - ) - parser.add_argument( - "--aws_region_name", - type=str, - default="us-east-1", - help=( - "The AWS region to use. If not specified, will use the default region for the server specified." - " For example, if you specify 's3.amazonaws.com', the default region will be 'us-east-1'." - ), - ) - parser.add_argument( - "--aws_access_key_id", - type=str, - default=None, - help="The AWS access key ID.", - ) - parser.add_argument( - "--aws_secret_access_key", - type=str, - default=None, - help="The AWS secret access key.", - ) parser.add_argument( "--cache_dir", type=str, @@ -1000,50 +946,19 @@ def parse_args(input_args=None): if args.non_ema_revision is None: args.non_ema_revision = args.revision - if args.aws_config_file is not None: - try: - with open(args.aws_config_file, "r") as f: - aws_config = json.load(f) - except Exception as e: - raise ValueError(f"Could not load AWS config file: {e}") - if not isinstance(aws_config, dict): - raise ValueError("AWS config file must be a JSON object.") - args.aws_bucket_name = aws_config.get("aws_bucket_name", args.aws_bucket_name) - args.aws_bucket_image_prefix = aws_config.get("aws_bucket_image_prefix", "") - args.aws_endpoint_url = aws_config.get( - "aws_endpoint_url", args.aws_endpoint_url - ) - args.aws_region_name = aws_config.get("aws_region_name", args.aws_region_name) - args.aws_access_key_id = aws_config.get( - "aws_access_key_id", args.aws_access_key_id - ) - args.aws_secret_access_key = aws_config.get( - "aws_secret_access_key", args.aws_secret_access_key - ) if args.cache_dir is None or args.cache_dir == "": args.cache_dir = os.path.join(args.output_dir, "cache") - if args.data_backend == "aws": - if args.aws_bucket_name is None: - raise ValueError("Must specify an AWS bucket name.") - if args.aws_endpoint_url is None and args.aws_region_name is None: - raise ValueError("Must specify an AWS endpoint URL or region name.") - if args.aws_access_key_id is None: - raise ValueError("Must specify an AWS access key ID.") - if args.aws_secret_access_key is None: - raise ValueError("Must specify an AWS secret access key.") - # Override the instance data dir with the bucket image prefix. - args.instance_data_dir = args.aws_bucket_image_prefix - else: - if args.cache_dir_vae is None or args.cache_dir_vae == "": - args.cache_dir_vae = os.path.join(args.output_dir, "cache_vae") - if args.cache_dir_text is None or args.cache_dir_text == "": - args.cache_dir_text = os.path.join(args.output_dir, "cache_text") - for target_dir in [ - Path(args.cache_dir), - Path(args.cache_dir_vae), - Path(args.cache_dir_text), - ]: - os.makedirs(target_dir, exist_ok=True) + + if args.cache_dir_vae is None or args.cache_dir_vae == "": + args.cache_dir_vae = os.path.join(args.output_dir, "cache_vae") + if args.cache_dir_text is None or args.cache_dir_text == "": + args.cache_dir_text = os.path.join(args.output_dir, "cache_text") + for target_dir in [ + Path(args.cache_dir), + Path(args.cache_dir_vae), + Path(args.cache_dir_text), + ]: + os.makedirs(target_dir, exist_ok=True) logger.info(f"VAE Cache location: {args.cache_dir_vae}") logger.info(f"Text Cache location: {args.cache_dir_text}") diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 967653b1..fda9d8ea 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -10,7 +10,10 @@ from helpers.training.collate import collate_fn from helpers.training.state_tracker import StateTracker -import json, os, torch +import json, os, torch, logging + +logger = logging.getLogger("DataBackendFactory") +logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) def print_bucket_info(bucket_manager): @@ -44,7 +47,6 @@ def configure_multi_databackend(args: dict, accelerator): raise ValueError( "Must provide at least one data backend in the data backend config file." ) - data_backends = {} all_captions = [] for backend in data_backend_config: # For each backend, we will create a dict to store all of its components in. @@ -79,12 +81,15 @@ def configure_multi_databackend(args: dict, accelerator): instance_data_root=init_backend["instance_data_root"], data_backend=init_backend["data_backend"], accelerator=accelerator, - resolution=backend["resolution"] or args.resolution, - minimum_image_size=backend["minimum_image_size"] or args.minimum_image_size, - resolution_type=backend["resolution_type"] or args.resolution_type, + resolution=backend.get("resolution", args.resolution), + minimum_image_size=backend.get( + "minimum_image_size", args.minimum_image_size + ), + resolution_type=backend.get("resolution_type", args.resolution_type), batch_size=args.train_batch_size, - metadata_update_interval=backend["metadata_update_interval"] - or args.metadata_update_interval, + metadata_update_interval=backend.get( + "metadata_update_interval", args.metadata_update_interval + ), cache_file=os.path.join( init_backend["instance_data_root"], "aspect_ratio_bucket_indices.json" ), @@ -119,25 +124,9 @@ def configure_multi_databackend(args: dict, accelerator): use_captions = False elif args.only_instance_prompt: use_captions = False - caption_strategy = args.caption_strategy - if "caption_strategy" in backend: - caption_strategy = backend["caption_strategy"] init_backend["train_dataset"] = MultiAspectDataset( id=init_backend["id"], - bucket_manager=init_backend["bucket_manager"], - data_backend=init_backend["data_backend"], - instance_data_root=init_backend["instance_data_root"], - accelerator=accelerator, - size=backend["resolution"] or args.resolution, - size_type=backend["resolution_type"] or args.resolution_type, - print_names=args.print_filenames or False, - prepend_instance_prompt=backend["prepend_instance_prompt"] - or args.prepend_instance_prompt - or False, - use_captions=use_captions, - use_precomputed_token_ids=True, - debug_dataset_loader=args.debug_dataset_loader, - caption_strategy=caption_strategy, + datasets=[init_backend["bucket_manager"]], ) # full filename path: @@ -158,13 +147,19 @@ def configure_multi_databackend(args: dict, accelerator): data_backend=init_backend["data_backend"], accelerator=accelerator, batch_size=args.train_batch_size, - seen_images_path=backend["seen_state_path"] or seen_state_path, - state_path=backend["state_path"] or state_path, + seen_images_path=backend.get("seen_state_path", seen_state_path), + state_path=backend.get("state_path", state_path), debug_aspect_buckets=args.debug_aspect_buckets, - delete_unwanted_images=backend["delete_unwanted_images"] - or args.delete_unwanted_images, - resolution=backend["resolution"] or args.resolution, - resolution_type=backend["resolution_type"] or args.resolution_type, + delete_unwanted_images=backend.get( + "delete_unwanted_images", args.delete_unwanted_images + ), + resolution=backend.get("resolution", args.resolution), + resolution_type=backend.get("resolution_type", args.resolution_type), + caption_strategy=backend.get("caption_strategy", args.caption_strategy), + use_captions=use_captions, + prepend_instance_prompt=backend.get( + "prepend_instance_prompt", args.prepend_instance_prompt + ), ) init_backend["train_dataloader"] = torch.utils.data.DataLoader( @@ -178,13 +173,13 @@ def configure_multi_databackend(args: dict, accelerator): ) with accelerator.main_process_first(): - all_captions.append( + all_captions.extend( PromptHandler.get_all_captions( data_backend=init_backend["data_backend"], instance_data_root=init_backend["instance_data_root"], - prepend_instance_prompt=backend["prepend_instance_prompt"] - or args.prepend_instance_prompt - or False, + prepend_instance_prompt=backend.get( + "prepend_instance_prompt", args.prepend_instance_prompt + ), use_captions=use_captions, ) ) @@ -197,14 +192,17 @@ def configure_multi_databackend(args: dict, accelerator): bucket_manager=init_backend["bucket_manager"], data_backend=init_backend["data_backend"], instance_data_root=init_backend["instance_data_root"], - delete_problematic_images=backend["delete_problematic_images"] - or args.delete_problematic_images, - resolution=backend["resolution"] or args.resolution, - resolution_type=backend["resolution_type"] or args.resolution_type, - minimum_image_size=backend["minimum_image_size"] or args.minimum_image_size, + delete_problematic_images=backend.get( + "delete_problematic_images", args.delete_problematic_images + ), + resolution=backend.get("resolution", args.resolution), + resolution_type=backend.get("resolution_type", args.resolution_type), + minimum_image_size=backend.get( + "minimum_image_size", args.minimum_image_size + ), vae_batch_size=args.vae_batch_size, write_batch_size=args.write_batch_size, - cache_dir=backend["cache_dir_vae"] or args.cache_dir_vae, + cache_dir=backend.get("cache_dir_vae", args.cache_dir_vae), ) if accelerator.is_local_main_process: @@ -223,16 +221,16 @@ def configure_multi_databackend(args: dict, accelerator): init_backend["vaecache"].process_buckets() accelerator.wait_for_everyone() - StateTracker.register_backend(init_backend) + StateTracker.register_data_backend(init_backend) # After configuring all backends, register their captions. StateTracker.set_caption_files(all_captions) - if len(data_backends) == 0: + if len(StateTracker.get_data_backends()) == 0: raise ValueError( "Must provide at least one data backend in the data backend config file." ) - return data_backends + return StateTracker.get_data_backends() def get_local_backend(accelerator, identifier: str) -> LocalDataBackend: diff --git a/helpers/multiaspect/sampler.py b/helpers/multiaspect/sampler.py index 52ce1de4..ad3a7d03 100644 --- a/helpers/multiaspect/sampler.py +++ b/helpers/multiaspect/sampler.py @@ -101,7 +101,9 @@ def save_state(self, state_path: str = None): def load_states(self, state_path: str): try: - self.state_manager = BucketStateManager(state_path, self.seen_images_path) + self.state_manager = BucketStateManager( + self.id, state_path, self.seen_images_path + ) self.buckets = self.load_buckets() previous_state = self.state_manager.load_state() except Exception as e: diff --git a/helpers/multiaspect/state.py b/helpers/multiaspect/state.py index 1b5c61c1..a78dc5d1 100644 --- a/helpers/multiaspect/state.py +++ b/helpers/multiaspect/state.py @@ -6,10 +6,19 @@ class BucketStateManager: - def __init__(self, state_path, seen_images_path): - self.state_path = state_path + def __init__(self, id: str, state_path, seen_images_path): + self.id = id + self.state_path = self.mangle_state_path(state_path) + # seen_images_path is pre-mangled by the dataset factory self.seen_images_path = seen_images_path + def mangle_state_path(self, state_path): + # When saving the state, it goes into the checkpoint dir. + # However, we need to save a single state for each data backend. + # Thus, we split the state_path from its extension, add self.id to the end of the name, and rejoin: + filename, ext = os.path.splitext(state_path) + return f"{filename}-{self.id}{ext}" + def load_seen_images(self): if os.path.exists(self.seen_images_path): with open(self.seen_images_path, "r") as f: @@ -35,6 +44,8 @@ def save_state(self, state: dict, state_path: str = None): final_state = state if state_path is None: state_path = self.state_path + else: + state_path = self.mangle_state_path(state_path) logger.debug(f"Type of state: {type(state)}") final_state = self.deep_convert_dict(state) logger.info(f"Saving trainer state to {state_path}") diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 031150d3..219ad4f2 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -21,7 +21,7 @@ class StateTracker: embedcache = None accelerator = None bucket_managers = [] - backends = {} + data_backends = {} vae = None vae_dtype = None weight_dtype = None @@ -116,18 +116,22 @@ def has_caption_files_loaded(cls): return len(list(cls.all_caption_files.keys())) > 0 @classmethod - def register_backend(cls, data_backend): - cls.backends[data_backend["id"]] = data_backend + def register_data_backend(cls, data_backend): + cls.data_backends[data_backend["id"]] = data_backend @classmethod def get_data_backend(cls, id: str): - return cls.backends[id] + return cls.data_backends[id] + + @classmethod + def get_data_backends(cls): + return cls.data_backends @classmethod def get_bucket_manager(cls, id: str): - for bucket_manager in cls.bucket_managers: - if bucket_manager.id == id: - return bucket_manager + for data_backend in cls.data_backends: + if data_backend["id"] == id: + return data_backend["bucket_manager"] return None @classmethod diff --git a/multidatabackend.example.json b/multidatabackend.example.json index c4551c93..fab50000 100644 --- a/multidatabackend.example.json +++ b/multidatabackend.example.json @@ -10,7 +10,8 @@ "prepend_instance_prompt": false, "instance_prompt": "cat girls", "only_instance_prompt": false, - "caption_strategy": "filename" + "caption_strategy": "filename", + "cache_dir_vae": "cache_prefix" }, { "id": "another-special-name-for-another-backend", @@ -20,6 +21,7 @@ "aws_endpoint_url": "https://foo.bar/", "aws_access_key_id": "wpz-764e9734523434", "aws_secret_access_key": "xyz-sdajkhfhakhfjd", - "aws_data_prefix": "" + "aws_data_prefix": "", + "cache_dir_vae": "/path/to/cache/dir" } ] \ No newline at end of file diff --git a/tests/test_bucket.py b/tests/test_bucket.py index a24753ed..2c68b212 100644 --- a/tests/test_bucket.py +++ b/tests/test_bucket.py @@ -8,6 +8,7 @@ class TestBucketManager(unittest.TestCase): def setUp(self): self.data_backend = MockDataBackend() + self.data_backend.id = "foo" self.accelerator = Mock() self.data_backend.exists = Mock(return_value=True) self.data_backend.write = Mock(return_value=True) @@ -25,6 +26,7 @@ def setUp(self): ), patch("pathlib.Path.exists", return_value=True): with self.assertLogs("BucketManager", level="WARNING"): self.bucket_manager = BucketManager( + id="foo", instance_data_root=self.instance_data_root, cache_file=self.cache_file, metadata_file=self.metadata_file, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 5482dc1d..5fd20cd8 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -14,6 +14,7 @@ def setUp(self): self.bucket_manager = Mock(spec=BucketManager) self.bucket_manager.__len__ = Mock(return_value=10) self.image_metadata = { + "image_path": "fake_image_path", "original_size": (16, 8), "crop_coordinates": (0, 0), "target_size": (16, 8), @@ -28,18 +29,14 @@ def setUp(self): # Mock the Path.exists method to return True with patch("pathlib.Path.exists", return_value=True): self.dataset = MultiAspectDataset( - instance_data_root=self.instance_data_root, - accelerator=self.accelerator, - bucket_manager=self.bucket_manager, - data_backend=self.data_backend, + id="foo", + datasets=[range(10)], ) def test_init_invalid_instance_data_root(self): MultiAspectDataset( - instance_data_root="/invalid/path", - accelerator=self.accelerator, - bucket_manager=self.bucket_manager, - data_backend=self.data_backend, + id="foo", + datasets=[range(10)], ) def test_len(self): @@ -54,7 +51,19 @@ def test_getitem_valid_image(self): # Create a blank canvas: mock_image = Image.new(mode="RGB", size=(16, 8)) mock_image_open.return_value = mock_image - target = tuple([{"image_path": self.image_path, "image_data": mock_image}]) + target = tuple( + [ + { + "image_path": self.image_path, + "image_data": mock_image, + "instance_prompt_text": "fake_prompt_text", + "original_size": (16, 8), + "target_size": (16, 8), + "aspect_ratio": 1.0, + "luminance": 0.5, + } + ] + ) examples = self.dataset.__getitem__(target) # Grab the size of the first image: example = examples[0] @@ -70,12 +79,7 @@ def test_getitem_invalid_image(self): with self.assertRaises(Exception): with self.assertLogs("MultiAspectDataset", level="ERROR") as cm: - self.dataset.__getitem__(self.image_path) - - def test_getitem_not_in_training_state(self): - input_data = tuple([{"image_path": self.image_path}]) - example = self.dataset.__getitem__(input_data) - self.assertIsNotNone(example) + self.dataset.__getitem__(self.image_metadata) if __name__ == "__main__": diff --git a/tests/test_sampler.py b/tests/test_sampler.py index bc29cffa..d0dbc2a2 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -17,14 +17,17 @@ def setUp(self): self.accelerator = MagicMock() self.accelerator.log = MagicMock() self.bucket_manager = Mock(spec=BucketManager) + self.bucket_manager.id = "foo" self.bucket_manager.aspect_ratio_bucket_indices = {"1.0": ["image1", "image2"]} self.bucket_manager.seen_images = {} self.data_backend = MockDataBackend() + self.data_backend.id = "foo" self.batch_size = 2 self.seen_images_path = "/some/fake/seen_images.json" self.state_path = "/some/fake/state.json" self.sampler = MultiAspectSampler( + id="foo", bucket_manager=self.bucket_manager, data_backend=self.data_backend, accelerator=self.accelerator, diff --git a/train_sdxl.py b/train_sdxl.py index 6f6458c4..93c8171a 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -245,9 +245,6 @@ def main(): * accelerator.num_processes ) - # Create a DataBackend, so that we can access our dataset. - configure_multi_databackend(args, accelerator) - # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 @@ -333,6 +330,34 @@ def main(): vae.requires_grad_(False) text_encoder_1.requires_grad_(False) text_encoder_2.requires_grad_(False) + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae_dtype = torch.float32 + if hasattr(args, "vae_dtype"): + logger.info( + f"Initialising VAE in {args.vae_dtype} precision, you may specify a different value if preferred: bf16, fp16, fp32, default" + ) + # Let's use a case-switch for convenience: bf16, fp16, fp32, none/default + if args.vae_dtype == "bf16": + vae_dtype = torch.bfloat16 + elif args.vae_dtype == "fp16": + vae_dtype = torch.float16 + elif args.vae_dtype == "fp32": + vae_dtype = torch.float32 + elif args.vae_dtype == "none" or args.vae_dtype == "default": + vae_dtype = torch.float32 + if args.pretrained_vae_model_name_or_path is not None: + logger.debug(f"Initialising VAE with weight dtype {vae_dtype}") + vae.to(accelerator.device, dtype=vae_dtype) + else: + logger.debug(f"Initialising VAE with custom dtype {vae_dtype}") + vae.to(accelerator.device, dtype=vae_dtype) + StateTracker.set_vae_dtype(vae_dtype) + StateTracker.set_vae(vae) + logger.info(f"Loaded VAE into VRAM.") + + # Create a DataBackend, so that we can access our dataset. + configure_multi_databackend(args, accelerator) prompt_handler = None if not args.disable_compel: @@ -421,7 +446,13 @@ def main(): # We calculate the number of steps per epoch by dividing the number of images by the effective batch divisor. # Gradient accumulation steps mean that we only update the model weights every /n/ steps. num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps + sum( + [ + len(backend["bucket_manager"]) + for _, backend in StateTracker.get_data_backends().items() + ] + ) + / args.gradient_accumulation_steps ) if args.max_train_steps is None or args.max_train_steps == 0: if args.num_train_epochs is None or args.num_train_epochs == 0: @@ -665,44 +696,25 @@ def main(): # Prepare everything with our `accelerator`. disable_accelerator = os.environ.get("SIMPLETUNER_DISABLE_ACCELERATOR", False) + train_dataloaders = [] + for _, backend in StateTracker.get_data_backends().items(): + train_dataloaders.append(backend["train_dataloader"]) if not disable_accelerator: logger.info(f"Loading our accelerator...") - unet, train_dataloader, lr_scheduler, optimizer = accelerator.prepare( - unet, train_dataloader, lr_scheduler, optimizer - ) + results = accelerator.prepare(unet, lr_scheduler, optimizer, *train_dataloaders) + unet = results[0] + lr_scheduler = results[1] + optimizer = results[2] + # The rest of the entries are dataloaders: + train_dataloaders = results[3:] if args.use_ema: logger.info("Moving EMA model weights to accelerator...") ema_unet.to(accelerator.device, dtype=weight_dtype) - # Move vae, unet and text_encoder to device and cast to weight_dtype - # The VAE is in float32 to avoid NaN losses. - vae_dtype = torch.float32 - if hasattr(args, "vae_dtype"): - logger.info( - f"Initialising VAE in {args.vae_dtype} precision, you may specify a different value if preferred: bf16, fp16, fp32, default" - ) - # Let's use a case-switch for convenience: bf16, fp16, fp32, none/default - if args.vae_dtype == "bf16": - vae_dtype = torch.bfloat16 - elif args.vae_dtype == "fp16": - vae_dtype = torch.float16 - elif args.vae_dtype == "fp32": - vae_dtype = torch.float32 - elif args.vae_dtype == "none" or args.vae_dtype == "default": - vae_dtype = torch.float32 - if args.pretrained_vae_model_name_or_path is not None: - logger.debug(f"Initialising VAE with weight dtype {vae_dtype}") - vae.to(accelerator.device, dtype=vae_dtype) - else: - logger.debug(f"Initialising VAE with custom dtype {vae_dtype}") - vae.to(accelerator.device, dtype=vae_dtype) - StateTracker.set_vae_dtype(vae_dtype) - StateTracker.set_vae(vae) - logger.info(f"Loaded VAE into VRAM.") - # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps + sum([len(dataloader) for dataloader in train_dataloaders]) + / args.gradient_accumulation_steps ) if hasattr(lr_scheduler, "num_update_steps_per_epoch"): lr_scheduler.num_update_steps_per_epoch = num_update_steps_per_epoch @@ -720,12 +732,6 @@ def main(): if accelerator.is_main_process: # Copy args into public_args: public_args = copy.deepcopy(args) - # Remove the args that we don't want to track: - del public_args.aws_access_key_id - del public_args.aws_secret_access_key - del public_args.aws_bucket_name - del public_args.aws_region_name - del public_args.aws_endpoint_url # Hash the contents of public_args to reflect a deterministic ID for a single set of params: public_args_hash = hashlib.md5( json.dumps(vars(public_args), sort_keys=True).encode("utf-8") @@ -751,7 +757,8 @@ def main(): del vae vae = None - vaecache.vae = None + for _, backend in StateTracker.get_data_backends().items(): + backend["vaecache"].vae = None gc.collect() torch.cuda.empty_cache() memory_after_unload = torch.cuda.memory_allocated() / 1024**3 @@ -793,9 +800,12 @@ def main(): else: logger.info(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) - custom_balanced_sampler.load_states( - state_path=os.path.join(args.output_dir, path, "training_state.json"), - ) + for _, backend in StateTracker.get_data_backends().items(): + backend["sampler"].load_states( + state_path=os.path.join( + args.output_dir, path, "training_state.json" + ), + ) resume_global_step = global_step = int(path.split("-")[1]) # If we use a constant LR, we can update that now. @@ -812,10 +822,16 @@ def main(): f" {num_update_steps_per_epoch} steps per epoch and" f" {args.gradient_accumulation_steps} gradient_accumulation_steps" ) - custom_balanced_sampler.log_state() + for _, backend in StateTracker.get_data_backends().items(): + backend["sampler"].log_state() total_steps_remaining_at_start = args.max_train_steps # We store the number of dataset resets that have occurred inside the checkpoint. - first_epoch = custom_balanced_sampler.current_epoch + first_epoch = max( + [ + backend["sampler"].current_epoch + for backend in StateTracker.get_data_backends() + ] + ) if first_epoch > 1: steps_to_remove = first_epoch * num_update_steps_per_epoch total_steps_remaining_at_start -= steps_to_remove @@ -829,8 +845,11 @@ def main(): ) logger.info("***** Running training *****") + total_num_batches = len( + [backend["train_dataset"] for backend in StateTracker.get_data_backends()] + ) logger.info( - f" Num batches = {len(train_dataset)} ({len(train_dataset) * args.train_batch_size} samples)" + f" Num batches = {total_num_batches} ({total_num_batches * args.train_batch_size} samples)" ) logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Current Epoch = {first_epoch}") @@ -1138,9 +1157,12 @@ def main(): args.output_dir, f"checkpoint-{global_step}" ) accelerator.save_state(save_path) - custom_balanced_sampler.save_state( - state_path=os.path.join(save_path, "training_state.json"), - ) + for backend in StateTracker.get_data_backends(): + backend["sampler"].save_state( + state_path=os.path.join( + save_path, "training_state.json" + ), + ) logger.info(f"Saved state to {save_path}") logs = { diff --git a/train_sdxl.sh b/train_sdxl.sh index 7172f648..09579c57 100644 --- a/train_sdxl.sh +++ b/train_sdxl.sh @@ -219,7 +219,7 @@ fi # Run the training script. accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --num_processes="${TRAINING_NUM_PROCESSES}" --num_machines="${TRAINING_NUM_MACHINES}" --dynamo_backend="${TRAINING_DYNAMO_BACKEND}" train_sdxl.py \ ---pretrained_model_name_or_path="${MODEL_NAME}" "${XFORMERS_ARG}" "${GRADIENT_ARG}" --set_grads_to_none --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ +--pretrained_model_name_or_path="${MODEL_NAME}" ${XFORMERS_ARG} ${GRADIENT_ARG} --set_grads_to_none --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ --resume_from_checkpoint="${RESUME_CHECKPOINT}" ${DELETE_ARGS} ${SNR_GAMMA_ARG} \ --num_train_epochs=${NUM_EPOCHS} --max_train_steps=${MAX_NUM_STEPS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ --learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" --lr_warmup_steps="${LR_WARMUP_STEPS}" \ From d230abf1e237d1f5f271b491844435f7d29d3c94 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 19:19:01 -0600 Subject: [PATCH 08/22] WIP fixes for aspect bucketing / disk cache --- helpers/caching/vae.py | 27 +++++----- helpers/data_backend/factory.py | 28 +++++++++- helpers/multiaspect/bucket.py | 18 +++---- helpers/multiaspect/sampler.py | 8 +-- helpers/prompts.py | 4 +- helpers/training/collate.py | 16 +++--- helpers/training/state_tracker.py | 89 ++++++++++++++----------------- train_sdxl.py | 10 ++-- 8 files changed, 110 insertions(+), 90 deletions(-) diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index 29583fd8..f9d800c9 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -118,23 +118,24 @@ def retrieve_from_cache(self, filepath: str): def discover_all_files(self, directory: str = None): """Identify all files in a directory.""" - all_image_files = ( - StateTracker.get_image_files() - or StateTracker.set_image_files( - self.data_backend.list_files( - instance_data_root=self.instance_data_root, - str_pattern="*.[jJpP][pPnN][gG]", - ) - ) + all_image_files = StateTracker.get_image_files( + data_backend_id=self.id + ) or StateTracker.set_image_files( + self.data_backend.list_files( + instance_data_root=self.instance_data_root, + str_pattern="*.[jJpP][pPnN][gG]", + ), + data_backend_id=self.id, ) # This isn't returned, because we merely check if it's stored, or, store it. ( - StateTracker.get_vae_cache_files() + StateTracker.get_vae_cache_files(data_backend_id=self.id) or StateTracker.set_vae_cache_files( self.data_backend.list_files( instance_data_root=self.cache_dir, str_pattern="*.pt", - ) + ), + data_backend_id=self.id, ) ) self.debug_log( @@ -147,7 +148,7 @@ def _list_cached_images(self): Return a set of filenames (without the .pt extension) that have been processed. """ # Extract array of tuple into just, an array of files: - pt_files = StateTracker.get_vae_cache_files() + pt_files = StateTracker.get_vae_cache_files(data_backend_id=self.id) # Extract just the base filename without the extension results = {os.path.splitext(f)[0] for f in pt_files} logging.debug( @@ -157,8 +158,8 @@ def _list_cached_images(self): def discover_unprocessed_files(self, directory: str = None): """Identify files that haven't been processed yet.""" - all_image_files = StateTracker.get_image_files() - existing_cache_files = StateTracker.get_vae_cache_files() + all_image_files = StateTracker.get_image_files(data_backend_id=self.id) + existing_cache_files = StateTracker.get_vae_cache_files(data_backend_id=self.id) self.debug_log( f"discover_unprocessed_files found {len(all_image_files)} images from StateTracker (truncated): {list(all_image_files)[:5]}" ) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index fda9d8ea..e867a995 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -10,7 +10,7 @@ from helpers.training.collate import collate_fn from helpers.training.state_tracker import StateTracker -import json, os, torch, logging +import json, os, torch, logging, random logger = logging.getLogger("DataBackendFactory") logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) @@ -315,3 +315,29 @@ def get_dataset(args: dict, accelerator) -> list: accelerator=accelerator, ) ] + + +def random_dataloader_iterator(dataloaders): + """ + Create an iterator that yields batches from multiple dataloaders randomly. + + Args: + dataloaders (list): A list of DataLoader objects. + + Yields: + A batch from one of the dataloaders chosen randomly. + """ + # Create iterators for each dataloader + iterators = [iter(dataloader) for dataloader in dataloaders] + step = 0 + while iterators: + # Randomly select a dataloader iterator + step += 1 + chosen_iter = random.choice(iterators) + + try: + # Yield a batch from the chosen dataloader + yield (step, next(chosen_iter)) + except StopIteration: + # If the chosen iterator is exhausted, remove it from the list + iterators.remove(chosen_iter) diff --git a/helpers/multiaspect/bucket.py b/helpers/multiaspect/bucket.py index 5046afb1..8ee734f4 100644 --- a/helpers/multiaspect/bucket.py +++ b/helpers/multiaspect/bucket.py @@ -72,14 +72,14 @@ def _discover_new_files(self, for_metadata: bool = False): Returns: list: A list of new files. """ - all_image_files = ( - StateTracker.get_image_files() - or StateTracker.set_image_files( - self.data_backend.list_files( - instance_data_root=self.instance_data_root, - str_pattern="*.[jJpP][pPnN][gG]", - ) - ) + all_image_files = StateTracker.get_image_files( + data_backend_id=self.data_backend.id + ) or StateTracker.set_image_files( + self.data_backend.list_files( + instance_data_root=self.instance_data_root, + str_pattern="*.[jJpP][pPnN][gG]", + ), + data_backend_id=self.data_backend.id, ) # Log an excerpt of the all_image_files: logger.debug( @@ -392,7 +392,7 @@ def refresh_buckets(self, rank: int = None): self.compute_aspect_ratio_bucket_indices() # Get the list of existing files - existing_files = StateTracker.get_image_files() + existing_files = StateTracker.get_image_files(data_backend_id=self.id) # Update bucket indices to remove entries that no longer exist self.update_buckets_with_existing_files(existing_files) diff --git a/helpers/multiaspect/sampler.py b/helpers/multiaspect/sampler.py index ad3a7d03..9c1d0fea 100644 --- a/helpers/multiaspect/sampler.py +++ b/helpers/multiaspect/sampler.py @@ -123,11 +123,6 @@ def load_buckets(self): self.bucket_manager.aspect_ratio_bucket_indices.keys() ) # These keys are a float value, eg. 1.78. - def retrieve_vae_cache(self): - if self.vae_cache is None: - self.vae_cache = StateTracker.get_vaecache() - return self.vae_cache - def _yield_random_image(self): bucket = random.choice(self.buckets) image_path = random.choice( @@ -332,6 +327,7 @@ def _validate_and_yield_images_from_samples(self, samples, bucket): self.debug_log( f"Image {image_path} is considered valid. Adding to yield list." ) + image_metadata["data_backend_id"] = self.id image_metadata["image_path"] = image_path # Use the magic prompt handler to retrieve the captions. @@ -475,4 +471,4 @@ def convert_to_human_readable( return f"{ratio_width}:{ratio_height}" def debug_log(self, msg: str): - logger.debug(f"{self.rank_info}{msg}", main_process_only=False) + logger.debug(f"{self.rank_info} (id: {self.id}) {msg}", main_process_only=False) diff --git a/helpers/prompts.py b/helpers/prompts.py index fa38848c..8844163f 100644 --- a/helpers/prompts.py +++ b/helpers/prompts.py @@ -301,7 +301,9 @@ def get_all_captions( data_backend: BaseDataBackend, ) -> list: captions = [] - all_image_files = StateTracker.get_image_files() or data_backend.list_files( + all_image_files = StateTracker.get_image_files( + data_backend_id=data_backend.id + ) or data_backend.list_files( instance_data_root=instance_data_root, str_pattern="*.[jJpP][pPnN][gG]" ) if type(all_image_files) == list and type(all_image_files[0]) == tuple: diff --git a/helpers/training/collate.py b/helpers/training/collate.py index be38dbd6..21cae78c 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -81,10 +81,10 @@ def extract_filepaths(examples): return filepaths -def fetch_latent(fp): +def fetch_latent(fp, data_backend_id: str): """Worker method to fetch latent for a single image.""" - debug_log(" -> pull latents from cache") - latent = StateTracker.get_vaecache().retrieve_from_cache(fp) + debug_log(f" -> pull latents from cache via data backend {data_backend_id}") + latent = StateTracker.get_vaecache(id=data_backend_id).retrieve_from_cache(fp) # Move to CPU and pin memory if it's not on the GPU debug_log(" -> push latents to GPU via pinned memory") @@ -92,10 +92,12 @@ def fetch_latent(fp): return latent -def compute_latents(filepaths): +def compute_latents(filepaths, data_backend_id: str): # Use a thread pool to fetch latents concurrently with concurrent.futures.ThreadPoolExecutor() as executor: - latents = list(executor.map(fetch_latent, filepaths)) + latents = list( + executor.map(fetch_latent, filepaths, [data_backend_id] * len(filepaths)) + ) # Validate shapes test_shape = latents[0].shape @@ -155,7 +157,6 @@ def collate_fn(batch): "This trainer is not designed to handle multiple batches in a single collate." ) debug_log("Begin collate_fn on batch") - examples = batch[0] # SDXL Dropout dropout_probability = StateTracker.get_args().caption_dropout_probability @@ -163,6 +164,7 @@ def collate_fn(batch): # Randomly drop captions/conditioning based on dropout_probability for example in examples: + data_backend_id = example["data_backend_id"] if ( dropout_probability > 0 and dropout_probability is not None @@ -180,7 +182,7 @@ def collate_fn(batch): debug_log("Extract filepaths") filepaths = extract_filepaths(examples) debug_log("Compute latents") - latent_batch = compute_latents(filepaths) + latent_batch = compute_latents(filepaths, data_backend_id) debug_log("Check latents") check_latent_shapes(latent_batch, filepaths) diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 219ad4f2..9788b12e 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -11,8 +11,8 @@ class StateTracker: # Class variables has_training_started = False calculate_luminance = False - all_image_files = None - all_vae_cache_files = None + all_image_files = {} + all_vae_cache_files = {} all_caption_files = None # Backend entities for retrieval @@ -52,45 +52,57 @@ def _save_to_disk(cls, cache_name, data): json.dump(data, f) @classmethod - def set_image_files(cls, raw_file_list): - if cls.all_image_files is not None: - cls.all_image_files.clear() + def set_image_files(cls, raw_file_list: list, data_backend_id: str): + if cls.all_image_files[data_backend_id] is not None: + cls.all_image_files[data_backend_id].clear() else: - cls.all_image_files = {} + cls.all_image_files[data_backend_id] = {} for subdirectory_list in raw_file_list: _, _, files = subdirectory_list for image in files: - cls.all_image_files[image] = False - cls._save_to_disk("all_image_files", cls.all_image_files) - logger.debug(f"set_image_files found {len(cls.all_image_files)} images.") - return cls.all_image_files + cls.all_image_files[data_backend_id][image] = False + cls._save_to_disk( + "all_image_files_{}".format(data_backend_id), + cls.all_image_files[data_backend_id], + ) + logger.debug( + f"set_image_files found {len(cls.all_image_files[data_backend_id])} images." + ) + return cls.all_image_files[data_backend_id] @classmethod - def get_image_files(cls): - if not cls.all_image_files: - cls.all_image_files = cls._load_from_disk("all_image_files") - return cls.all_image_files + def get_image_files(cls, data_backend_id: str): + if data_backend_id not in cls.all_image_files: + cls.all_image_files[data_backend_id] = cls._load_from_disk( + "all_image_files_{}".format(data_backend_id) + ) + return cls.all_image_files[data_backend_id] @classmethod - def set_vae_cache_files(cls, raw_file_list): - if cls.all_vae_cache_files is not None: - cls.all_vae_cache_files.clear() + def set_vae_cache_files(cls, raw_file_list: list, data_backend_id: str): + if cls.all_vae_cache_files[data_backend_id] is not None: + cls.all_vae_cache_files[data_backend_id].clear() else: - cls.all_vae_cache_files = {} + cls.all_vae_cache_files[data_backend_id] = {} for subdirectory_list in raw_file_list: _, _, files = subdirectory_list for image in files: - cls.all_vae_cache_files[path.basename(image)] = False - cls._save_to_disk("all_vae_cache_files", cls.all_vae_cache_files) + cls.all_vae_cache_files[data_backend_id][path.basename(image)] = False + cls._save_to_disk( + "all_vae_cache_files_{}".format(data_backend_id), + cls.all_vae_cache_files[data_backend_id], + ) logger.debug( - f"set_vae_cache_files found {len(cls.all_vae_cache_files)} images." + f"set_vae_cache_files found {len(cls.all_vae_cache_files[data_backend_id])} images." ) @classmethod - def get_vae_cache_files(cls): - if not cls.all_vae_cache_files: - cls.all_vae_cache_files = cls._load_from_disk("all_vae_cache_files") - return cls.all_vae_cache_files + def get_vae_cache_files(cls: list, data_backend_id: str): + if data_backend_id not in cls.all_vae_cache_files: + cls.all_vae_cache_files[data_backend_id] = cls._load_from_disk( + "all_vae_cache_files_{}".format(data_backend_id) + ) + return cls.all_vae_cache_files[data_backend_id] @classmethod def set_caption_files(cls, caption_files): @@ -103,18 +115,6 @@ def get_caption_files(cls): cls.all_caption_files = cls._load_from_disk("all_caption_files") return cls.all_caption_files - @classmethod - def has_image_files_loaded(cls): - return len(list(cls.all_image_files.keys())) > 0 - - @classmethod - def has_vae_cache_files_loaded(cls): - return len(list(cls.all_vae_cache_files.keys())) > 0 - - @classmethod - def has_caption_files_loaded(cls): - return len(list(cls.all_caption_files.keys())) > 0 - @classmethod def register_data_backend(cls, data_backend): cls.data_backends[data_backend["id"]] = data_backend @@ -127,13 +127,6 @@ def get_data_backend(cls, id: str): def get_data_backends(cls): return cls.data_backends - @classmethod - def get_bucket_manager(cls, id: str): - for data_backend in cls.data_backends: - if data_backend["id"] == id: - return data_backend["bucket_manager"] - return None - @classmethod def set_accelerator(cls, accelerator): cls.accelerator = accelerator @@ -158,10 +151,6 @@ def set_vae_dtype(cls, vae_dtype): def get_vae_dtype(cls): return cls.vae_dtype - @classmethod - def set_bucket_manager(cls, bucket_manager): - cls.bucket_managers.append(bucket_manager) - @classmethod def get_bucket_managers(cls): return cls.bucket_managers @@ -187,8 +176,8 @@ def set_vaecache(cls, vaecache): cls.vaecache = vaecache @classmethod - def get_vaecache(cls): - return cls.vaecache + def get_vaecache(cls, id: str): + return cls.data_backends[id]["vaecache"] @classmethod def set_embedcache(cls, embedcache): diff --git a/train_sdxl.py b/train_sdxl.py index 93c8171a..66974f43 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -26,6 +26,7 @@ from helpers.training.state_tracker import StateTracker from helpers.training.deepspeed import deepspeed_zero_init_disabled_context_manager from helpers.data_backend.factory import configure_multi_databackend +from helpers.data_backend.factory import random_dataloader_iterator from helpers.caching.vae import VAECache from helpers.caching.sdxl_embeds import TextEmbeddingCache from helpers.training.custom_schedule import ( @@ -829,7 +830,7 @@ def main(): first_epoch = max( [ backend["sampler"].current_epoch - for backend in StateTracker.get_data_backends() + for _, backend in StateTracker.get_data_backends().items() ] ) if first_epoch > 1: @@ -846,7 +847,10 @@ def main(): logger.info("***** Running training *****") total_num_batches = len( - [backend["train_dataset"] for backend in StateTracker.get_data_backends()] + [ + backend["train_dataset"] + for _, backend in StateTracker.get_data_backends().items() + ] ) logger.info( f" Num batches = {total_num_batches} ({total_num_batches * args.train_batch_size} samples)" @@ -889,7 +893,7 @@ def main(): current_epoch = epoch unet.train() current_epoch_step = 0 - for step, batch in enumerate(train_dataloader): + for step, batch in random_dataloader_iterator(train_dataloaders): if args.lr_scheduler == "cosine_with_restarts": scheduler_kwargs["step"] = global_step if accelerator.is_main_process: From 78061929b3f58daa97a0607844e01fc107aac592 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 19:19:17 -0600 Subject: [PATCH 09/22] tests --- tests/test_collate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_collate.py b/tests/test_collate.py index f5264089..ac2ee43b 100644 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -20,6 +20,7 @@ def setUp(self): "original_size": (100, 100), "image_data": MagicMock(), "crop_coordinates": [0, 0, 100, 100], + "data_backend_id": "foo", }, # Add more examples as needed ] From 8e772326db14b129d0712e38945af8419c79c2f7 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 20:10:04 -0600 Subject: [PATCH 10/22] update documentation for --data_backend_config --- OPTIONS.md | 54 ++++++++++++++--------------------------------------- TUTORIAL.md | 5 +++++ 2 files changed, 19 insertions(+), 40 deletions(-) diff --git a/OPTIONS.md b/OPTIONS.md index b34c3c13..4800085b 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -8,6 +8,12 @@ This guide provides a user-friendly breakdown of the command-line options availa ## 🌟 Core Model Configuration +### `--data_backend_config` + +- **What**: Path to your SimpleTuner dataset configuration. +- **Why**: Multiple datasets on different storage medium may be combined into a single training session. +- **Example**: See (multidatabackend.json.example)[/multidatabackend.json.example] for an example configuration. + ### `--pretrained_model_name_or_path` - **What**: Path to the pretrained model or its identifier from huggingface.co/models. @@ -155,16 +161,8 @@ usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--revision REVISION] --instance_data_dir INSTANCE_DATA_DIR [--preserve_data_backend_cache] [--cache_dir_text CACHE_DIR_TEXT] - [--cache_dir_vae CACHE_DIR_VAE] - [--data_backend {local,aws}] - [--write_batch_size WRITE_BATCH_SIZE] - [--aws_config_file AWS_CONFIG_FILE] - [--aws_bucket_name AWS_BUCKET_NAME] - [--aws_bucket_image_prefix AWS_BUCKET_IMAGE_PREFIX] - [--aws_endpoint_url AWS_ENDPOINT_URL] - [--aws_region_name AWS_REGION_NAME] - [--aws_access_key_id AWS_ACCESS_KEY_ID] - [--aws_secret_access_key AWS_SECRET_ACCESS_KEY] + [--cache_dir_vae CACHE_DIR_VAE] --data_backend_config + DATA_BACKEND_CONFIG [--write_batch_size WRITE_BATCH_SIZE] [--cache_dir CACHE_DIR] [--cache_clear_validation_prompts] [--seen_state_path SEEN_STATE_PATH] @@ -366,13 +364,12 @@ options: This is the path to a local directory that will contain your VAE outputs. Unlike the text embed cache, your VAE latents will be stored in the AWS data - backend. If the AWS backend is in use, this will be a - prefix for the bucket's VAE cache entries. - --data_backend {local,aws} - The data backend to use. Choose between ['local', - 'aws']. Default: local. If using AWS, you must set the - AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY - environment variables. + backend. Each backend can have its own value, but if + that is not provided, this will be the default value. + --data_backend_config DATA_BACKEND_CONFIG + The relative or fully-qualified path for your data + backend config. See multidatabackend.json.example for + an example. --write_batch_size WRITE_BATCH_SIZE When using certain storage backends, it is better to batch smaller writes rather than continuous @@ -381,29 +378,6 @@ options: objects are written. This mostly applies to S3, but some shared server filesystems may benefit as well, eg. Ceph. Default: 64. - --aws_config_file AWS_CONFIG_FILE - Path to the AWS configuration file in JSON format. - Config key names are the same as SimpleTuner option - counterparts. - --aws_bucket_name AWS_BUCKET_NAME - The AWS bucket name to use. - --aws_bucket_image_prefix AWS_BUCKET_IMAGE_PREFIX - Instead of using --instance_data_dir, AWS S3 relies on - aws_bucket_*_prefix parameters. When provided, this - parameter will be prepended to the image path. - --aws_endpoint_url AWS_ENDPOINT_URL - The AWS server to use. If not specified, will use the - default server for the region specified. For Wasabi, - use https://s3.wasabisys.com. - --aws_region_name AWS_REGION_NAME - The AWS region to use. If not specified, will use the - default region for the server specified. For example, - if you specify 's3.amazonaws.com', the default region - will be 'us-east-1'. - --aws_access_key_id AWS_ACCESS_KEY_ID - The AWS access key ID. - --aws_secret_access_key AWS_SECRET_ACCESS_KEY - The AWS secret access key. --cache_dir CACHE_DIR The directory where the downloaded models and datasets will be stored. diff --git a/TUTORIAL.md b/TUTORIAL.md index 1e7667f5..0067623a 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -158,6 +158,11 @@ Here's a breakdown of what each environment variable does: #### General Settings +- `DATALOADER_CONFIG`: This file is mandatory, and an example copy can be found in `multidatabackend.json.example` which contains an example for a multi-dataset configuration split between S3 and local data storage. + - One or more datasets can be configured, but it's not necessary to use multiple. + - Some config options that have an equivalent commandline option name can be omitted, in favour of the global option + - Some config options are mandatory, but errors will emit for those on startup. Feel free to experiment. + - Each dataset can have its own crop and resolution config. - `TRAINING_SEED`: You may set a numeric value here and it will make your training reproducible to that seed across all other given settings. - You may wish to set this to -1 so that your training is absolutely random, which prevents overfitting to a given seed. - `RESUME_CHECKPOINT`: Specifies which checkpoint to resume from. "latest" will pick the most recent one. From abb76f597e5cecbd6c81d90681585c08ac94207a Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 20:18:21 -0600 Subject: [PATCH 11/22] Update example configuration and OPTIONS file --- OPTIONS.md | 30 ++++++++++++++++++++---------- sdxl-env.sh.example | 1 + train_sdxl.sh | 9 ++++++++- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/OPTIONS.md b/OPTIONS.md index 4800085b..743b4882 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -8,12 +8,6 @@ This guide provides a user-friendly breakdown of the command-line options availa ## 🌟 Core Model Configuration -### `--data_backend_config` - -- **What**: Path to your SimpleTuner dataset configuration. -- **Why**: Multiple datasets on different storage medium may be combined into a single training session. -- **Example**: See (multidatabackend.json.example)[/multidatabackend.json.example] for an example configuration. - ### `--pretrained_model_name_or_path` - **What**: Path to the pretrained model or its identifier from huggingface.co/models. @@ -36,10 +30,11 @@ This guide provides a user-friendly breakdown of the command-line options availa - **What**: Folder containing the training data. - **Why**: Designates where your training images and other data are stored. -### `--data_backend` +### `--data_backend_config` -- **What**: Specifies the data storage backend, either 'local' or 'aws'. -- **Why**: Allows for seamless switching between local and cloud storage. +- **What**: Path to your SimpleTuner dataset configuration. +- **Why**: Multiple datasets on different storage medium may be combined into a single training session. +- **Example**: See (multidatabackend.json.example)[/multidatabackend.json.example] for an example configuration. --- @@ -48,7 +43,7 @@ This guide provides a user-friendly breakdown of the command-line options availa ### `--resolution` - **What**: Input image resolution. Can be expressed as pixels, or megapixels. -- **Why**: All images in the dataset will have their smaller edge resized to this resolution for training. If you use 1024px, the images may become very large and use an excessive amount of VRAM. The best mileage tends to be a 768 or 800 pixel base resolution, although 512px resolution training can really pay off with SDXL in particular. +- **Why**: All images in the dataset will have their smaller edge resized to this resolution for training. It is recommended use a value of 1.0 if also using `--resolution_type=area`. When using `--resolution_type=pixel` and `--resolution=1024px`, the images may become very large and use an excessive amount of VRAM. The recommended configuration is to combine `--resolution_type=area` with `--resolution=1` (or lower - .25 would be a 512px model with data bucketing). ### `--resolution_type` @@ -65,6 +60,21 @@ This guide provides a user-friendly breakdown of the command-line options availa - **What**: Strategy for deriving image captions. __Choices__: `textfile`, `filename` - **Why**: Determines how captions are generated for training images. `textfile` will use the contents of a `.txt` file with the same filename as the image, and `filename` will apply some cleanup to the filename before using it as the caption. +### `--crop` + +- **What**: When `--crop=true` is supplied, SimpleTuner will crop all (new) images in the training dataset. It will not re-process old images. +- **Why**: Training on cropped images seems to result in better fine detail learning, especially on SDXL models. + +### `--crop_style` + +- **What**: When `--crop=true`, the trainer may be instructed to crop in different ways. +- **Why**: The `crop_style` option can be set to `center` (or `centre`) for a classic centre-crop, `corner` to elect for the lowest-right corner, and `random` for a random image slice. Default: random. + +### `--crop_aspect` + +- **What**: When using `--crop=true`, the `--crop_aspect` option may be supplied with a value of `square` or `preserve`. +- **Why**: The default crop behaviour is to crop all images to a square aspect ratio, but when `--crop_aspect=preserve` is supplied, the trainer will crop images to a size matching their original aspect ratio. This may help to keep multi-resolution support, but it may also harm training quality. Your mileage may vary. + --- ## 🎛 Training Parameters diff --git a/sdxl-env.sh.example b/sdxl-env.sh.example index 8564acff..2cdebec6 100644 --- a/sdxl-env.sh.example +++ b/sdxl-env.sh.example @@ -30,6 +30,7 @@ export NUM_EPOCHS=25 # Location of training data. export BASE_DIR="/notebooks/datasets" +export DATALOADER_CONFIG="multidatabackend.json" export INSTANCE_DIR="${BASE_DIR}/training_data" export OUTPUT_DIR="${BASE_DIR}/models" # By default, images will be resized so their SMALLER EDGE is 1024 pixels, maintaining aspect ratio. diff --git a/train_sdxl.sh b/train_sdxl.sh index 09579c57..ba006190 100644 --- a/train_sdxl.sh +++ b/train_sdxl.sh @@ -180,6 +180,13 @@ if ! [ -z "$USE_XFORMERS" ] && [[ "$USE_XFORMERS" == "false" ]]; then export XFORMERS_ARG="" fi +if [ -z "$DATALOADER_CONFIG" ]; then + printf "DATALOADER_CONFIG not set, cannot continue. See multidatabackend.json.example.\n" +fi +if ! [ -f "$DATALOADER_CONFIG" ]; then + printf "DATALOADER_CONFIG file %s not found, cannot continue.\n" "${DATALOADER_CONFIG}" +fi + export SNR_GAMMA_ARG="" if ! [ -z "$MIN_SNR_GAMMA" ]; then export SNR_GAMMA_ARG="--snr_gamma=${MIN_SNR_GAMMA}" @@ -220,7 +227,7 @@ fi accelerate launch ${ACCELERATE_EXTRA_ARGS} --mixed_precision="${MIXED_PRECISION}" --num_processes="${TRAINING_NUM_PROCESSES}" --num_machines="${TRAINING_NUM_MACHINES}" --dynamo_backend="${TRAINING_DYNAMO_BACKEND}" train_sdxl.py \ --pretrained_model_name_or_path="${MODEL_NAME}" ${XFORMERS_ARG} ${GRADIENT_ARG} --set_grads_to_none --gradient_accumulation_steps=${GRADIENT_ACCUMULATION_STEPS} \ ---resume_from_checkpoint="${RESUME_CHECKPOINT}" ${DELETE_ARGS} ${SNR_GAMMA_ARG} \ +--resume_from_checkpoint="${RESUME_CHECKPOINT}" ${DELETE_ARGS} ${SNR_GAMMA_ARG} --data_backend_config="${DATALOADER_CONFIG}" \ --num_train_epochs=${NUM_EPOCHS} --max_train_steps=${MAX_NUM_STEPS} --metadata_update_interval=${METADATA_UPDATE_INTERVAL} \ --learning_rate="${LEARNING_RATE}" --lr_scheduler="${LR_SCHEDULE}" --seed "${TRAINING_SEED}" --lr_warmup_steps="${LR_WARMUP_STEPS}" \ --instance_data_dir="${INSTANCE_DIR}" --seen_state_path="${SEEN_STATE_PATH}" --state_path="${STATE_PATH}" --output_dir="${OUTPUT_DIR}" \ From 207ae9704dd7b0a38d8fca6d8c345952b2b61aeb Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 20:29:24 -0600 Subject: [PATCH 12/22] allow crop configs to be specified in data backends --- helpers/data_backend/factory.py | 12 +++++++++++- helpers/training/state_tracker.py | 4 ++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index e867a995..e28749fd 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -16,6 +16,16 @@ logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) +def init_backend_config(backend: dict, args: dict, accelerator) -> dict: + output = {"id": backend["id"], "config": {}} + if "crop" in backend: + output["config"]["crop"] = backend["crop"] + if "crop_aspect" in args: + output["config"]["crop_aspect"] = args.crop_aspect + if "crop_style" in args: + output["config"]["crop_style"] = args.crop_style + + def print_bucket_info(bucket_manager): # Print table header print(f"{rank_info()} | {'Bucket':<10} | {'Image Count':<12}") @@ -54,7 +64,7 @@ def configure_multi_databackend(args: dict, accelerator): raise ValueError( "No identifier was given for one more of your data backends. Add a unique 'id' field to each one." ) - init_backend = {"id": backend["id"]} + init_backend = init_backend_config(backend, args, accelerator) if backend["type"] == "local": init_backend["data_backend"] = get_local_backend( accelerator, init_backend["id"] diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 9788b12e..47309175 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -123,6 +123,10 @@ def register_data_backend(cls, data_backend): def get_data_backend(cls, id: str): return cls.data_backends[id] + @classmethod + def get_data_backend_config(cls, id: str): + return cls.data_backends[id]["config"] + @classmethod def get_data_backends(cls): return cls.data_backends From d936b90ecdc25f507fdb648cb60ff7bfb356271c Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 20:46:16 -0600 Subject: [PATCH 13/22] clear backend cache on startup as before --- helpers/training/state_tracker.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 47309175..826e9f65 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -37,6 +37,22 @@ def delete_cache_files(cls): except: pass + # Glob the directory for "all_image_files.*.json" and "all_vae_cache_files.*.json", and delete those too + # This is a workaround for the fact that the cache files are named with the data_backend_id + filelist = Path(cls.args.output_dir).glob("all_image_files.*.json") + for file in filelist: + try: + file.unlink() + except: + pass + + filelist = Path(cls.args.output_dir).glob("all_vae_cache_files.*.json") + for file in filelist: + try: + file.unlink() + except: + pass + @classmethod def _load_from_disk(cls, cache_name): cache_path = Path(cls.args.output_dir) / f"{cache_name}.json" From 1e1aef10baa141d8f619dd3100d6bd2ca7a3949d Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 20:48:27 -0600 Subject: [PATCH 14/22] fix init of backend config --- helpers/data_backend/factory.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index e28749fd..5163228a 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -20,10 +20,12 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: output = {"id": backend["id"], "config": {}} if "crop" in backend: output["config"]["crop"] = backend["crop"] - if "crop_aspect" in args: - output["config"]["crop_aspect"] = args.crop_aspect - if "crop_style" in args: - output["config"]["crop_style"] = args.crop_style + if "crop_aspect" in backend: + output["config"]["crop_aspect"] = backend["crop_aspect"] + if "crop_style" in backend: + output["config"]["crop_style"] = backend["crop_style"] + + return output def print_bucket_info(bucket_manager): @@ -64,6 +66,7 @@ def configure_multi_databackend(args: dict, accelerator): raise ValueError( "No identifier was given for one more of your data backends. Add a unique 'id' field to each one." ) + # Retrieve some config file overrides for commandline arguments, eg. cropping init_backend = init_backend_config(backend, args, accelerator) if backend["type"] == "local": init_backend["data_backend"] = get_local_backend( From 98c68d8208dacb105033a5a5ba2e69f2a01fb548 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 20:57:01 -0600 Subject: [PATCH 15/22] MultiaspectImage: use the backend config for image prep --- helpers/caching/vae.py | 2 +- helpers/multiaspect/image.py | 22 +++++++++++++++++----- helpers/training/state_tracker.py | 4 ++-- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index f9d800c9..0b87ad81 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -339,7 +339,7 @@ def _process_images_in_batch(self) -> None: ) continue image, crop_coordinates = MultiaspectImage.prepare_image( - image, self.resolution, self.resolution_type + image, self.resolution, self.resolution_type, self.id ) pixel_values = self.transform(image).to( self.accelerator.device, dtype=self.vae.dtype diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index 00cb2aa2..f31ce590 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -45,7 +45,10 @@ def process_for_bucket( # Apply EXIF transforms image_metadata["original_size"] = image.size image, crop_coordinates = MultiaspectImage.prepare_image( - image, bucket_manager.resolution, bucket_manager.resolution_type + image, + bucket_manager.resolution, + bucket_manager.resolution_type, + data_backend.id, ) image_metadata["crop_coordinates"] = crop_coordinates image_metadata["target_size"] = image.size @@ -85,7 +88,9 @@ def process_for_bucket( return aspect_ratio_bucket_indices @staticmethod - def prepare_image(image: Image, resolution: float, resolution_type: str = "pixel"): + def prepare_image( + image: Image, resolution: float, resolution_type: str = "pixel", id: str = "foo" + ): if not hasattr(image, "convert"): raise Exception( f"Unknown data received instead of PIL.Image object: {type(image)}" @@ -127,10 +132,17 @@ def prepare_image(image: Image, resolution: float, resolution_type: str = "pixel else: raise ValueError(f"Unknown resolution type: {resolution_type}") - crop_style = StateTracker.get_args().crop_style - crop_aspect = StateTracker.get_args().crop_aspect + crop = StateTracker.get_data_backend_config(data_backend_id=id).get( + "crop", StateTracker.get_args().crop + ) + crop_style = StateTracker.get_data_backend_config(data_backend_id=id).get( + "crop_style", StateTracker.get_args().crop_style + ) + crop_aspect = StateTracker.get_data_backend_config(data_backend_id=id).get( + "crop_aspect", StateTracker.get_args().crop_aspect + ) - if StateTracker.get_args().crop: + if crop: crop_width, crop_height = ( (resolution, resolution) if crop_aspect == "square" diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 826e9f65..562bf60c 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -140,8 +140,8 @@ def get_data_backend(cls, id: str): return cls.data_backends[id] @classmethod - def get_data_backend_config(cls, id: str): - return cls.data_backends[id]["config"] + def get_data_backend_config(cls, data_backend_id: str): + return cls.data_backends.get(data_backend_id, {}).get("config", {}) @classmethod def get_data_backends(cls): From 735d01938a3cf23b8016aea76e870da0b788621e Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 21:36:27 -0600 Subject: [PATCH 16/22] SD 2.x changes for new data backend config style --- sd21-env.sh.example | 1 + train_sd21.py | 357 +++++++++++++++----------------------------- train_sd2x.sh | 2 +- train_sdxl.py | 5 +- 4 files changed, 126 insertions(+), 239 deletions(-) diff --git a/sd21-env.sh.example b/sd21-env.sh.example index 9159dda7..0bb8c876 100644 --- a/sd21-env.sh.example +++ b/sd21-env.sh.example @@ -39,6 +39,7 @@ export MODEL_NAME="stabilityai/stable-diffusion-2-1" export BASE_DIR="/notebooks/datasets" export INSTANCE_DIR="${BASE_DIR}/training_data" export OUTPUT_DIR="${BASE_DIR}/models" +export DATALOADER_CONFIG="multidatabackend_sd2x.json" # Some data that we generate will be cached here. export STATE_PATH="${BASE_DIR}/training_state.json" diff --git a/train_sd21.py b/train_sd21.py index 35946689..d323ab99 100644 --- a/train_sd21.py +++ b/train_sd21.py @@ -24,7 +24,8 @@ from helpers.arguments import parse_args from helpers.training.state_tracker import StateTracker from helpers.training.deepspeed import deepspeed_zero_init_disabled_context_manager - +from helpers.data_backend.factory import configure_multi_databackend +from helpers.data_backend.factory import random_dataloader_iterator from helpers.caching.sdxl_embeds import TextEmbeddingCache from helpers.prompts import PromptHandler from helpers.training.multi_process import rank_info @@ -85,9 +86,9 @@ torch.autograd.set_detect_anomaly(True) # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.17.0.dev0") +check_min_version("0.25.0.dev0") -logger = get_logger("SimpleTuner") +logger = get_logger(__name__, log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) filelock_logger = get_logger("filelock") connection_logger = get_logger("urllib3.connectionpool") @@ -134,7 +135,8 @@ def compute_ids(prompt: str): def main(args): StateTracker.set_args(args) - StateTracker.delete_cache_files() + if not args.preserve_data_backend_cache: + StateTracker.delete_cache_files() logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration( @@ -169,10 +171,22 @@ def main(args): ) import wandb - # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate - # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. - # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. - # FIXED (bghira): https://github.com/huggingface/accelerate/pull/1708 + if ( + hasattr(accelerator.state, "deepspeed_plugin") + and accelerator.state.deepspeed_plugin is not None + ): + if ( + "gradient_accumulation_steps" + in accelerator.state.deepspeed_plugin.deepspeed_config + ): + args.gradient_accumulation_steps = ( + accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + ) + logger.info( + f"Updated gradient_accumulation_steps to the value provided by DeepSpeed: {args.gradient_accumulation_steps}" + ) # If passed along, set the training seed now. if args.seed is not None and args.seed != 0: @@ -433,128 +447,48 @@ def main(args): params_to_optimize, **extra_optimizer_args, ) + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + logging.info("Moving VAE to GPU..") + # Move vae and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + logging.info("Moving text encoder to GPU..") + text_encoder.to(accelerator.device, dtype=weight_dtype) + if args.use_ema: + logger.info("Moving EMA model weights to accelerator...") + ema_unet.to(accelerator.device, dtype=weight_dtype) - # Create a DataBackend, so that we can access our dataset. - if args.data_backend == "local": - from helpers.data_backend.local import LocalDataBackend - - data_backend = LocalDataBackend(accelerator=accelerator) - if not os.path.exists(args.instance_data_dir): - raise FileNotFoundError( - f"Instance {args.instance_data_dir} images root doesn't exist. Cannot continue." - ) - - elif args.data_backend == "aws": - from helpers.data_backend.aws import S3DataBackend - - data_backend = S3DataBackend( - bucket_name=args.aws_bucket_name, - accelerator=accelerator, - region_name=args.aws_region_name, - endpoint_url=args.aws_endpoint_url, - aws_access_key_id=args.aws_access_key_id, - aws_secret_access_key=args.aws_secret_access_key, - ) - else: - raise ValueError(f"Unsupported data backend: {args.data_backend}") - logger.info( - f"{rank_info()} created {args.data_backend} data backend.", - main_process_only=False, - ) - - # Get the datasets: you can either provide your own training and evaluation files (see below) - # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). - # Bucket manager. We keep the aspect config in the dataset so that switching datasets is simpler. - - logger.info( - f"{rank_info()} is creating a bucket manager.", - main_process_only=False, - ) - bucket_manager = BucketManager( - instance_data_root=args.instance_data_dir, - data_backend=data_backend, - accelerator=accelerator, - batch_size=args.train_batch_size, - resolution=args.resolution, - resolution_type=args.resolution_type, - minimum_image_size=args.minimum_image_size, - cache_file=os.path.join( - args.instance_data_dir, "aspect_ratio_bucket_indices.json" - ), - metadata_file=os.path.join( - args.instance_data_dir, "aspect_ratio_bucket_metadata.json" - ), - apply_dataset_padding=args.apply_dataset_padding or False, - delete_problematic_images=args.delete_problematic_images or False, - ) - StateTracker.set_bucket_manager(bucket_manager) - if bucket_manager.has_single_underfilled_bucket(): - raise Exception( - f"Cannot train using a dataset that has a single bucket with fewer than {args.train_batch_size} images." - " You have to reduce your batch size, or increase your dataset size." - ) - if accelerator.is_main_process: - logger.info( - f"{rank_info()} is now refreshing the buckets..", - main_process_only=False, - ) - bucket_manager.refresh_buckets() - logger.info( - f"{rank_info()} has completed its bucket manager tasks.", - main_process_only=False, - ) + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae_dtype = torch.float32 + if hasattr(args, "vae_dtype"): logger.info( - f"{rank_info()} is now splitting the data.", - main_process_only=False, - ) - accelerator.wait_for_everyone() - bucket_manager.reload_cache() - - # Now split the contents of these buckets between all processes - bucket_manager.split_buckets_between_processes( - gradient_accumulation_steps=args.gradient_accumulation_steps, - ) - # Now, let's print the total of each bucket, along with the current rank, so that we might catch debug info: - for bucket in bucket_manager.aspect_ratio_bucket_indices: - print( - f"{rank_info()}: {len(bucket_manager.aspect_ratio_bucket_indices[bucket])} images in bucket {bucket}" - ) - - if len(bucket_manager) == 0: - raise Exception( - "No images were discovered by the bucket manager in the dataset." + f"Initialising VAE in {args.vae_dtype} precision, you may specify a different value if preferred: bf16, fp16, fp32, default" ) - logger.info("Creating dataset iterator object") + # Let's use a case-switch for convenience: bf16, fp16, fp32, none/default + if args.vae_dtype == "bf16": + vae_dtype = torch.bfloat16 + elif args.vae_dtype == "fp16": + vae_dtype = torch.float16 + elif args.vae_dtype == "fp32": + vae_dtype = torch.float32 + elif args.vae_dtype == "none" or args.vae_dtype == "default": + vae_dtype = torch.float32 + logger.debug(f"Initialising VAE with custom dtype {vae_dtype}") + vae.to(accelerator.device, dtype=vae_dtype) + logger.info(f"Loaded VAE into VRAM.") + StateTracker.set_vae_dtype(vae_dtype) + StateTracker.set_vae(vae) - train_dataset = MultiAspectDataset( - bucket_manager=bucket_manager, - data_backend=data_backend, - instance_data_root=args.instance_data_dir, - accelerator=accelerator, - size=args.resolution, - size_type=args.resolution_type, - print_names=args.print_filenames or False, - prepend_instance_prompt=args.prepend_instance_prompt or False, - use_captions=not args.only_instance_prompt or False, - use_precomputed_token_ids=False, - debug_dataset_loader=args.debug_dataset_loader, - caption_strategy=args.caption_strategy, - return_tensor=True, - ) - logger.info("Creating aspect bucket sampler") + # Create a DataBackend, so that we can access our dataset. + configure_multi_databackend(args, accelerator) - custom_balanced_sampler = MultiAspectSampler( - bucket_manager=bucket_manager, - data_backend=data_backend, - accelerator=accelerator, - batch_size=args.train_batch_size, - seen_images_path=args.seen_state_path, - state_path=args.state_path, - debug_aspect_buckets=args.debug_aspect_buckets, - delete_unwanted_images=args.delete_unwanted_images, - resolution=args.resolution, - resolution_type=args.resolution_type, - ) from helpers.training.collate import ( extract_filepaths, compute_latents, @@ -587,17 +521,6 @@ def collate_fn(batch): "batch_luminance": batch_luminance, } - logger.info("Plugging sampler into dataloader") - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=1, # The sampler handles batching - shuffle=False, # The sampler handles shuffling - sampler=custom_balanced_sampler, - collate_fn=lambda examples: collate_fn(examples), - num_workers=0, - persistent_workers=False, - ) - logger.info("Initialise text embedding cache") prompt_handler = PromptHandler( args=args, @@ -614,18 +537,10 @@ def collate_fn(batch): model_type="legacy", prompt_handler=prompt_handler, ) - - logger.info(f"Pre-computing text embeds / updating cache.") - with accelerator.local_main_process_first(): - all_captions = PromptHandler.get_all_captions( - data_backend=data_backend, - instance_data_root=args.instance_data_dir, - prepend_instance_prompt=args.prepend_instance_prompt or False, - use_captions=not args.only_instance_prompt, - ) - accelerator.wait_for_everyone() - embed_cache.split_cache_between_processes(all_captions) - embed_cache.compute_embeddings_for_legacy_prompts() + if "text" not in args.skip_file_discovery: + logger.info(f"Pre-computing text embeds / updating cache.") + all_captions = StateTracker.get_caption_files() + embed_cache.compute_embeddings_for_legacy_prompts() with accelerator.main_process_first(): ( validation_prompts, @@ -636,10 +551,27 @@ def collate_fn(batch): logger.info("Configuring runtime step count and epoch limit") # Scheduler and math around the number of training steps. overrode_max_train_steps = False + # Check if we have a valid gradient accumulation steps. + if args.gradient_accumulation_steps < 1: + raise ValueError( + f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1" + ) + # We calculate the number of steps per epoch by dividing the number of images by the effective batch divisor. + # Gradient accumulation steps mean that we only update the model weights every /n/ steps. num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps + sum( + [ + len(backend["bucket_manager"]) + for _, backend in StateTracker.get_data_backends().items() + ] + ) + / args.gradient_accumulation_steps ) - if args.max_train_steps is None: + if args.max_train_steps is None or args.max_train_steps == 0: + if args.num_train_epochs is None or args.num_train_epochs == 0: + raise ValueError( + "You must specify either --max_train_steps or --num_train_epochs with a value > 0" + ) args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch logger.debug( f"Overriding max_train_steps to {args.max_train_steps} = {args.num_train_epochs} * {num_update_steps_per_epoch}" @@ -710,12 +642,19 @@ def collate_fn(batch): ) logger.info("EMA model creation complete.") + train_dataloaders = [] + for _, backend in StateTracker.get_data_backends().items(): + train_dataloaders.append(backend["train_dataloader"]) + logger.info("Preparing accelerator..") # Base components to prepare - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) + results = accelerator.prepare(unet, lr_scheduler, optimizer, *train_dataloaders) + unet = results[0] + lr_scheduler = results[1] + optimizer = results[2] + # The rest of the entries are dataloaders: + train_dataloaders = results[3:] # Conditionally prepare the text_encoder if required if args.train_text_encoder: @@ -725,80 +664,11 @@ def collate_fn(batch): if args.use_ema: ema_model = accelerator.prepare(ema_model) - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - logging.info("Moving VAE to GPU..") - # Move vae and text_encoder to device and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) - if not args.train_text_encoder: - logging.info("Moving text encoder to GPU..") - text_encoder.to(accelerator.device, dtype=weight_dtype) - if args.use_ema: - logger.info("Moving EMA model weights to accelerator...") - ema_unet.to(accelerator.device, dtype=weight_dtype) - - # Move vae, unet and text_encoder to device and cast to weight_dtype - # The VAE is in float32 to avoid NaN losses. - vae_dtype = torch.float32 - if hasattr(args, "vae_dtype"): - logger.info( - f"Initialising VAE in {args.vae_dtype} precision, you may specify a different value if preferred: bf16, fp16, fp32, default" - ) - # Let's use a case-switch for convenience: bf16, fp16, fp32, none/default - if args.vae_dtype == "bf16": - vae_dtype = torch.bfloat16 - elif args.vae_dtype == "fp16": - vae_dtype = torch.float16 - elif args.vae_dtype == "fp32": - vae_dtype = torch.float32 - elif args.vae_dtype == "none" or args.vae_dtype == "default": - vae_dtype = torch.float32 - logger.debug(f"Initialising VAE with custom dtype {vae_dtype}") - vae.to(accelerator.device, dtype=vae_dtype) - logger.info(f"Loaded VAE into VRAM.") - logger.info(f"Pre-computing VAE latent space.") - vaecache = VAECache( - vae=vae, - accelerator=accelerator, - bucket_manager=bucket_manager, - instance_data_root=args.instance_data_dir, - data_backend=data_backend, - delete_problematic_images=args.delete_problematic_images, - resolution=args.resolution, - resolution_type=args.resolution_type, - vae_batch_size=args.vae_batch_size, - write_batch_size=args.write_batch_size, - minimum_image_size=args.minimum_image_size, - ) - StateTracker.set_vaecache(vaecache) - StateTracker.set_vae_dtype(vae_dtype) - StateTracker.set_vae(vae) - - if accelerator.is_local_main_process: - vaecache.discover_all_files() - accelerator.wait_for_everyone() - - if "vae" not in args.skip_file_discovery: - vaecache.split_cache_between_processes() - vaecache.process_buckets() - accelerator.wait_for_everyone() - - if "metadata" not in args.skip_file_discovery and accelerator.is_main_process: - bucket_manager.scan_for_metadata() - accelerator.wait_for_everyone() - if not accelerator.is_main_process: - bucket_manager.load_image_metadata() - accelerator.wait_for_everyone() - # We need to recalculate our total training steps as the size of the training dataloader may have changed. logging.info("Recalculating max step count.") num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps + sum([len(dataloader) for dataloader in train_dataloaders]) + / args.gradient_accumulation_steps ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch @@ -846,7 +716,8 @@ def collate_fn(batch): del vae vae = None - vaecache.vae = None + for _, backend in StateTracker.get_data_backends().items(): + backend["vaecache"].vae = None gc.collect() torch.cuda.empty_cache() memory_after_unload = torch.cuda.memory_allocated() / 1024**3 @@ -869,6 +740,7 @@ def collate_fn(batch): scheduler_kwargs = {} # Potentially load in the weights and states from a previous save + first_epoch = 1 if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) @@ -887,15 +759,21 @@ def collate_fn(batch): else: logging.info(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) - custom_balanced_sampler.load_states( - state_path=os.path.join(args.output_dir, path, "training_state.json"), + for _, backend in StateTracker.get_data_backends().items(): + backend["sampler"].load_states( + state_path=os.path.join( + args.output_dir, path, "training_state.json" + ), + ) + first_epoch = max( + [ + backend["sampler"].current_epoch + for _, backend in StateTracker.get_data_backends().items() + ] ) - first_epoch = custom_balanced_sampler.current_epoch resume_global_step = global_step = int(path.split("-")[1]) - custom_balanced_sampler.log_state() total_steps_remaining_at_start = args.max_train_steps # We store the number of dataset resets that have occurred inside the checkpoint. - first_epoch = custom_balanced_sampler.current_epoch if first_epoch > 1: steps_to_remove = first_epoch * num_update_steps_per_epoch total_steps_remaining_at_start -= steps_to_remove @@ -909,8 +787,14 @@ def collate_fn(batch): ) logger.info("***** Running training *****") + total_num_batches = len( + [ + backend["train_dataset"] + for _, backend in StateTracker.get_data_backends().items() + ] + ) logger.info( - f" Num batches = {len(train_dataset)} ({len(train_dataset) * args.train_batch_size} samples)" + f" Num batches = {total_num_batches} ({total_num_batches * args.train_batch_size} samples)" ) logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Current Epoch = {first_epoch}") @@ -956,7 +840,7 @@ def collate_fn(batch): logger.debug(f"Bumping text encoder.") text_encoder.train() training_models.append(text_encoder) - for step, batch in enumerate(train_dataloader): + for step, batch in random_dataloader_iterator(train_dataloaders): if accelerator.is_main_process: progress_bar.set_description( f"Epoch {current_epoch}/{args.num_train_epochs}, Steps" @@ -1207,9 +1091,12 @@ def collate_fn(batch): args.output_dir, f"checkpoint-{global_step}" ) accelerator.save_state(save_path) - custom_balanced_sampler.save_state( - state_path=os.path.join(save_path, "training_state.json"), - ) + for _, backend in StateTracker.get_data_backends().items(): + backend["sampler"].save_state( + state_path=os.path.join( + save_path, "training_state.json" + ), + ) logger.info(f"Saved state to {save_path}") logs = { diff --git a/train_sd2x.sh b/train_sd2x.sh index 32f4f210..cb12c7b6 100644 --- a/train_sd2x.sh +++ b/train_sd2x.sh @@ -57,4 +57,4 @@ train_sd21.py \ --state_path="${STATE_PATH}" \ --caption_dropout_probability="${CAPTION_DROPOUT_PROBABILITY}" \ --caption_strategy="${CAPTION_STRATEGY}" \ ---data_backend="${DATA_BACKEND}" ${TRAINER_EXTRA_ARGS} \ No newline at end of file +--data_backend_config="${DATALOADER_CONFIG}" ${TRAINER_EXTRA_ARGS} \ No newline at end of file diff --git a/train_sdxl.py b/train_sdxl.py index 66974f43..f6e3926f 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -27,7 +27,6 @@ from helpers.training.deepspeed import deepspeed_zero_init_disabled_context_manager from helpers.data_backend.factory import configure_multi_databackend from helpers.data_backend.factory import random_dataloader_iterator -from helpers.caching.vae import VAECache from helpers.caching.sdxl_embeds import TextEmbeddingCache from helpers.training.custom_schedule import ( get_polynomial_decay_schedule_with_warmup, @@ -92,7 +91,7 @@ from transformers.utils import ContextManagers # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.20.0.dev0") +check_min_version("0.25.0.dev0") logger = get_logger(__name__, log_level=os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) @@ -1161,7 +1160,7 @@ def main(): args.output_dir, f"checkpoint-{global_step}" ) accelerator.save_state(save_path) - for backend in StateTracker.get_data_backends(): + for _, backend in StateTracker.get_data_backends().items(): backend["sampler"].save_state( state_path=os.path.join( save_path, "training_state.json" From 94c0b324e9926faf49404d1dcc2e9e933c70251a Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 23:01:56 -0600 Subject: [PATCH 17/22] SD 2.x: fix text embed caching/vae caching --- helpers/caching/sdxl_embeds.py | 4 ++-- helpers/caching/vae.py | 10 +++++++++ helpers/data_backend/local.py | 4 +++- helpers/multiaspect/bucket.py | 1 + helpers/training/collate.py | 39 ++++++++++++++++++++++------------ train_sd21.py | 13 ++++-------- train_sd2x.sh | 2 +- 7 files changed, 47 insertions(+), 26 deletions(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index 85b06926..efbfc706 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -8,7 +8,7 @@ class TextEmbeddingCache: - prompts = None + prompts = {} def __init__( self, @@ -43,7 +43,7 @@ def encode_legacy_prompt(self, text_encoder, tokenizer, prompt): padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", - ).input_ids + ).input_ids.to(text_encoder.device) output = text_encoder(input_tokens)[0] logger.debug(f"Legacy prompt shape: {output.shape}") logger.debug(f"Legacy prompt encoded: {output}") diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index 0b87ad81..323d1af7 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -224,10 +224,14 @@ def split_cache_between_processes(self): f"All unprocessed files: {all_unprocessed_files[:5]} (truncated)" ) # Use the accelerator to split the data + with self.accelerator.split_between_processes( all_unprocessed_files ) as split_files: self.local_unprocessed_files = split_files + self.debug_log( + f"Before splitting, we had {len(all_unprocessed_files)} unprocessed files. After splitting, we have {len(self.local_unprocessed_files)} unprocessed files." + ) # Print the first 5 as a debug log: self.debug_log( f"Local unprocessed files: {self.local_unprocessed_files[:5]} (truncated)" @@ -255,6 +259,12 @@ def encode_images(self, images, filepaths, load_from_cache=True): for i, filename in enumerate(full_filenames) if not self.data_backend.exists(filename) ] + logger.debug( + f"Found {len(uncached_image_indices)} uncached images (truncated): {uncached_image_indices[:5]}" + ) + logger.debug( + f"Received full filenames {len(full_filenames)} (truncated): {full_filenames[:5]}" + ) uncached_images = [images[i] for i in uncached_image_indices] if len(uncached_image_indices) > 0 and load_from_cache: diff --git a/helpers/data_backend/local.py b/helpers/data_backend/local.py index 2116e5ad..ead4a8d3 100644 --- a/helpers/data_backend/local.py +++ b/helpers/data_backend/local.py @@ -56,7 +56,9 @@ def delete(self, filepath): def exists(self, filepath): """Check if the file exists.""" - return os.path.exists(filepath) + result = os.path.exists(filepath) + logger.debug(f"Checking if {filepath} exists = {result}") + return result def open_file(self, filepath, mode): """Open the file in the specified mode.""" diff --git a/helpers/multiaspect/bucket.py b/helpers/multiaspect/bucket.py index 8ee734f4..d45c5074 100644 --- a/helpers/multiaspect/bucket.py +++ b/helpers/multiaspect/bucket.py @@ -323,6 +323,7 @@ def split_buckets_between_processes(self, gradient_accumulation_steps=1): # Trim the list to a length that's divisible by the effective batch size num_batches = len(images) // effective_batch_size trimmed_images = images[: num_batches * effective_batch_size] + logger.debug(f"Trimmed from {len(images)} to {len(trimmed_images)}") with self.accelerator.split_between_processes( trimmed_images, apply_padding=False diff --git a/helpers/training/collate.py b/helpers/training/collate.py index 21cae78c..8918d135 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -83,7 +83,9 @@ def extract_filepaths(examples): def fetch_latent(fp, data_backend_id: str): """Worker method to fetch latent for a single image.""" - debug_log(f" -> pull latents from cache via data backend {data_backend_id}") + debug_log( + f" -> pull latents for fp {fp} from cache via data backend {data_backend_id}" + ) latent = StateTracker.get_vaecache(id=data_backend_id).retrieve_from_cache(fp) # Move to CPU and pin memory if it's not on the GPU @@ -111,13 +113,23 @@ def compute_latents(filepaths, data_backend_id: str): return torch.stack(latents) -def compute_prompt_embeddings(captions): +def compute_prompt_embeddings(captions, model_type: str = "sdxl"): debug_log(" -> get embed from cache") - ( - prompt_embeds_all, - add_text_embeds_all, - ) = StateTracker.get_embedcache().compute_embeddings_for_sdxl_prompts(captions) - debug_log(" -> concat embeds") + embedcache = StateTracker.get_embedcache() + if embedcache.model_type == "sdxl": + ( + prompt_embeds_all, + add_text_embeds_all, + ) = embedcache.compute_embeddings_for_sdxl_prompts(captions) + debug_log(" -> concat embeds") + else: + debug_log(" -> concat embeds") + prompt_embeds_all = embedcache.compute_embeddings_for_legacy_prompts(captions)[ + 0 + ] + print(f"Poop: {prompt_embeds_all}") + prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0) + return prompt_embeds_all, None prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0) add_text_embeds_all = torch.concat([add_text_embeds_all for _ in range(1)], dim=0) return prompt_embeds_all, add_text_embeds_all @@ -191,12 +203,13 @@ def collate_fn(batch): captions = [example["instance_prompt_text"] for example in examples] debug_log("Pull cached text embeds") prompt_embeds_all, add_text_embeds_all = compute_prompt_embeddings(captions) - - debug_log("Compute and stack SDXL time ids") - batch_time_ids = gather_conditional_size_features( - examples, latent_batch, StateTracker.get_weight_dtype() - ) - debug_log(f"Time ids stacked to {batch_time_ids.shape}: {batch_time_ids}") + batch_time_ids = None + if add_text_embeds_all is not None: + debug_log("Compute and stack SDXL time ids") + batch_time_ids = gather_conditional_size_features( + examples, latent_batch, StateTracker.get_weight_dtype() + ) + debug_log(f"Time ids stacked to {batch_time_ids.shape}: {batch_time_ids}") return { "latent_batch": latent_batch, diff --git a/train_sd21.py b/train_sd21.py index d323ab99..767e2fbf 100644 --- a/train_sd21.py +++ b/train_sd21.py @@ -457,9 +457,8 @@ def main(args): logging.info("Moving VAE to GPU..") # Move vae and text_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) - if not args.train_text_encoder: - logging.info("Moving text encoder to GPU..") - text_encoder.to(accelerator.device, dtype=weight_dtype) + logging.info("Moving text encoder to GPU..") + text_encoder.to(accelerator.device, dtype=weight_dtype) if args.use_ema: logger.info("Moving EMA model weights to accelerator...") ema_unet.to(accelerator.device, dtype=weight_dtype) @@ -537,10 +536,11 @@ def collate_fn(batch): model_type="legacy", prompt_handler=prompt_handler, ) + StateTracker.set_embedcache(embed_cache) if "text" not in args.skip_file_discovery: logger.info(f"Pre-computing text embeds / updating cache.") all_captions = StateTracker.get_caption_files() - embed_cache.compute_embeddings_for_legacy_prompts() + embed_cache.compute_embeddings_for_legacy_prompts(return_concat=False) with accelerator.main_process_first(): ( validation_prompts, @@ -685,11 +685,6 @@ def collate_fn(batch): # Copy args into public_args: public_args = copy.deepcopy(args) # Remove the args that we don't want to track: - del public_args.aws_access_key_id - del public_args.aws_secret_access_key - del public_args.aws_bucket_name - del public_args.aws_region_name - del public_args.aws_endpoint_url project_name = args.tracker_project_name or "simpletuner-training" tracker_run_name = args.tracker_run_name or "simpletuner-training-run" public_args_hash = hashlib.md5( diff --git a/train_sd2x.sh b/train_sd2x.sh index cb12c7b6..45dc85f2 100644 --- a/train_sd2x.sh +++ b/train_sd2x.sh @@ -41,7 +41,7 @@ train_sd21.py \ --train_batch_size="${TRAIN_BATCH_SIZE}" \ --seed "${TRAINING_SEED}" \ --learning_rate="${LEARNING_RATE}" \ ---learning_rate_end="${LEARNING_RATE_END}" \ +--lr_end="${LEARNING_RATE_END}" \ --lr_scheduler="${LR_SCHEDULE}" \ --num_train_epochs="${NUM_EPOCHS}" \ --mixed_precision="${MIXED_PRECISION}" \ From 523487f8921fd4e123bdf22105618158a0f2e7ec Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 23:15:07 -0600 Subject: [PATCH 18/22] SD 2.x: fix text embed caching/vae caching --- helpers/training/collate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/helpers/training/collate.py b/helpers/training/collate.py index 8918d135..8620e104 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -113,7 +113,7 @@ def compute_latents(filepaths, data_backend_id: str): return torch.stack(latents) -def compute_prompt_embeddings(captions, model_type: str = "sdxl"): +def compute_prompt_embeddings(captions): debug_log(" -> get embed from cache") embedcache = StateTracker.get_embedcache() if embedcache.model_type == "sdxl": @@ -127,7 +127,6 @@ def compute_prompt_embeddings(captions, model_type: str = "sdxl"): prompt_embeds_all = embedcache.compute_embeddings_for_legacy_prompts(captions)[ 0 ] - print(f"Poop: {prompt_embeds_all}") prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0) return prompt_embeds_all, None prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0) From 7f4ab96b486f151828326aad23ce33f520513916 Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 24 Dec 2023 23:18:25 -0600 Subject: [PATCH 19/22] SDXL: Logging --- train_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train_sdxl.py b/train_sdxl.py index f6e3926f..91dd1b8e 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -1161,6 +1161,7 @@ def main(): ) accelerator.save_state(save_path) for _, backend in StateTracker.get_data_backends().items(): + logger.debug(f"Backend: {backend}") backend["sampler"].save_state( state_path=os.path.join( save_path, "training_state.json" From d6485ad73d922a1d453fdd9b34379ebee1ae7213 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 25 Dec 2023 00:32:58 -0600 Subject: [PATCH 20/22] StateTracker: fix delete method --- helpers/training/state_tracker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 562bf60c..7bfb6a0d 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -39,14 +39,14 @@ def delete_cache_files(cls): # Glob the directory for "all_image_files.*.json" and "all_vae_cache_files.*.json", and delete those too # This is a workaround for the fact that the cache files are named with the data_backend_id - filelist = Path(cls.args.output_dir).glob("all_image_files.*.json") + filelist = Path(cls.args.output_dir).glob("all_image_files_*.json") for file in filelist: try: file.unlink() except: pass - filelist = Path(cls.args.output_dir).glob("all_vae_cache_files.*.json") + filelist = Path(cls.args.output_dir).glob("all_vae_cache_files_*.json") for file in filelist: try: file.unlink() From cbf8b8d0f9dfa19314d47aadf27ab54dba3742bc Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 25 Dec 2023 14:23:46 -0600 Subject: [PATCH 21/22] Added backend config validations so that we do not accidentally mess up any caches in the future --- helpers/arguments.py | 24 ++++++ helpers/caching/sdxl_embeds.py | 2 +- helpers/caching/vae.py | 21 +++++ helpers/data_backend/factory.py | 66 +++++++++++++++- helpers/data_backend/local.py | 3 +- helpers/multiaspect/bucket.py | 122 ++++++++++++++++++++++++++++-- helpers/multiaspect/image.py | 33 +++++++- helpers/training/state_tracker.py | 6 ++ tests/test_bucket.py | 2 +- train_sdxl.py | 8 +- 10 files changed, 274 insertions(+), 13 deletions(-) diff --git a/helpers/arguments.py b/helpers/arguments.py index 41afdce2..e9f75bb3 100644 --- a/helpers/arguments.py +++ b/helpers/arguments.py @@ -148,6 +148,18 @@ def parse_args(input_args=None): " but if you are at that point of contention, it's possible that your GPU has too little RAM. Default: 4." ), ) + parser.add_argument( + "--vae_cache_behaviour", + type=str, + choices=["recreate", "sync"], + default="recreate", + help=( + "When a mismatched latent vector is detected, a scan will be initiated to locate inconsistencies and resolve them." + " The default setting 'recreate' will delete any inconsistent cache entries and rebuild it." + " Alternatively, 'sync' will update the bucket configuration so that the image is in a bucket that matches its latent size." + " The recommended behaviour is to use the default value and allow the cache to be recreated." + ), + ) parser.add_argument( "--keep_vae_loaded", action="store_true", @@ -199,6 +211,18 @@ def parse_args(input_args=None): " Currently, cache is not stored in the dataset itself but rather, locally. This may change in a future release." ), ) + parser.add_argument( + "--override_dataset_config", + action="store_true", + default=False, + help=( + "When provided, the dataset's config will not be checked against the live backend config." + " This is useful if you want to simply update the behaviour of an existing dataset," + " but the recommendation is to not change the dataset configuration after caching has begun," + " as most options cannot be changed without unexpected behaviour later on. Additionally, it prevents" + " accidentally loading an SDXL configuration on a SD 2.x model and vice versa." + ), + ) parser.add_argument( "--cache_dir_text", type=str, diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index efbfc706..e1e32fb7 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -28,7 +28,7 @@ def __init__( os.makedirs(self.cache_dir, exist_ok=True) def create_hash(self, caption): - return hashlib.md5(caption.encode()).hexdigest() + return f"{hashlib.md5(caption.encode()).hexdigest()}-{self.model_type}" def save_to_cache(self, filename, embeddings): torch.save(embeddings, filename) diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index 323d1af7..355242cc 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -599,3 +599,24 @@ def process_buckets(self): except Exception as e: logger.error(f"Fatal error when processing bucket {bucket}: {e}") continue + + def scan_cache_contents(self): + """ + A generator method that iterates over the VAE cache, yielding each cache file's path and its contents. + + This is likely a very expensive operation for extra-large cloud datasets, but it could save time and + computational resources if finding a problem with surgical precision can prevent the need for removing + all cache entries in a dataset for a complete rebuild. + + Yields: + Tuple[str, Any]: A tuple containing the file path and its contents. + """ + try: + all_cache_files = StateTracker.get_vae_cache_files(data_backend_id=self.id) + for cache_file in all_cache_files: + full_path = os.path.join(self.cache_dir, cache_file) + cache_content = self._read_from_storage(full_path) + yield (full_path, cache_content) + except Exception as e: + logger.error(f"Error in scan_cache_contents: {e}") + logging.debug(f"Error traceback: {traceback.format_exc()}") diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 5163228a..6c2e078d 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -20,10 +20,24 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: output = {"id": backend["id"], "config": {}} if "crop" in backend: output["config"]["crop"] = backend["crop"] + else: + output["config"]["crop"] = args.crop if "crop_aspect" in backend: output["config"]["crop_aspect"] = backend["crop_aspect"] + else: + output["config"]["crop_aspect"] = args.crop_aspect if "crop_style" in backend: output["config"]["crop_style"] = backend["crop_style"] + else: + output["config"]["crop_style"] = args.crop_style + if "resolution" in backend: + output["config"]["resolution"] = backend["resolution"] + else: + output["config"]["resolution"] = args.resolution + if "resolution_type" in backend: + output["config"]["resolution_type"] = backend["resolution_type"] + else: + output["config"]["resolution_type"] = args.resolution_type return output @@ -111,6 +125,13 @@ def configure_multi_databackend(args: dict, accelerator): ), delete_problematic_images=args.delete_problematic_images or False, ) + logger.debug( + f"Loaded previous data backend config: {init_backend['bucket_manager'].config}" + ) + StateTracker.set_data_backend_config( + data_backend_id=init_backend["id"], + config=init_backend["bucket_manager"].config, + ) if init_backend["bucket_manager"].has_single_underfilled_bucket(): raise Exception( f"Cannot train using a dataset that has a single bucket with fewer than {args.train_batch_size} images." @@ -126,6 +147,31 @@ def configure_multi_databackend(args: dict, accelerator): gradient_accumulation_steps=args.gradient_accumulation_steps, ) + # Check if there is an existing 'config' in the bucket_manager.config + if init_backend["bucket_manager"].config != {}: + logger.debug( + f"Found existing config: {init_backend['bucket_manager'].config}" + ) + # Check if any values differ between the 'backend' values and the 'config' values: + for key, _ in init_backend["bucket_manager"].config.items(): + logger.debug(f"Checking config key: {key}") + if ( + key in backend + and init_backend["bucket_manager"].config[key] != backend[key] + ): + if not args.override_dataset_config: + raise Exception( + f"Dataset {init_backend['id']} has inconsistent config, and --override_dataset_config was not provided." + f"\n-> Expected value {key}={init_backend['bucket_manager'].config[key]} differs from current value={backend[key]}." + f"\n-> Recommended action is to correct the current config values to match the values that were used to create this dataset:" + f"\n{init_backend['bucket_manager'].config}" + ) + else: + logger.warning( + f"Overriding config value {key}={init_backend['bucket_manager'].config[key]} with {backend[key]}" + ) + init_backend["bucket_manager"].config[key] = backend[key] + print_bucket_info(init_backend["bucket_manager"]) if len(init_backend["bucket_manager"]) == 0: raise Exception( @@ -222,8 +268,25 @@ def configure_multi_databackend(args: dict, accelerator): init_backend["vaecache"].discover_all_files() accelerator.wait_for_everyone() - if "metadata" not in args.skip_file_discovery and accelerator.is_main_process: + if ( + "metadata" not in args.skip_file_discovery + and accelerator.is_main_process + and backend.get("scan_for_errors", False) + ): + logger.info( + f"Beginning error scan for dataset {init_backend['id']}. Set 'scan_for_errors' to False in the dataset config to disable this." + ) + init_backend["bucket_manager"].handle_vae_cache_inconsistencies( + vae_cache=init_backend["vaecache"], + vae_cache_behavior=backend.get( + "vae_cache_behaviour", args.vae_cache_behaviour + ), + ) init_backend["bucket_manager"].scan_for_metadata() + elif not backend.get("scan_for_errors", False): + logger.info( + f"Skipping error scan for dataset {init_backend['id']}. Set 'scan_for_errors' to True in the dataset config to enable this if your training runs into mismatched latent dimensions." + ) accelerator.wait_for_everyone() if not accelerator.is_main_process: init_backend["bucket_manager"].load_image_metadata() @@ -235,6 +298,7 @@ def configure_multi_databackend(args: dict, accelerator): accelerator.wait_for_everyone() StateTracker.register_data_backend(init_backend) + init_backend["bucket_manager"].save_cache() # After configuring all backends, register their captions. StateTracker.set_caption_files(all_captions) diff --git a/helpers/data_backend/local.py b/helpers/data_backend/local.py index ead4a8d3..ec68b412 100644 --- a/helpers/data_backend/local.py +++ b/helpers/data_backend/local.py @@ -32,7 +32,7 @@ def write(self, filepath: str, data: Any) -> None: logger.debug(f"Writing a torch file to disk.") return self.torch_save(data, file) elif isinstance(data, str): - logger.debug(f"Writing a string to disk: {data}") + logger.debug(f"Writing a string to disk as {filepath}: {data}") data = data.encode("utf-8") else: logger.debug( @@ -47,6 +47,7 @@ def write(self, filepath: str, data: Any) -> None: def delete(self, filepath): """Delete the specified file.""" if os.path.exists(filepath): + logger.debug(f"Deleting file: {filepath}") os.remove(filepath) else: raise FileNotFoundError(f"{filepath} not found.") diff --git a/helpers/multiaspect/bucket.py b/helpers/multiaspect/bucket.py index d45c5074..92e61495 100644 --- a/helpers/multiaspect/bucket.py +++ b/helpers/multiaspect/bucket.py @@ -2,7 +2,7 @@ from helpers.multiaspect.image import MultiaspectImage from helpers.data_backend.base import BaseDataBackend from pathlib import Path -import json, logging, os, time +import json, logging, os, time, re from multiprocessing import Manager from PIL import Image from tqdm import tqdm @@ -123,9 +123,11 @@ def reload_cache(self): self.aspect_ratio_bucket_indices = cache_data.get( "aspect_ratio_bucket_indices", {} ) + self.config = cache_data.get("config", {}) + logger.debug(f"Setting config to {self.config}") self.instance_images_path = set(cache_data.get("instance_images_path", [])) - def _save_cache(self, enforce_constraints: bool = False): + def save_cache(self, enforce_constraints: bool = False): """ Save cache data to file. """ @@ -139,9 +141,13 @@ def _save_cache(self, enforce_constraints: bool = False): } # Encode the cache as JSON. cache_data = { + "config": StateTracker.get_data_backend_config( + data_backend_id=self.data_backend.id + ), "aspect_ratio_bucket_indices": aspect_ratio_bucket_indices_str, "instance_images_path": [str(path) for path in self.instance_images_path], } + logger.debug(f"save_cache has config to write: {cache_data['config']}") cache_data_str = json.dumps(cache_data) # Use our DataBackend to write the cache file. self.data_backend.write(self.cache_file, cache_data_str) @@ -289,7 +295,7 @@ def compute_aspect_ratio_bucket_indices(self): f"In-flight metadata update after {processing_duration} seconds. Saving {len(self.image_metadata)} metadata entries and {len(self.aspect_ratio_bucket_indices)} aspect bucket lists." ) self.instance_images_path.update(written_files) - self._save_cache(enforce_constraints=False) + self.save_cache(enforce_constraints=False) self.save_image_metadata() last_write_time = current_time @@ -300,7 +306,7 @@ def compute_aspect_ratio_bucket_indices(self): self.instance_images_path.update(new_files) self.save_image_metadata() - self._save_cache(enforce_constraints=True) + self.save_cache(enforce_constraints=True) logger.info("Completed aspect bucket update.") def split_buckets_between_processes(self, gradient_accumulation_steps=1): @@ -383,7 +389,7 @@ def update_buckets_with_existing_files(self, existing_files: set): img for img in images if img in existing_files ] # Save the updated cache - self._save_cache() + self.save_cache() def refresh_buckets(self, rank: int = None): """ @@ -511,7 +517,7 @@ def handle_incorrect_bucket(self, image_path: str, bucket: str, actual_bucket: s else: logger.warning(f"Created new bucket for that pesky image.") self.aspect_ratio_bucket_indices[actual_bucket] = [image_path] - self._save_cache() + self.save_cache() def handle_small_image( self, image_path: str, bucket: str, delete_unwanted_images: bool @@ -699,5 +705,107 @@ def scan_for_metadata(self): worker.join() self.save_image_metadata() - self._save_cache(enforce_constraints=True) + self.save_cache(enforce_constraints=True) logger.info("Completed metadata update.") + + def handle_vae_cache_inconsistencies(self, vae_cache, vae_cache_behavior: str): + """ + Handles inconsistencies between the aspect buckets and the VAE cache. + + Args: + vae_cache: The VAECache object. + vae_cache_behavior (str): Behavior for handling inconsistencies ('sync' or 'recreate'). + """ + if vae_cache_behavior not in ["sync", "recreate"]: + raise ValueError("Invalid VAE cache behavior specified.") + + for cache_file, cache_content in vae_cache.scan_cache_contents(): + if vae_cache_behavior == "sync": + # Sync aspect buckets with the cache + expected_bucket = MultiaspectImage.determine_bucket_for_aspect_ratio( + self._get_aspect_ratio_from_tensor(cache_content) + ) + self._modify_cache_entry_bucket(cache_file, expected_bucket) + + elif vae_cache_behavior == "recreate": + # Delete the cache file if it doesn't match the aspect bucket indices + if self.is_cache_inconsistent(cache_file, cache_content): + self.data_backend.delete(cache_file) + + # Update any state or metadata post-processing + self.save_cache() + + def is_cache_inconsistent(self, cache_file, cache_content): + """ + Check if a cache file's content is inconsistent with the aspect ratio bucket indices. + + Args: + cache_file (str): The cache file path. + cache_content: The content of the cache file (PyTorch Tensor). + + Returns: + bool: True if the cache file is inconsistent, False otherwise. + """ + actual_aspect_ratio = self._get_aspect_ratio_from_tensor(cache_content) + expected_bucket = MultiaspectImage.determine_bucket_for_aspect_ratio( + actual_aspect_ratio + ) + logger.debug( + f"Expected bucket for {cache_file}: {expected_bucket} vs actual {actual_aspect_ratio}" + ) + + # Extract the base filename without the extension + base_filename = os.path.splitext(os.path.basename(cache_file))[0] + base_filename_png = os.path.join( + self.instance_data_root, f"{base_filename}.png" + ) + base_filename_jpg = os.path.join( + self.instance_data_root, f"{base_filename}.jpg" + ) + # Check if the base filename is in the correct bucket + if any( + base_filename_png in self.aspect_ratio_bucket_indices.get(bucket, set()) + for bucket in [expected_bucket, str(expected_bucket)] + ): + logger.debug(f"File {base_filename} is in the correct bucket.") + return False + if any( + base_filename_jpg in self.aspect_ratio_bucket_indices.get(bucket, set()) + for bucket in [expected_bucket, str(expected_bucket)] + ): + logger.debug(f"File {base_filename} is in the correct bucket.") + return False + logger.debug(f"File {base_filename} was not found in the correct place.") + return True + + def _get_aspect_ratio_from_tensor(self, tensor): + """ + Calculate the aspect ratio from a PyTorch Tensor. + + Args: + tensor (torch.Tensor): The tensor representing the image. + + Returns: + float: The aspect ratio of the image. + """ + if tensor.dim() < 3: + raise ValueError( + "Tensor does not have enough dimensions to determine aspect ratio." + ) + # Assuming tensor is in CHW format (channel, height, width) + _, height, width = tensor.size() + return width / height + + def _modify_cache_entry_bucket(self, cache_file, expected_bucket): + """ + Update the bucket indices based on the cache file's actual aspect ratio. + + Args: + cache_file (str): The cache file path. + expected_bucket (str): The bucket that the cache file should belong to. + """ + for bucket, files in self.aspect_ratio_bucket_indices.items(): + if cache_file in files and str(bucket) != str(expected_bucket): + files.remove(cache_file) + self.aspect_ratio_bucket_indices[expected_bucket].append(cache_file) + break diff --git a/helpers/multiaspect/image.py b/helpers/multiaspect/image.py index f31ce590..79ff9226 100644 --- a/helpers/multiaspect/image.py +++ b/helpers/multiaspect/image.py @@ -53,7 +53,9 @@ def process_for_bucket( image_metadata["crop_coordinates"] = crop_coordinates image_metadata["target_size"] = image.size # Round to avoid excessive unique buckets - aspect_ratio = round(image.width / image.height, aspect_ratio_rounding) + aspect_ratio = MultiaspectImage.calculate_image_aspect_ratio( + image, aspect_ratio_rounding + ) image_metadata["aspect_ratio"] = aspect_ratio image_metadata["luminance"] = calculate_luminance(image) logger.debug( @@ -254,3 +256,32 @@ def calculate_new_size_by_pixel_area(W: int, H: int, megapixels: float): H_new = MultiaspectImage._round_to_nearest_multiple(H_new, 64) return W_new, H_new + + @staticmethod + def calculate_image_aspect_ratio(image, rounding: int = 2): + """ + Calculate the aspect ratio of an image and round it to a specified precision. + + Args: + image (PIL.Image): The image to calculate the aspect ratio for. + rounding (int): The number of decimal places to round the aspect ratio to. + + Returns: + float: The rounded aspect ratio of the image. + """ + aspect_ratio = round(image.width / image.height, rounding) + return aspect_ratio + + @staticmethod + def determine_bucket_for_aspect_ratio(aspect_ratio, rounding: int = 2): + """ + Determine the correct bucket for a given aspect ratio. + + Args: + aspect_ratio (float): The aspect ratio of an image. + + Returns: + str: The bucket corresponding to the aspect ratio. + """ + # The logic for determining the bucket can be based on the aspect ratio directly + return str(round(aspect_ratio, rounding)) diff --git a/helpers/training/state_tracker.py b/helpers/training/state_tracker.py index 7bfb6a0d..a2413902 100644 --- a/helpers/training/state_tracker.py +++ b/helpers/training/state_tracker.py @@ -143,6 +143,12 @@ def get_data_backend(cls, id: str): def get_data_backend_config(cls, data_backend_id: str): return cls.data_backends.get(data_backend_id, {}).get("config", {}) + @classmethod + def set_data_backend_config(cls, data_backend_id: str, config: dict): + if data_backend_id not in cls.data_backends: + cls.data_backends[data_backend_id] = {} + cls.data_backends[data_backend_id]["config"] = config + @classmethod def get_data_backends(cls): return cls.data_backends diff --git a/tests/test_bucket.py b/tests/test_bucket.py index 2c68b212..0a5a0026 100644 --- a/tests/test_bucket.py +++ b/tests/test_bucket.py @@ -87,7 +87,7 @@ def test_save_cache(self): self.bucket_manager.aspect_ratio_bucket_indices = {"1.0": ["image1", "image2"]} self.bucket_manager.instance_images_path = ["image1", "image2"] with patch.object(self.data_backend, "write") as mock_write: - self.bucket_manager._save_cache() + self.bucket_manager.save_cache() mock_write.assert_called_once() # Add more tests for other methods as needed diff --git a/train_sdxl.py b/train_sdxl.py index 91dd1b8e..5a9d509d 100644 --- a/train_sdxl.py +++ b/train_sdxl.py @@ -357,7 +357,13 @@ def main(): logger.info(f"Loaded VAE into VRAM.") # Create a DataBackend, so that we can access our dataset. - configure_multi_databackend(args, accelerator) + try: + configure_multi_databackend(args, accelerator) + except Exception as e: + logging.error(f"{e}") + import sys + + sys.exit(0) prompt_handler = None if not args.disable_compel: From 0f35ef32a218d27f2003d49c609a6c0bde4ca36e Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 25 Dec 2023 14:45:51 -0600 Subject: [PATCH 22/22] update documentation --- OPTIONS.md | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/OPTIONS.md b/OPTIONS.md index 743b4882..62dcbea2 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -36,6 +36,17 @@ This guide provides a user-friendly breakdown of the command-line options availa - **Why**: Multiple datasets on different storage medium may be combined into a single training session. - **Example**: See (multidatabackend.json.example)[/multidatabackend.json.example] for an example configuration. +### `--override_dataset_config` + +- **What**: When provided, will allow SimpleTuner to ignore differences between the cached config inside the dataset and the current values. +- **Why**: When SimplerTuner is run for the first time on a dataset, it will create a cache document containing information about everything in that dataset. This includes the dataset config, including its "crop" and "resolution" related configuration values. Changing these arbitrarily or by accident could result in your training jobs crashing randomly, so it's highly recommended to not use this parameter, and instead resolve the differences you'd like to apply in your dataset some other way. + + +### `--vae_cache_behaviour` + +- **What**: Configure the behaviour of the integrity scan check. +- **Why**: A dataset could have incorrect settings applied at multiple points of training, eg. if you accidentally delete the `.json` cache files from your dataset and switch the data backend config to use square images rather than aspect-crops. This will result in an inconsistent data cache, which can be corrected by setting `scan_for_errors` to `true` in your `multidatabackend.json` configuration file. When this scan runs, it relies on the setting of `--vae_cache_behaviour` to determine how to resolve the inconsistency: `recreate` (the default) will remove the offending cache entry so that it can be recreated, and `sync` will update the bucket metadata to reflect the reality of the real training sample. Recommended value: `recreate`. + --- ## 🌈 Image and Text Processing @@ -166,10 +177,13 @@ usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--timestep_bias_end TIMESTEP_BIAS_END] [--timestep_bias_portion TIMESTEP_BIAS_PORTION] [--rescale_betas_zero_snr] [--vae_dtype VAE_DTYPE] - [--vae_batch_size VAE_BATCH_SIZE] [--keep_vae_loaded] + [--vae_batch_size VAE_BATCH_SIZE] + [--vae_cache_behaviour {recreate,sync}] + [--keep_vae_loaded] [--skip_file_discovery SKIP_FILE_DISCOVERY] [--revision REVISION] --instance_data_dir INSTANCE_DATA_DIR [--preserve_data_backend_cache] + [--override_dataset_config] [--cache_dir_text CACHE_DIR_TEXT] [--cache_dir_vae CACHE_DIR_VAE] --data_backend_config DATA_BACKEND_CONFIG [--write_batch_size WRITE_BATCH_SIZE] @@ -331,6 +345,16 @@ options: issues, but if you are at that point of contention, it's possible that your GPU has too little RAM. Default: 4. + --vae_cache_behaviour {recreate,sync} + When a mismatched latent vector is detected, a scan + will be initiated to locate inconsistencies and + resolve them. The default setting 'recreate' will + delete any inconsistent cache entries and rebuild it. + Alternatively, 'sync' will update the bucket + configuration so that the image is in a bucket that + matches its latent size. The recommended behaviour is + to use the default value and allow the cache to be + recreated. --keep_vae_loaded If set, will keep the VAE loaded in memory. This can reduce disk churn, but consumes VRAM during the forward pass. @@ -367,6 +391,16 @@ options: Currently, cache is not stored in the dataset itself but rather, locally. This may change in a future release. + --override_dataset_config + When provided, the dataset's config will not be + checked against the live backend config. This is + useful if you want to simply update the behaviour of + an existing dataset, but the recommendation is to not + change the dataset configuration after caching has + begun, as most options cannot be changed without + unexpected behaviour later on. Additionally, it + prevents accidentally loading an SDXL configuration on + a SD 2.x model and vice versa. --cache_dir_text CACHE_DIR_TEXT This is the path to a local directory that will contain your text embed cache.