From 06d34a2e4432475b44bf8d54ffde275bbd509100 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 15:39:50 -0600 Subject: [PATCH 01/14] text embedding cache: reduce file existence checks --- helpers/data_backend/local.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/helpers/data_backend/local.py b/helpers/data_backend/local.py index d3bfdbee..3b307871 100644 --- a/helpers/data_backend/local.py +++ b/helpers/data_backend/local.py @@ -58,7 +58,7 @@ def delete(self, filepath): def exists(self, filepath): """Check if the file exists.""" result = os.path.exists(filepath) - # logger.debug(f"Checking if {filepath} exists = {result}") + logger.debug(f"Checking if {filepath} exists = {result}") return result def open_file(self, filepath, mode): @@ -153,7 +153,6 @@ def read_image_batch( def create_directory(self, directory_path): logger.debug(f"Creating directory: {directory_path}") os.makedirs(directory_path, exist_ok=True) - logger.debug(f"Created directory.") def torch_load(self, filename): # Check if file exists: @@ -164,16 +163,12 @@ def torch_load(self, filename): ) def torch_save(self, data, original_location): - logger.debug("Calling torch_save on Local backend.") if type(original_location) == str: # A file path was given. Open it. - logger.debug(f"Using file path: {original_location}") location = self.open_file(original_location, "wb") else: # A file object was given. Use it. - logger.debug(f"Using file object: {original_location}") location = original_location - logger.debug(f"Torch location {original_location} save was given data: {data}") torch.save(data, location) # Check whether the file created: if type(original_location) == str: From aba313ffc24433e53d41d40f79605353b3c4ec36 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 15:49:07 -0600 Subject: [PATCH 02/14] text embedding cache: start the thread when re-scanning a new subset of a dataset --- helpers/caching/sdxl_embeds.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index 19ff8854..24757a5d 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -246,6 +246,11 @@ def compute_embeddings_for_prompts( return_concat: bool = True, is_validation: bool = False, ): + if not self.batch_write_thread.is_alive(): + # Start the thread again. + self.process_write_batches = True + self.batch_write_thread = Thread(target=self.batch_write_embeddings) + self.batch_write_thread.start() existing_cache_filenames = list( StateTracker.get_text_cache_files(data_backend_id=self.id).keys() ) @@ -318,12 +323,10 @@ def compute_embeddings_for_sdxl_prompts( ) self.debug_log(f"Checking for cache file: {filename}") if ( - self.data_backend.exists(filename) + return_concat + and self.data_backend.exists(filename) and load_from_cache - and not return_concat ): - continue - if self.data_backend.exists(filename) and load_from_cache: prompt_embeds, add_text_embeds = self.load_from_cache(filename) else: self.debug_log(f"Encoding prompt: {prompt}") From 7475accff303a3cb419810802df16c7f7f7f0f8e Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 15:51:52 -0600 Subject: [PATCH 03/14] text embedding cache: do not log as much --- helpers/caching/sdxl_embeds.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index 24757a5d..edf3f52c 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -255,8 +255,6 @@ def compute_embeddings_for_prompts( StateTracker.get_text_cache_files(data_backend_id=self.id).keys() ) all_cache_filenames = [f"{self.create_hash(p)}.pt" for p in all_prompts] - self.debug_log(f"Existing cache filenames: {existing_cache_filenames}") - self.debug_log(f"All cache filenames: {all_cache_filenames}") # Check if we have all the files in the cache if ( not is_validation From 7b30a211475da55a0885fea2e1fe3d8351527832 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 16:03:19 -0600 Subject: [PATCH 04/14] text embedding cache: write until we exhaust it --- helpers/caching/sdxl_embeds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index edf3f52c..a42157d7 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -112,7 +112,7 @@ def batch_write_embeddings(self): ) self.process_write_batch(batch) - if not self.process_write_batches: + if not self.process_write_batches and self.write_queue.empty(): # End the loop if we are done. break time.sleep(1) # Prevents the thread from being too busy-waiting From dfb4f1554dce283442a9445161fb8800fc909526 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 16:11:18 -0600 Subject: [PATCH 05/14] text embedding cache: write until we exhaust it, faster --- helpers/caching/sdxl_embeds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index a42157d7..48ac016d 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -115,7 +115,7 @@ def batch_write_embeddings(self): if not self.process_write_batches and self.write_queue.empty(): # End the loop if we are done. break - time.sleep(1) # Prevents the thread from being too busy-waiting + time.sleep(0.01) # Prevents the thread from being too busy-waiting def process_write_batch(self, batch): """Write a batch of embeddings to the cache.""" From b5457bc598a2216928767ed99147a32014c16b30 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 16:12:18 -0600 Subject: [PATCH 06/14] text embedding cache: write until we exhaust it, faster --- helpers/caching/sdxl_embeds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index 48ac016d..1ecbfd95 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -28,7 +28,7 @@ def __init__( cache_dir: str = "cache", model_type: str = "sdxl", prompt_handler: PromptHandler = None, - write_batch_size: int = 25, + write_batch_size: int = 128, read_batch_size: int = 25, process_queue_size: int = 16, text_encoder_batch_size: int = 4, From 822619a298030c8affff762110e1ebb0c2c49a25 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 16:14:39 -0600 Subject: [PATCH 07/14] text embedding cache: do not log the list --- helpers/caching/sdxl_embeds.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index 1ecbfd95..75388d60 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -80,9 +80,6 @@ def discover_all_files(self, directory: str = None): ) ) self.debug_log(" -> done listing all text embed cache entries") - self.debug_log( - f" -> {StateTracker.get_text_cache_files(data_backend_id=self.id)}" - ) def save_to_cache(self, filename, embeddings): """Add write requests to the queue instead of writing directly.""" From 2e8da41f5b9222519d9b61090b987e247d5a4745 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 16:24:31 -0600 Subject: [PATCH 08/14] text embedding cache: do not log as much, speed up further --- helpers/caching/sdxl_embeds.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index 75388d60..bc554120 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -84,29 +84,17 @@ def discover_all_files(self, directory: str = None): def save_to_cache(self, filename, embeddings): """Add write requests to the queue instead of writing directly.""" self.write_queue.put((embeddings, filename)) - self.debug_log( - f"Pushing cache object into write queue. We have {self.write_queue.qsize()} items in the queue." - ) def batch_write_embeddings(self): """Process write requests in batches.""" while True: batch = [] while not self.write_queue.empty() and len(batch) < self.write_batch_size: - self.debug_log( - f"Adding to batch, currently at {len(batch)} embeds. Waiting for {self.write_batch_size} embeds before we process" - ) batch.append(self.write_queue.get()) if len(batch) >= self.write_batch_size: - self.debug_log( - f"Processing batch of {len(batch)} embeds, as we reached our threshold of {self.write_batch_size}" - ) self.process_write_batch(batch) elif self.write_queue.empty() and len(batch) > 0: - self.debug_log( - f"Processing batch of {len(batch)} embeds, as the queue is empty." - ) self.process_write_batch(batch) if not self.process_write_batches and self.write_queue.empty(): @@ -116,9 +104,6 @@ def batch_write_embeddings(self): def process_write_batch(self, batch): """Write a batch of embeddings to the cache.""" - self.debug_log( - f"Processing write batch of {len(batch)} embeds via process_write_batch" - ) with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = [ executor.submit(self.data_backend.torch_save, *args) for args in batch @@ -127,9 +112,7 @@ def process_write_batch(self, batch): future.result() # Wait for all writes to complete def load_from_cache(self, filename): - logger.debug("Begin load from cache.") result = self.data_backend.torch_load(filename) - logger.debug("Completed load from cache.") return result def encode_legacy_prompt(self, text_encoder, tokenizer, prompt): @@ -216,7 +199,6 @@ def encode_sdxl_prompts( pooled_prompt_embeds_all = [] for prompt in prompts: - self.debug_log(f"Encoding prompt: {prompt}") prompt_embeds, pooled_prompt_embeds = self.encode_sdxl_prompt( text_encoders, tokenizers, prompt, is_validation ) From 5d59aaf734f3aaff0ee3c93d3df411621d6de700 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 16:33:17 -0600 Subject: [PATCH 09/14] Use unified method which checks for the existence of the prompts in the cache --- helpers/data_backend/factory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 1a57e149..d0599c55 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -159,9 +159,9 @@ def configure_multi_databackend( ): logger.info("Pre-computing null embedding for caption dropout") with accelerator.main_process_first(): - init_backend[ - "text_embed_cache" - ].compute_embeddings_for_sdxl_prompts([""], return_concat=False) + init_backend["text_embed_cache"].compute_embeddings_for_prompts( + [""], return_concat=False + ) accelerator.wait_for_everyone() else: logger.warning( From bfdf03b234318c604adb841da0e4dbabde4c66c7 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 16:53:10 -0600 Subject: [PATCH 10/14] AWS: log when we check for files, since this is an expensive operation --- helpers/data_backend/aws.py | 1 + 1 file changed, 1 insertion(+) diff --git a/helpers/data_backend/aws.py b/helpers/data_backend/aws.py index fdf72854..76ed7451 100644 --- a/helpers/data_backend/aws.py +++ b/helpers/data_backend/aws.py @@ -82,6 +82,7 @@ def __init__( def exists(self, s3_key) -> bool: """Determine whether a file exists in S3.""" try: + logger.debug("Checking if file exists: {s3_key}") self.client.head_object( Bucket=self.bucket_name, Key=self._convert_path_to_key(str(s3_key)) ) From 982b080872e230e4f438512a044ea1c106a6a945 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 18:23:04 -0600 Subject: [PATCH 11/14] text embedding cache: switch log level back to user-provided --- helpers/caching/sdxl_embeds.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index bc554120..311d5769 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -10,7 +10,6 @@ logger = logging.getLogger("TextEmbeddingCache") logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) -logger.setLevel("DEBUG") class TextEmbeddingCache: @@ -298,7 +297,6 @@ def compute_embeddings_for_sdxl_prompts( filename = os.path.join( self.cache_dir, self.create_hash(prompt) + ".pt" ) - self.debug_log(f"Checking for cache file: {filename}") if ( return_concat and self.data_backend.exists(filename) From 20fe88b57d04ba06cfab6ab5cd12ad81ceddb038 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 21:29:19 -0600 Subject: [PATCH 12/14] multidatabackend: fix the probability implementation, and add ignore_epochs parameter to dataset --- helpers/caching/sdxl_embeds.py | 6 +-- helpers/data_backend/aws.py | 2 +- helpers/data_backend/factory.py | 72 +++++++++++++++++++-------------- helpers/data_backend/local.py | 24 ----------- multidatabackend.example.json | 3 +- 5 files changed, 45 insertions(+), 62 deletions(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index 311d5769..48bc374b 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -297,11 +297,7 @@ def compute_embeddings_for_sdxl_prompts( filename = os.path.join( self.cache_dir, self.create_hash(prompt) + ".pt" ) - if ( - return_concat - and self.data_backend.exists(filename) - and load_from_cache - ): + if return_concat and load_from_cache: prompt_embeds, add_text_embeds = self.load_from_cache(filename) else: self.debug_log(f"Encoding prompt: {prompt}") diff --git a/helpers/data_backend/aws.py b/helpers/data_backend/aws.py index 76ed7451..114f92e2 100644 --- a/helpers/data_backend/aws.py +++ b/helpers/data_backend/aws.py @@ -82,7 +82,7 @@ def __init__( def exists(self, s3_key) -> bool: """Determine whether a file exists in S3.""" try: - logger.debug("Checking if file exists: {s3_key}") + logger.debug(f"Checking if file exists: {s3_key}") self.client.head_object( Bucket=self.bucket_name, Key=self._convert_path_to_key(str(s3_key)) ) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index d0599c55..24b317dd 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -501,6 +501,39 @@ def get_aws_backend( step = None +def select_dataloader_index(step, backends): + # Generate weights for each backend based on some criteria + weights = [] + backend_ids = [] + for backend_id, backend in backends.items(): + weight = get_backend_weight(backend_id, backend, step) + weights.append(weight) + backend_ids.append(backend_id) + + # Convert to a torch tensor for easy sampling + weights = torch.tensor(weights) + weights /= weights.sum() # Normalize the weights + + if weights.sum() == 0: + return None + + # Sample a backend index based on the weights + chosen_index = torch.multinomial(weights, 1).item() + return backend_ids[chosen_index] + + +def get_backend_weight(backend_id, backend, step): + # Implement your logic to determine the weight for each backend + # For example, a simple linear decay based on the step count + backend_config = StateTracker.get_data_backend_config(backend_id) + prob = backend_config.get("probability", 1) + disable_step = backend_config.get("disable_after_epoch_step", float("inf")) + adjusted_prob = ( + 0 if step > disable_step else max(0, prob * (1 - step / disable_step)) + ) + return adjusted_prob + + def random_dataloader_iterator(backends: dict): global step if step is None: @@ -547,39 +580,16 @@ def random_dataloader_iterator(backends: dict): StateTracker.backend_exhausted(chosen_backend_id) StateTracker.set_repeats(data_backend_id=chosen_backend_id, repeats=0) finally: - if not backends: + if not backends or all( + [ + StateTracker.get_data_backend_config(backend_id).get( + "ignore_epochs", False + ) + for backend_id in backends + ] + ): logger.debug( "All dataloaders exhausted. Moving to next epoch in main training loop." ) StateTracker.clear_exhausted_buckets() return (step, None) - - -def select_dataloader_index(step, backends): - adjusted_probabilities = {} - logger.debug(f"Selecting from backends: {backends.keys()}") - for backend_id, dataloader in backends.items(): - backend = StateTracker.get_data_backend(backend_id) - prob = backend["config"].get(" `", 1) - disable_step = backend["config"].get("disable_after_epoch_step", float("inf")) - - adjusted_prob = ( - 0 if step > disable_step else max(0, prob * (1 - step / disable_step)) - ) - adjusted_probabilities[backend_id] = adjusted_prob - - # Shuffle the backends - items = list(adjusted_probabilities.items()) - random.shuffle(items) - total_prob = sum(prob for _, prob in items) - if total_prob == 0: - return None - - rnd = random.uniform(0, total_prob) - cumulative_prob = 0 - for backend_id, prob in items: # Use shuffled order - cumulative_prob += prob - if rnd < cumulative_prob: - return backend_id - - return None diff --git a/helpers/data_backend/local.py b/helpers/data_backend/local.py index 3b307871..9662731c 100644 --- a/helpers/data_backend/local.py +++ b/helpers/data_backend/local.py @@ -39,10 +39,6 @@ def write(self, filepath: str, data: Any) -> None: f"Received an unknown data type to write to disk. Doing our best: {type(data)}" ) file.write(data) - # Check if file exists: - if not self.exists(filepath): - raise Exception(f"Failed to write to {filepath}") - logger.debug(f"Completed write()") def delete(self, filepath): """Delete the specified file.""" @@ -170,28 +166,8 @@ def torch_save(self, data, original_location): # A file object was given. Use it. location = original_location torch.save(data, location) - # Check whether the file created: - if type(original_location) == str: - # A file path was given. Check it. - if not self.exists(original_location): - raise Exception(f"Failed to write to {original_location}") - elif hasattr(original_location, "name"): - # A file object was given. Check it. - if not self.exists(original_location.name): - raise Exception(f"Failed to write to {original_location.name}") - else: - import traceback - - raise Exception( - f"Unknown error writing to {original_location}, traceback: {traceback.format_exc()}" - ) def write_batch(self, filepaths: list, data_list: list) -> None: """Write a batch of data to the specified filepaths.""" - logger.debug(f"Reached write_batch in LocalDataBackend.") for filepath, data in zip(filepaths, data_list): self.write(filepath, data) - # Check if file was written: - if not self.exists(filepath): - raise Exception(f"Failed to write to {filepath}") - logger.debug(f"Succesfully validated file creation: {filepath}") diff --git a/multidatabackend.example.json b/multidatabackend.example.json index 03938fea..8dc40eff 100644 --- a/multidatabackend.example.json +++ b/multidatabackend.example.json @@ -30,7 +30,8 @@ "aws_data_prefix": "", "cache_dir_vae": "s3prefix/for/vaecache", "vae_cache_clear_each_epoch": true, - "repeats": 2 + "repeats": 2, + "ignore_epochs": false }, { "id": "an example backend for text embeds.", From 77412a77a7329eb50ec161f10137ad61ab64b1f2 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 21:39:41 -0600 Subject: [PATCH 13/14] multidatabackend: fix the probability implementation, and add ignore_epochs parameter to dataset (fix) --- helpers/data_backend/factory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 24b317dd..97954903 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -33,6 +33,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: ] if "probability" in backend: output["config"]["probability"] = backend["probability"] + if "ignore_epochs" in backend: + output["config"]["ignore_epochs"] = backend["ignore_epochs"] if "repeats" in backend: output["config"]["repeats"] = backend["repeats"] if "crop" in backend: From e0b1bb59d36f4e4424493a5f13beb9334f2a63be Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 21:54:50 -0600 Subject: [PATCH 14/14] multidatabackend: better docs and implementation for ignore_epochs --- documentation/DATALOADER.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/documentation/DATALOADER.md b/documentation/DATALOADER.md index 2e8e8e40..cfa94442 100644 --- a/documentation/DATALOADER.md +++ b/documentation/DATALOADER.md @@ -107,4 +107,7 @@ Here is an example dataloader configuration file, as `multidatabackend.example.j - Specifies the number of times all samples in the dataset are seen during an epoch. Useful for giving more impact to smaller datasets or maximizing the usage of VAE cache objects. ### `vae_cache_clear_each_epoch` -- When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets. \ No newline at end of file +- When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets. + +### `ignore_epochs` +- When enabled, this dataset will not hold up the rest of the datasets from completing an epoch. This will inherently make the value for the current epoch inaccurate, as it reflects only the number of times any datasets *without* this flag have completed all of their repeats. The state of the ignored dataset isn't reset upon the next epoch, it is simply ignored. It will eventually run out of samples as a dataset typically does. At that time it will be removed from consideration until the next natural epoch completes. \ No newline at end of file