From a63048b849d8d362d7ba5b34474263dd8493fe3a Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 12:25:24 -0600 Subject: [PATCH 1/3] text embeds: retrieve from backend in parallel during collate_fn --- helpers/caching/sdxl_embeds.py | 5 ++- helpers/data_backend/factory.py | 10 ++++- helpers/legacy/validation.py | 1 + helpers/training/collate.py | 65 +++++++++++++++++++-------------- 4 files changed, 51 insertions(+), 30 deletions(-) diff --git a/helpers/caching/sdxl_embeds.py b/helpers/caching/sdxl_embeds.py index 08c99626..fbb374a6 100644 --- a/helpers/caching/sdxl_embeds.py +++ b/helpers/caching/sdxl_embeds.py @@ -60,7 +60,10 @@ def save_to_cache(self, filename, embeddings): self.data_backend.torch_save(embeddings, filename) def load_from_cache(self, filename): - return self.data_backend.torch_load(filename) + logger.debug("Begin load from cache.") + result = self.data_backend.torch_load(filename) + logger.debug("Completed load from cache.") + return result def encode_legacy_prompt(self, text_encoder, tokenizer, prompt): input_tokens = tokenizer( diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 007ed00f..c1adf4d2 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -95,12 +95,19 @@ def configure_multi_databackend( text_embed_backends = {} default_text_embed_backend_id = None - logger.info("Initialise text embedding cache") for backend in data_backend_config: dataset_type = backend.get("dataset_type", None) if dataset_type is None or dataset_type != "text_embeds": # Skip configuration of image data backends. It is done earlier. continue + if ("disabled" in backend and backend["disabled"]) or ( + "disable" in backend and backend["disable"] + ): + logger.info( + f"Skipping disabled data backend {backend['id']} in config file." + ) + continue + logger.info(f'Configuring text embed backend: {backend["id"]}') if backend.get("default", None): if default_text_embed_backend_id is not None: @@ -133,7 +140,6 @@ def configure_multi_databackend( raise ValueError(f"Unknown data backend type: {backend['type']}") # Generate a TextEmbeddingCache object - logger.info("Loading the text embed management object") init_backend["text_embed_cache"] = TextEmbeddingCache( id=init_backend["id"], data_backend=init_backend["data_backend"], diff --git a/helpers/legacy/validation.py b/helpers/legacy/validation.py index 156b38e8..3ae4b50a 100644 --- a/helpers/legacy/validation.py +++ b/helpers/legacy/validation.py @@ -260,6 +260,7 @@ def log_validations( **extra_validation_kwargs, ).images ) + logger.debug(f"Completed generating image: {validation_prompt}") for tracker in accelerator.trackers: if tracker.name == "wandb": diff --git a/helpers/training/collate.py b/helpers/training/collate.py index 29403d51..73dfe73d 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -4,6 +4,7 @@ from helpers.training.multi_process import rank_info from helpers.image_manipulation.brightness import calculate_batch_luminance from accelerate.logging import get_logger +from concurrent.futures import ThreadPoolExecutor logger = logging.getLogger("collate_fn") logger.setLevel(environ.get("SIMPLETUNER_COLLATE_LOG_LEVEL", "INFO")) @@ -61,19 +62,6 @@ def compute_time_ids( return add_time_ids -def extract_pixel_values(examples): - pixel_values = [] - for example in examples: - image_data = example["image_data"] - pixel_values.append( - to_tensor(image_data).to( - memory_format=torch.contiguous_format, - dtype=StateTracker.get_vae_dtype(), - ) - ) - return pixel_values - - def extract_filepaths(examples): filepaths = [] for example in examples: @@ -113,24 +101,47 @@ def compute_latents(filepaths, data_backend_id: str): return torch.stack(latents) +def compute_single_embedding(caption, text_embed_cache, is_sdxl): + """Worker function to compute embedding for a single caption.""" + if is_sdxl: + return text_embed_cache.compute_embeddings_for_sdxl_prompts([caption])[0] + else: + return text_embed_cache.compute_embeddings_for_legacy_prompts([caption])[0] + + def compute_prompt_embeddings(captions, text_embed_cache): + """ + Retrieve / compute text embeds in parallel. + Args: + captions: List of strings + text_embed_cache: TextEmbedCache instance + + Returns: + prompt_embeds_all: Tensor of shape (batch_size, 512) + add_text_embeds_all: Tensor of shape (batch_size, 512) + """ debug_log(" -> get embed from cache") - if text_embed_cache.model_type == "sdxl": - ( - prompt_embeds_all, - add_text_embeds_all, - ) = text_embed_cache.compute_embeddings_for_sdxl_prompts(captions) - debug_log(" -> concat embeds") + is_sdxl = text_embed_cache.model_type == "sdxl" + + # Use a thread pool to compute embeddings concurrently + with ThreadPoolExecutor() as executor: + embeddings = list( + executor.map( + compute_single_embedding, + captions, + [text_embed_cache] * len(captions), + [is_sdxl] * len(captions), + ) + ) + + if is_sdxl: + prompt_embeds_all, add_text_embeds_all = zip(*embeddings) + prompt_embeds_all = torch.concat(list(prompt_embeds_all), dim=0) + add_text_embeds_all = torch.concat(list(add_text_embeds_all), dim=0) + return prompt_embeds_all, add_text_embeds_all else: - debug_log(" -> concat embeds") - prompt_embeds_all = text_embed_cache.compute_embeddings_for_legacy_prompts( - captions - )[0] - prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0) + prompt_embeds_all = torch.concat(embeddings, dim=0) return prompt_embeds_all, None - prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0) - add_text_embeds_all = torch.concat([add_text_embeds_all for _ in range(1)], dim=0) - return prompt_embeds_all, add_text_embeds_all def gather_conditional_size_features(examples, latents, weight_dtype): From 075ce768b5e87d7299632957c6d6cc70bf321794 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 13:04:43 -0600 Subject: [PATCH 2/3] more collate changes for testing --- helpers/training/collate.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/helpers/training/collate.py b/helpers/training/collate.py index 73dfe73d..af5c9daa 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -104,9 +104,19 @@ def compute_latents(filepaths, data_backend_id: str): def compute_single_embedding(caption, text_embed_cache, is_sdxl): """Worker function to compute embedding for a single caption.""" if is_sdxl: - return text_embed_cache.compute_embeddings_for_sdxl_prompts([caption])[0] + ( + prompt_embeds, + add_text_embeds, + ) = text_embed_cache.compute_embeddings_for_sdxl_prompts([caption]) + return ( + prompt_embeds[0], + add_text_embeds[0], + ) # Unpack the first (and only) element else: - return text_embed_cache.compute_embeddings_for_legacy_prompts([caption])[0] + prompt_embeds = text_embed_cache.compute_embeddings_for_legacy_prompts( + [caption] + ) + return prompt_embeds[0], None # Unpack and return None for the second element def compute_prompt_embeddings(captions, text_embed_cache): @@ -134,6 +144,7 @@ def compute_prompt_embeddings(captions, text_embed_cache): ) ) + logger.debug(f"Got embeddings: {embeddings}") if is_sdxl: prompt_embeds_all, add_text_embeds_all = zip(*embeddings) prompt_embeds_all = torch.concat(list(prompt_embeds_all), dim=0) From 1e833d422de354ee7f496323f1188ec923dcdd59 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 15 Jan 2024 13:20:23 -0600 Subject: [PATCH 3/3] sdxl text embeds need to be stacked separately --- helpers/training/collate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/helpers/training/collate.py b/helpers/training/collate.py index af5c9daa..85c8a5db 100644 --- a/helpers/training/collate.py +++ b/helpers/training/collate.py @@ -146,10 +146,10 @@ def compute_prompt_embeddings(captions, text_embed_cache): logger.debug(f"Got embeddings: {embeddings}") if is_sdxl: - prompt_embeds_all, add_text_embeds_all = zip(*embeddings) - prompt_embeds_all = torch.concat(list(prompt_embeds_all), dim=0) - add_text_embeds_all = torch.concat(list(add_text_embeds_all), dim=0) - return prompt_embeds_all, add_text_embeds_all + # Separate the tuples + prompt_embeds = [t[0] for t in embeddings] + add_text_embeds = [t[1] for t in embeddings] + return (torch.stack(prompt_embeds), torch.stack(add_text_embeds)) else: prompt_embeds_all = torch.concat(embeddings, dim=0) return prompt_embeds_all, None