diff --git a/documentation/DATALOADER.md b/documentation/DATALOADER.md index 2e8e8e40..cfa94442 100644 --- a/documentation/DATALOADER.md +++ b/documentation/DATALOADER.md @@ -107,4 +107,7 @@ Here is an example dataloader configuration file, as `multidatabackend.example.j - Specifies the number of times all samples in the dataset are seen during an epoch. Useful for giving more impact to smaller datasets or maximizing the usage of VAE cache objects. ### `vae_cache_clear_each_epoch` -- When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets. \ No newline at end of file +- When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets. + +### `ignore_epochs` +- When enabled, this dataset will not hold up the rest of the datasets from completing an epoch. This will inherently make the value for the current epoch inaccurate, as it reflects only the number of times any datasets *without* this flag have completed all of their repeats. The state of the ignored dataset isn't reset upon the next epoch, it is simply ignored. It will eventually run out of samples as a dataset typically does. At that time it will be removed from consideration until the next natural epoch completes. \ No newline at end of file diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index 19ff8854..48bc374b 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -10,7 +10,6 @@ logger = logging.getLogger("TextEmbeddingCache") logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) -logger.setLevel("DEBUG") class TextEmbeddingCache: @@ -28,7 +27,7 @@ def __init__( cache_dir: str = "cache", model_type: str = "sdxl", prompt_handler: PromptHandler = None, - write_batch_size: int = 25, + write_batch_size: int = 128, read_batch_size: int = 25, process_queue_size: int = 16, text_encoder_batch_size: int = 4, @@ -80,48 +79,30 @@ def discover_all_files(self, directory: str = None): ) ) self.debug_log(" -> done listing all text embed cache entries") - self.debug_log( - f" -> {StateTracker.get_text_cache_files(data_backend_id=self.id)}" - ) def save_to_cache(self, filename, embeddings): """Add write requests to the queue instead of writing directly.""" self.write_queue.put((embeddings, filename)) - self.debug_log( - f"Pushing cache object into write queue. We have {self.write_queue.qsize()} items in the queue." - ) def batch_write_embeddings(self): """Process write requests in batches.""" while True: batch = [] while not self.write_queue.empty() and len(batch) < self.write_batch_size: - self.debug_log( - f"Adding to batch, currently at {len(batch)} embeds. Waiting for {self.write_batch_size} embeds before we process" - ) batch.append(self.write_queue.get()) if len(batch) >= self.write_batch_size: - self.debug_log( - f"Processing batch of {len(batch)} embeds, as we reached our threshold of {self.write_batch_size}" - ) self.process_write_batch(batch) elif self.write_queue.empty() and len(batch) > 0: - self.debug_log( - f"Processing batch of {len(batch)} embeds, as the queue is empty." - ) self.process_write_batch(batch) - if not self.process_write_batches: + if not self.process_write_batches and self.write_queue.empty(): # End the loop if we are done. break - time.sleep(1) # Prevents the thread from being too busy-waiting + time.sleep(0.01) # Prevents the thread from being too busy-waiting def process_write_batch(self, batch): """Write a batch of embeddings to the cache.""" - self.debug_log( - f"Processing write batch of {len(batch)} embeds via process_write_batch" - ) with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = [ executor.submit(self.data_backend.torch_save, *args) for args in batch @@ -130,9 +111,7 @@ def process_write_batch(self, batch): future.result() # Wait for all writes to complete def load_from_cache(self, filename): - logger.debug("Begin load from cache.") result = self.data_backend.torch_load(filename) - logger.debug("Completed load from cache.") return result def encode_legacy_prompt(self, text_encoder, tokenizer, prompt): @@ -219,7 +198,6 @@ def encode_sdxl_prompts( pooled_prompt_embeds_all = [] for prompt in prompts: - self.debug_log(f"Encoding prompt: {prompt}") prompt_embeds, pooled_prompt_embeds = self.encode_sdxl_prompt( text_encoders, tokenizers, prompt, is_validation ) @@ -246,12 +224,15 @@ def compute_embeddings_for_prompts( return_concat: bool = True, is_validation: bool = False, ): + if not self.batch_write_thread.is_alive(): + # Start the thread again. + self.process_write_batches = True + self.batch_write_thread = Thread(target=self.batch_write_embeddings) + self.batch_write_thread.start() existing_cache_filenames = list( StateTracker.get_text_cache_files(data_backend_id=self.id).keys() ) all_cache_filenames = [f"{self.create_hash(p)}.pt" for p in all_prompts] - self.debug_log(f"Existing cache filenames: {existing_cache_filenames}") - self.debug_log(f"All cache filenames: {all_cache_filenames}") # Check if we have all the files in the cache if ( not is_validation @@ -316,14 +297,7 @@ def compute_embeddings_for_sdxl_prompts( filename = os.path.join( self.cache_dir, self.create_hash(prompt) + ".pt" ) - self.debug_log(f"Checking for cache file: {filename}") - if ( - self.data_backend.exists(filename) - and load_from_cache - and not return_concat - ): - continue - if self.data_backend.exists(filename) and load_from_cache: + if return_concat and load_from_cache: prompt_embeds, add_text_embeds = self.load_from_cache(filename) else: self.debug_log(f"Encoding prompt: {prompt}") diff --git a/helpers/data_backend/aws.py b/helpers/data_backend/aws.py index fdf72854..114f92e2 100644 --- a/helpers/data_backend/aws.py +++ b/helpers/data_backend/aws.py @@ -82,6 +82,7 @@ def __init__( def exists(self, s3_key) -> bool: """Determine whether a file exists in S3.""" try: + logger.debug(f"Checking if file exists: {s3_key}") self.client.head_object( Bucket=self.bucket_name, Key=self._convert_path_to_key(str(s3_key)) ) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 1a57e149..97954903 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -33,6 +33,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: ] if "probability" in backend: output["config"]["probability"] = backend["probability"] + if "ignore_epochs" in backend: + output["config"]["ignore_epochs"] = backend["ignore_epochs"] if "repeats" in backend: output["config"]["repeats"] = backend["repeats"] if "crop" in backend: @@ -159,9 +161,9 @@ def configure_multi_databackend( ): logger.info("Pre-computing null embedding for caption dropout") with accelerator.main_process_first(): - init_backend[ - "text_embed_cache" - ].compute_embeddings_for_sdxl_prompts([""], return_concat=False) + init_backend["text_embed_cache"].compute_embeddings_for_prompts( + [""], return_concat=False + ) accelerator.wait_for_everyone() else: logger.warning( @@ -501,6 +503,39 @@ def get_aws_backend( step = None +def select_dataloader_index(step, backends): + # Generate weights for each backend based on some criteria + weights = [] + backend_ids = [] + for backend_id, backend in backends.items(): + weight = get_backend_weight(backend_id, backend, step) + weights.append(weight) + backend_ids.append(backend_id) + + # Convert to a torch tensor for easy sampling + weights = torch.tensor(weights) + weights /= weights.sum() # Normalize the weights + + if weights.sum() == 0: + return None + + # Sample a backend index based on the weights + chosen_index = torch.multinomial(weights, 1).item() + return backend_ids[chosen_index] + + +def get_backend_weight(backend_id, backend, step): + # Implement your logic to determine the weight for each backend + # For example, a simple linear decay based on the step count + backend_config = StateTracker.get_data_backend_config(backend_id) + prob = backend_config.get("probability", 1) + disable_step = backend_config.get("disable_after_epoch_step", float("inf")) + adjusted_prob = ( + 0 if step > disable_step else max(0, prob * (1 - step / disable_step)) + ) + return adjusted_prob + + def random_dataloader_iterator(backends: dict): global step if step is None: @@ -547,39 +582,16 @@ def random_dataloader_iterator(backends: dict): StateTracker.backend_exhausted(chosen_backend_id) StateTracker.set_repeats(data_backend_id=chosen_backend_id, repeats=0) finally: - if not backends: + if not backends or all( + [ + StateTracker.get_data_backend_config(backend_id).get( + "ignore_epochs", False + ) + for backend_id in backends + ] + ): logger.debug( "All dataloaders exhausted. Moving to next epoch in main training loop." ) StateTracker.clear_exhausted_buckets() return (step, None) - - -def select_dataloader_index(step, backends): - adjusted_probabilities = {} - logger.debug(f"Selecting from backends: {backends.keys()}") - for backend_id, dataloader in backends.items(): - backend = StateTracker.get_data_backend(backend_id) - prob = backend["config"].get(" `", 1) - disable_step = backend["config"].get("disable_after_epoch_step", float("inf")) - - adjusted_prob = ( - 0 if step > disable_step else max(0, prob * (1 - step / disable_step)) - ) - adjusted_probabilities[backend_id] = adjusted_prob - - # Shuffle the backends - items = list(adjusted_probabilities.items()) - random.shuffle(items) - total_prob = sum(prob for _, prob in items) - if total_prob == 0: - return None - - rnd = random.uniform(0, total_prob) - cumulative_prob = 0 - for backend_id, prob in items: # Use shuffled order - cumulative_prob += prob - if rnd < cumulative_prob: - return backend_id - - return None diff --git a/helpers/data_backend/local.py b/helpers/data_backend/local.py index d3bfdbee..9662731c 100644 --- a/helpers/data_backend/local.py +++ b/helpers/data_backend/local.py @@ -39,10 +39,6 @@ def write(self, filepath: str, data: Any) -> None: f"Received an unknown data type to write to disk. Doing our best: {type(data)}" ) file.write(data) - # Check if file exists: - if not self.exists(filepath): - raise Exception(f"Failed to write to {filepath}") - logger.debug(f"Completed write()") def delete(self, filepath): """Delete the specified file.""" @@ -58,7 +54,7 @@ def delete(self, filepath): def exists(self, filepath): """Check if the file exists.""" result = os.path.exists(filepath) - # logger.debug(f"Checking if {filepath} exists = {result}") + logger.debug(f"Checking if {filepath} exists = {result}") return result def open_file(self, filepath, mode): @@ -153,7 +149,6 @@ def read_image_batch( def create_directory(self, directory_path): logger.debug(f"Creating directory: {directory_path}") os.makedirs(directory_path, exist_ok=True) - logger.debug(f"Created directory.") def torch_load(self, filename): # Check if file exists: @@ -164,39 +159,15 @@ def torch_load(self, filename): ) def torch_save(self, data, original_location): - logger.debug("Calling torch_save on Local backend.") if type(original_location) == str: # A file path was given. Open it. - logger.debug(f"Using file path: {original_location}") location = self.open_file(original_location, "wb") else: # A file object was given. Use it. - logger.debug(f"Using file object: {original_location}") location = original_location - logger.debug(f"Torch location {original_location} save was given data: {data}") torch.save(data, location) - # Check whether the file created: - if type(original_location) == str: - # A file path was given. Check it. - if not self.exists(original_location): - raise Exception(f"Failed to write to {original_location}") - elif hasattr(original_location, "name"): - # A file object was given. Check it. - if not self.exists(original_location.name): - raise Exception(f"Failed to write to {original_location.name}") - else: - import traceback - - raise Exception( - f"Unknown error writing to {original_location}, traceback: {traceback.format_exc()}" - ) def write_batch(self, filepaths: list, data_list: list) -> None: """Write a batch of data to the specified filepaths.""" - logger.debug(f"Reached write_batch in LocalDataBackend.") for filepath, data in zip(filepaths, data_list): self.write(filepath, data) - # Check if file was written: - if not self.exists(filepath): - raise Exception(f"Failed to write to {filepath}") - logger.debug(f"Succesfully validated file creation: {filepath}") diff --git a/multidatabackend.example.json b/multidatabackend.example.json index 03938fea..8dc40eff 100644 --- a/multidatabackend.example.json +++ b/multidatabackend.example.json @@ -30,7 +30,8 @@ "aws_data_prefix": "", "cache_dir_vae": "s3prefix/for/vaecache", "vae_cache_clear_each_epoch": true, - "repeats": 2 + "repeats": 2, + "ignore_epochs": false }, { "id": "an example backend for text embeds.",