Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collate_fn: multi-threaded retrieval of SDXL text embeds #265

Merged
merged 3 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion helpers/caching/sdxl_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def save_to_cache(self, filename, embeddings):
self.data_backend.torch_save(embeddings, filename)

def load_from_cache(self, filename):
return self.data_backend.torch_load(filename)
logger.debug("Begin load from cache.")
result = self.data_backend.torch_load(filename)
logger.debug("Completed load from cache.")
return result

def encode_legacy_prompt(self, text_encoder, tokenizer, prompt):
input_tokens = tokenizer(
Expand Down
10 changes: 8 additions & 2 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,19 @@ def configure_multi_databackend(

text_embed_backends = {}
default_text_embed_backend_id = None
logger.info("Initialise text embedding cache")
for backend in data_backend_config:
dataset_type = backend.get("dataset_type", None)
if dataset_type is None or dataset_type != "text_embeds":
# Skip configuration of image data backends. It is done earlier.
continue
if ("disabled" in backend and backend["disabled"]) or (
"disable" in backend and backend["disable"]
):
logger.info(
f"Skipping disabled data backend {backend['id']} in config file."
)
continue

logger.info(f'Configuring text embed backend: {backend["id"]}')
if backend.get("default", None):
if default_text_embed_backend_id is not None:
Expand Down Expand Up @@ -133,7 +140,6 @@ def configure_multi_databackend(
raise ValueError(f"Unknown data backend type: {backend['type']}")

# Generate a TextEmbeddingCache object
logger.info("Loading the text embed management object")
init_backend["text_embed_cache"] = TextEmbeddingCache(
id=init_backend["id"],
data_backend=init_backend["data_backend"],
Expand Down
1 change: 1 addition & 0 deletions helpers/legacy/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def log_validations(
**extra_validation_kwargs,
).images
)
logger.debug(f"Completed generating image: {validation_prompt}")

for tracker in accelerator.trackers:
if tracker.name == "wandb":
Expand Down
76 changes: 49 additions & 27 deletions helpers/training/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from helpers.training.multi_process import rank_info
from helpers.image_manipulation.brightness import calculate_batch_luminance
from accelerate.logging import get_logger
from concurrent.futures import ThreadPoolExecutor

logger = logging.getLogger("collate_fn")
logger.setLevel(environ.get("SIMPLETUNER_COLLATE_LOG_LEVEL", "INFO"))
Expand Down Expand Up @@ -61,19 +62,6 @@ def compute_time_ids(
return add_time_ids


def extract_pixel_values(examples):
pixel_values = []
for example in examples:
image_data = example["image_data"]
pixel_values.append(
to_tensor(image_data).to(
memory_format=torch.contiguous_format,
dtype=StateTracker.get_vae_dtype(),
)
)
return pixel_values


def extract_filepaths(examples):
filepaths = []
for example in examples:
Expand Down Expand Up @@ -113,24 +101,58 @@ def compute_latents(filepaths, data_backend_id: str):
return torch.stack(latents)


def compute_single_embedding(caption, text_embed_cache, is_sdxl):
"""Worker function to compute embedding for a single caption."""
if is_sdxl:
(
prompt_embeds,
add_text_embeds,
) = text_embed_cache.compute_embeddings_for_sdxl_prompts([caption])
return (
prompt_embeds[0],
add_text_embeds[0],
) # Unpack the first (and only) element
else:
prompt_embeds = text_embed_cache.compute_embeddings_for_legacy_prompts(
[caption]
)
return prompt_embeds[0], None # Unpack and return None for the second element


def compute_prompt_embeddings(captions, text_embed_cache):
"""
Retrieve / compute text embeds in parallel.
Args:
captions: List of strings
text_embed_cache: TextEmbedCache instance

Returns:
prompt_embeds_all: Tensor of shape (batch_size, 512)
add_text_embeds_all: Tensor of shape (batch_size, 512)
"""
debug_log(" -> get embed from cache")
if text_embed_cache.model_type == "sdxl":
(
prompt_embeds_all,
add_text_embeds_all,
) = text_embed_cache.compute_embeddings_for_sdxl_prompts(captions)
debug_log(" -> concat embeds")
is_sdxl = text_embed_cache.model_type == "sdxl"

# Use a thread pool to compute embeddings concurrently
with ThreadPoolExecutor() as executor:
embeddings = list(
executor.map(
compute_single_embedding,
captions,
[text_embed_cache] * len(captions),
[is_sdxl] * len(captions),
)
)

logger.debug(f"Got embeddings: {embeddings}")
if is_sdxl:
# Separate the tuples
prompt_embeds = [t[0] for t in embeddings]
add_text_embeds = [t[1] for t in embeddings]
return (torch.stack(prompt_embeds), torch.stack(add_text_embeds))
else:
debug_log(" -> concat embeds")
prompt_embeds_all = text_embed_cache.compute_embeddings_for_legacy_prompts(
captions
)[0]
prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0)
prompt_embeds_all = torch.concat(embeddings, dim=0)
return prompt_embeds_all, None
prompt_embeds_all = torch.concat([prompt_embeds_all for _ in range(1)], dim=0)
add_text_embeds_all = torch.concat([add_text_embeds_all for _ in range(1)], dim=0)
return prompt_embeds_all, add_text_embeds_all


def gather_conditional_size_features(examples, latents, weight_dtype):
Expand Down
Loading