Skip to content

Commit

Permalink
Merge pull request #265 from bghira/feature/multithreaded-collate-tex…
Browse files Browse the repository at this point in the history
…tembeds

Feature/multithreaded collate textembeds
  • Loading branch information
bghira authored Jan 15, 2024
2 parents 8547e7d + 1e833d4 commit 354979f
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 30 deletions.
5 changes: 4 additions & 1 deletion helpers/caching/sdxl_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def save_to_cache(self, filename, embeddings):
self.data_backend.torch_save(embeddings, filename)

def load_from_cache(self, filename):
return self.data_backend.torch_load(filename)
logger.debug("Begin load from cache.")
result = self.data_backend.torch_load(filename)
logger.debug("Completed load from cache.")
return result

def encode_legacy_prompt(self, text_encoder, tokenizer, prompt):
input_tokens = tokenizer(
Expand Down
10 changes: 8 additions & 2 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,19 @@ def configure_multi_databackend(

text_embed_backends = {}
default_text_embed_backend_id = None
logger.info("Initialise text embedding cache")
for backend in data_backend_config:
dataset_type = backend.get("dataset_type", None)
if dataset_type is None or dataset_type != "text_embeds":
# Skip configuration of image data backends. It is done earlier.
continue
if ("disabled" in backend and backend["disabled"]) or (
"disable" in backend and backend["disable"]
):
logger.info(
f"Skipping disabled data backend {backend['id']} in config file."
)
continue

logger.info(f'Configuring text embed backend: {backend["id"]}')
if backend.get("default", None):
if default_text_embed_backend_id is not None:
Expand Down Expand Up @@ -133,7 +140,6 @@ def configure_multi_databackend(
raise ValueError(f"Unknown data backend type: {backend['type']}")

# Generate a TextEmbeddingCache object
logger.info("Loading the text embed management object")
init_backend["text_embed_cache"] = TextEmbeddingCache(
id=init_backend["id"],
data_backend=init_backend["data_backend"],
Expand Down
1 change: 1 addition & 0 deletions helpers/legacy/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def log_validations(
**extra_validation_kwargs,
).images
)
logger.debug(f"Completed generating image: {validation_prompt}")

for tracker in accelerator.trackers:
if tracker.name == "wandb":
Expand Down
76 changes: 49 additions & 27 deletions helpers/training/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from helpers.training.multi_process import rank_info
from helpers.image_manipulation.brightness import calculate_batch_luminance
from accelerate.logging import get_logger
from concurrent.futures import ThreadPoolExecutor

logger = logging.getLogger("collate_fn")
logger.setLevel(environ.get("SIMPLETUNER_COLLATE_LOG_LEVEL", "INFO"))
Expand Down Expand Up @@ -61,19 +62,6 @@ def compute_time_ids(
return add_time_ids


def extract_pixel_values(examples):
pixel_values = []
for example in examples:
image_data = example["image_data"]
pixel_values.append(
to_tensor(image_data).to(
memory_format=torch.contiguous_format,
dtype=StateTracker.get_vae_dtype(),
)
)
return pixel_values


def extract_filepaths(examples):
filepaths = []
for example in examples:
Expand Down Expand Up @@ -113,24 +101,58 @@ def compute_latents(filepaths, data_backend_id: str):
return torch.stack(latents)


def compute_single_embedding(caption, text_embed_cache, is_sdxl):
"""Worker function to compute embedding for a single caption."""
if is_sdxl:
(
prompt_embeds,
add_text_embeds,
) = text_embed_cache.compute_embeddings_for_sdxl_prompts([caption])
return (
prompt_embeds[0],
add_text_embeds[0],
) # Unpack the first (and only) element
else:
prompt_embeds = text_embed_cache.compute_embeddings_for_legacy_prompts(
[caption]
)
return prompt_embeds[0], None # Unpack and return None for the second element


def compute_prompt_embeddings(captions, text_embed_cache):
"""
Retrieve / compute text embeds in parallel.
Args:
captions: List of strings
text_embed_cache: TextEmbedCache instance
Returns:
prompt_embeds_all: Tensor of shape (batch_size, 512)
add_text_embeds_all: Tensor of shape (batch_size, 512)
"""
debug_log(" -> get embed from cache")
if text_embed_cache.model_type == "sdxl":
(
prompt_embeds_all,
add_text_embeds_all,
) = text_embed_cache.compute_embeddings_for_sdxl_prompts(captions)
debug_log(" -> concat embeds")
is_sdxl = text_embed_cache.model_type == "sdxl"

# Use a thread pool to compute embeddings concurrently
with ThreadPoolExecutor() as executor:
embeddings = list(
executor.map(
compute_single_embedding,
captions,
[text_embed_cache] * len(captions),
[is_sdxl] * len(captions),
)
)

logger.debug(f"Got embeddings: {embeddings}")
if is_sdxl:
# Separate the tuples
prompt_embeds = [t[0] for t in embeddings]
add_text_embeds = [t[1] for t in embeddings]
return (torch.stack(prompt_embeds), torch.stack(add_text_embeds))
else:
debug_log(" -> concat embeds")
prompt_embeds_all = text_embed_cache.compute_embeddings_for_legacy_prompts(
captions
)[0]
prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0)
prompt_embeds_all = torch.concat(embeddings, dim=0)
return prompt_embeds_all, None
prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0)
add_text_embeds_all = torch.concat([add_text_embeds_all for _ in range(1)], dim=0)
return prompt_embeds_all, add_text_embeds_all


def gather_conditional_size_features(examples, latents, weight_dtype):
Expand Down

0 comments on commit 354979f

Please sign in to comment.