Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

text embedding cache speed-up for slow backends (eg. S3 or spinning disks) #271

Merged
merged 14 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion documentation/DATALOADER.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,7 @@ Here is an example dataloader configuration file, as `multidatabackend.example.j
- Specifies the number of times all samples in the dataset are seen during an epoch. Useful for giving more impact to smaller datasets or maximizing the usage of VAE cache objects.

### `vae_cache_clear_each_epoch`
- When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets.
- When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets.

### `ignore_epochs`
- When enabled, this dataset will not hold up the rest of the datasets from completing an epoch. This will inherently make the value for the current epoch inaccurate, as it reflects only the number of times any datasets *without* this flag have completed all of their repeats. The state of the ignored dataset isn't reset upon the next epoch, it is simply ignored. It will eventually run out of samples as a dataset typically does. At that time it will be removed from consideration until the next natural epoch completes.
44 changes: 9 additions & 35 deletions helpers/caching/sdxl_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

logger = logging.getLogger("TextEmbeddingCache")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
logger.setLevel("DEBUG")


class TextEmbeddingCache:
Expand All @@ -28,7 +27,7 @@ def __init__(
cache_dir: str = "cache",
model_type: str = "sdxl",
prompt_handler: PromptHandler = None,
write_batch_size: int = 25,
write_batch_size: int = 128,
read_batch_size: int = 25,
process_queue_size: int = 16,
text_encoder_batch_size: int = 4,
Expand Down Expand Up @@ -80,48 +79,30 @@ def discover_all_files(self, directory: str = None):
)
)
self.debug_log(" -> done listing all text embed cache entries")
self.debug_log(
f" -> {StateTracker.get_text_cache_files(data_backend_id=self.id)}"
)

def save_to_cache(self, filename, embeddings):
"""Add write requests to the queue instead of writing directly."""
self.write_queue.put((embeddings, filename))
self.debug_log(
f"Pushing cache object into write queue. We have {self.write_queue.qsize()} items in the queue."
)

def batch_write_embeddings(self):
"""Process write requests in batches."""
while True:
batch = []
while not self.write_queue.empty() and len(batch) < self.write_batch_size:
self.debug_log(
f"Adding to batch, currently at {len(batch)} embeds. Waiting for {self.write_batch_size} embeds before we process"
)
batch.append(self.write_queue.get())

if len(batch) >= self.write_batch_size:
self.debug_log(
f"Processing batch of {len(batch)} embeds, as we reached our threshold of {self.write_batch_size}"
)
self.process_write_batch(batch)
elif self.write_queue.empty() and len(batch) > 0:
self.debug_log(
f"Processing batch of {len(batch)} embeds, as the queue is empty."
)
self.process_write_batch(batch)

if not self.process_write_batches:
if not self.process_write_batches and self.write_queue.empty():
# End the loop if we are done.
break
time.sleep(1) # Prevents the thread from being too busy-waiting
time.sleep(0.01) # Prevents the thread from being too busy-waiting

def process_write_batch(self, batch):
"""Write a batch of embeddings to the cache."""
self.debug_log(
f"Processing write batch of {len(batch)} embeds via process_write_batch"
)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = [
executor.submit(self.data_backend.torch_save, *args) for args in batch
Expand All @@ -130,9 +111,7 @@ def process_write_batch(self, batch):
future.result() # Wait for all writes to complete

def load_from_cache(self, filename):
logger.debug("Begin load from cache.")
result = self.data_backend.torch_load(filename)
logger.debug("Completed load from cache.")
return result

def encode_legacy_prompt(self, text_encoder, tokenizer, prompt):
Expand Down Expand Up @@ -219,7 +198,6 @@ def encode_sdxl_prompts(
pooled_prompt_embeds_all = []

for prompt in prompts:
self.debug_log(f"Encoding prompt: {prompt}")
prompt_embeds, pooled_prompt_embeds = self.encode_sdxl_prompt(
text_encoders, tokenizers, prompt, is_validation
)
Expand All @@ -246,12 +224,15 @@ def compute_embeddings_for_prompts(
return_concat: bool = True,
is_validation: bool = False,
):
if not self.batch_write_thread.is_alive():
# Start the thread again.
self.process_write_batches = True
self.batch_write_thread = Thread(target=self.batch_write_embeddings)
self.batch_write_thread.start()
existing_cache_filenames = list(
StateTracker.get_text_cache_files(data_backend_id=self.id).keys()
)
all_cache_filenames = [f"{self.create_hash(p)}.pt" for p in all_prompts]
self.debug_log(f"Existing cache filenames: {existing_cache_filenames}")
self.debug_log(f"All cache filenames: {all_cache_filenames}")
# Check if we have all the files in the cache
if (
not is_validation
Expand Down Expand Up @@ -316,14 +297,7 @@ def compute_embeddings_for_sdxl_prompts(
filename = os.path.join(
self.cache_dir, self.create_hash(prompt) + ".pt"
)
self.debug_log(f"Checking for cache file: {filename}")
if (
self.data_backend.exists(filename)
and load_from_cache
and not return_concat
):
continue
if self.data_backend.exists(filename) and load_from_cache:
if return_concat and load_from_cache:
prompt_embeds, add_text_embeds = self.load_from_cache(filename)
else:
self.debug_log(f"Encoding prompt: {prompt}")
Expand Down
1 change: 1 addition & 0 deletions helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
def exists(self, s3_key) -> bool:
"""Determine whether a file exists in S3."""
try:
logger.debug(f"Checking if file exists: {s3_key}")
self.client.head_object(
Bucket=self.bucket_name, Key=self._convert_path_to_key(str(s3_key))
)
Expand Down
80 changes: 46 additions & 34 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
]
if "probability" in backend:
output["config"]["probability"] = backend["probability"]
if "ignore_epochs" in backend:
output["config"]["ignore_epochs"] = backend["ignore_epochs"]
if "repeats" in backend:
output["config"]["repeats"] = backend["repeats"]
if "crop" in backend:
Expand Down Expand Up @@ -159,9 +161,9 @@ def configure_multi_databackend(
):
logger.info("Pre-computing null embedding for caption dropout")
with accelerator.main_process_first():
init_backend[
"text_embed_cache"
].compute_embeddings_for_sdxl_prompts([""], return_concat=False)
init_backend["text_embed_cache"].compute_embeddings_for_prompts(
[""], return_concat=False
)
accelerator.wait_for_everyone()
else:
logger.warning(
Expand Down Expand Up @@ -501,6 +503,39 @@ def get_aws_backend(
step = None


def select_dataloader_index(step, backends):
# Generate weights for each backend based on some criteria
weights = []
backend_ids = []
for backend_id, backend in backends.items():
weight = get_backend_weight(backend_id, backend, step)
weights.append(weight)
backend_ids.append(backend_id)

# Convert to a torch tensor for easy sampling
weights = torch.tensor(weights)
weights /= weights.sum() # Normalize the weights

if weights.sum() == 0:
return None

# Sample a backend index based on the weights
chosen_index = torch.multinomial(weights, 1).item()
return backend_ids[chosen_index]


def get_backend_weight(backend_id, backend, step):
# Implement your logic to determine the weight for each backend
# For example, a simple linear decay based on the step count
backend_config = StateTracker.get_data_backend_config(backend_id)
prob = backend_config.get("probability", 1)
disable_step = backend_config.get("disable_after_epoch_step", float("inf"))
adjusted_prob = (
0 if step > disable_step else max(0, prob * (1 - step / disable_step))
)
return adjusted_prob


def random_dataloader_iterator(backends: dict):
global step
if step is None:
Expand Down Expand Up @@ -547,39 +582,16 @@ def random_dataloader_iterator(backends: dict):
StateTracker.backend_exhausted(chosen_backend_id)
StateTracker.set_repeats(data_backend_id=chosen_backend_id, repeats=0)
finally:
if not backends:
if not backends or all(
[
StateTracker.get_data_backend_config(backend_id).get(
"ignore_epochs", False
)
for backend_id in backends
]
):
logger.debug(
"All dataloaders exhausted. Moving to next epoch in main training loop."
)
StateTracker.clear_exhausted_buckets()
return (step, None)


def select_dataloader_index(step, backends):
adjusted_probabilities = {}
logger.debug(f"Selecting from backends: {backends.keys()}")
for backend_id, dataloader in backends.items():
backend = StateTracker.get_data_backend(backend_id)
prob = backend["config"].get(" `", 1)
disable_step = backend["config"].get("disable_after_epoch_step", float("inf"))

adjusted_prob = (
0 if step > disable_step else max(0, prob * (1 - step / disable_step))
)
adjusted_probabilities[backend_id] = adjusted_prob

# Shuffle the backends
items = list(adjusted_probabilities.items())
random.shuffle(items)
total_prob = sum(prob for _, prob in items)
if total_prob == 0:
return None

rnd = random.uniform(0, total_prob)
cumulative_prob = 0
for backend_id, prob in items: # Use shuffled order
cumulative_prob += prob
if rnd < cumulative_prob:
return backend_id

return None
31 changes: 1 addition & 30 deletions helpers/data_backend/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def write(self, filepath: str, data: Any) -> None:
f"Received an unknown data type to write to disk. Doing our best: {type(data)}"
)
file.write(data)
# Check if file exists:
if not self.exists(filepath):
raise Exception(f"Failed to write to {filepath}")
logger.debug(f"Completed write()")

def delete(self, filepath):
"""Delete the specified file."""
Expand All @@ -58,7 +54,7 @@ def delete(self, filepath):
def exists(self, filepath):
"""Check if the file exists."""
result = os.path.exists(filepath)
# logger.debug(f"Checking if {filepath} exists = {result}")
logger.debug(f"Checking if {filepath} exists = {result}")
return result

def open_file(self, filepath, mode):
Expand Down Expand Up @@ -153,7 +149,6 @@ def read_image_batch(
def create_directory(self, directory_path):
logger.debug(f"Creating directory: {directory_path}")
os.makedirs(directory_path, exist_ok=True)
logger.debug(f"Created directory.")

def torch_load(self, filename):
# Check if file exists:
Expand All @@ -164,39 +159,15 @@ def torch_load(self, filename):
)

def torch_save(self, data, original_location):
logger.debug("Calling torch_save on Local backend.")
if type(original_location) == str:
# A file path was given. Open it.
logger.debug(f"Using file path: {original_location}")
location = self.open_file(original_location, "wb")
else:
# A file object was given. Use it.
logger.debug(f"Using file object: {original_location}")
location = original_location
logger.debug(f"Torch location {original_location} save was given data: {data}")
torch.save(data, location)
# Check whether the file created:
if type(original_location) == str:
# A file path was given. Check it.
if not self.exists(original_location):
raise Exception(f"Failed to write to {original_location}")
elif hasattr(original_location, "name"):
# A file object was given. Check it.
if not self.exists(original_location.name):
raise Exception(f"Failed to write to {original_location.name}")
else:
import traceback

raise Exception(
f"Unknown error writing to {original_location}, traceback: {traceback.format_exc()}"
)

def write_batch(self, filepaths: list, data_list: list) -> None:
"""Write a batch of data to the specified filepaths."""
logger.debug(f"Reached write_batch in LocalDataBackend.")
for filepath, data in zip(filepaths, data_list):
self.write(filepath, data)
# Check if file was written:
if not self.exists(filepath):
raise Exception(f"Failed to write to {filepath}")
logger.debug(f"Succesfully validated file creation: {filepath}")
3 changes: 2 additions & 1 deletion multidatabackend.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
"aws_data_prefix": "",
"cache_dir_vae": "s3prefix/for/vaecache",
"vae_cache_clear_each_epoch": true,
"repeats": 2
"repeats": 2,
"ignore_epochs": false
},
{
"id": "an example backend for text embeds.",
Expand Down
Loading