Skip to content

Commit

Permalink
Merge pull request #607 from bghira/feature/config-file-versioning
Browse files Browse the repository at this point in the history
config file versioning to allow updating defaults without breaking backwards compat
  • Loading branch information
bghira authored Jul 29, 2024
2 parents 11074cf + f0b400d commit 0653855
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 45 deletions.
5 changes: 0 additions & 5 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,11 +1020,6 @@ def compute_embeddings_for_sd3_prompts(

return prompt_embeds_all, add_text_embeds_all

def split_cache_between_processes(self, prompts: list):
# Use the accelerator to split the data
with self.accelerator.split_between_processes(prompts) as split_files:
self.prompts = split_files

def __del__(self):
"""Ensure that the batch write thread is properly closed."""
if self.batch_write_thread.is_alive():
Expand Down
26 changes: 4 additions & 22 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ def _image_filename_from_vaecache_filename(self, filepath: str) -> tuple[str, st
test_filepath, _ = self.generate_vae_cache_filename(filepath)
result = self.vae_path_to_image_path.get(test_filepath, None)
if result is None:
logger.debug(f"Mapping: {self.vae_path_to_image_path}")
raise ValueError(
f"Could not find image path for cache file {filepath} (test_filepath: {test_filepath}). Is the map built? {True if self.vae_path_to_image_path != {} else False}"
f"Could not find image path for cache file {filepath} (test_filepath: {test_filepath}). This occurs when you toggle the value for hashed_filenames without clearing your VAE cache. If it still occurs after clearing the cache, please open an issue: https://github.com/bghira/simpletuner/issues"
)

return result
Expand Down Expand Up @@ -288,7 +287,7 @@ def rebuild_cache(self):
self.debug_log("-> Clearing cache objects")
self.clear_cache()
self.debug_log("-> Split tasks between GPU(s)")
self.split_cache_between_processes()
self.discover_unprocessed_files()
self.debug_log("-> Load VAE")
self.init_vae()
if StateTracker.get_args().vae_cache_preprocess:
Expand Down Expand Up @@ -366,13 +365,13 @@ def discover_unprocessed_files(self, directory: str = None):
}

# Identify unprocessed files
unprocessed_files = [
self.local_unprocessed_files = [
file
for file in all_image_files
if os.path.splitext(file)[0] not in existing_image_filenames
]

return unprocessed_files
return self.local_unprocessed_files

def _reduce_bucket(
self,
Expand Down Expand Up @@ -416,23 +415,6 @@ def _reduce_bucket(
# )
return relevant_files

def split_cache_between_processes(self):
self.local_unprocessed_files = self.discover_unprocessed_files(self.cache_dir)
"""
We used to split the VAE cache between GPU processes, but instead, we split the buckets.
This code remains as an artifact. It is no longer needed, as it causes a misalignment
between the assigned slice for this GPU and its slice of already-processed images.
"""
# # Use the accelerator to split the data
# with self.accelerator.split_between_processes(
# all_unprocessed_files
# ) as split_files:
# self.local_unprocessed_files = split_files
# self.debug_log(
# f"Before splitting, we had {len(all_unprocessed_files)} unprocessed files. After splitting, we have {len(self.local_unprocessed_files)} unprocessed files."
# )

def encode_images(self, images, filepaths, load_from_cache=True):
"""
Encode a batch of input images. Images must be the same dimension.
Expand Down
57 changes: 39 additions & 18 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from helpers.data_backend.aws import S3DataBackend
from helpers.data_backend.csv import CSVDataBackend
from helpers.data_backend.base import BaseDataBackend
from helpers.training.default_settings import default, latest_config_version
from helpers.caching.text_embeds import TextEmbeddingCache

from helpers.training.exceptions import MultiDatasetExhausted
Expand Down Expand Up @@ -138,6 +139,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
output["config"]["instance_data_dir"] = backend.get(
"instance_data_dir", backend.get("aws_data_prefix", "")
)
if "hash_filenames" in backend:
output["config"]["hash_filenames"] = backend["hash_filenames"]
if "shorten_filenames" in backend and backend.get("type") == "csv":
output["config"]["shorten_filenames"] = backend["shorten_filenames"]

Expand Down Expand Up @@ -697,30 +700,43 @@ def configure_multi_databackend(
"target_downsample_size",
"parquet",
]
# we will set the latest version by default.
current_config_version = latest_config_version()
if init_backend["metadata_backend"].config != {}:
prev_config = init_backend["metadata_backend"].config
logger.debug(f"Found existing config: {prev_config}")
# if the prev config used an old default config version, we will update defaults here.
current_config_version = prev_config.get("config_version", None)
if current_config_version is None:
# backwards compatibility for non-versioned config files, so that we do not enable life-changing options.
current_config_version = 1
logger.debug(
f"Found existing config (version={current_config_version}): {prev_config}"
)
logger.debug(f"Comparing against new config: {init_backend['config']}")
# Check if any values differ between the 'backend' values and the 'config' values:
for key, _ in prev_config.items():
logger.debug(f"Checking config key: {key}")
if (
key in backend
and prev_config[key] != backend[key]
and key not in excluded_keys
):
if not args.override_dataset_config:
raise Exception(
f"Dataset {init_backend['id']} has inconsistent config, and --override_dataset_config was not provided."
f"\n-> Expected value {key}={prev_config[key]} differs from current value={backend[key]}."
f"\n-> Recommended action is to correct the current config values to match the values that were used to create this dataset:"
f"\n{prev_config}"
)
else:
if key not in excluded_keys:
if key in backend and prev_config[key] != backend[key]:
if not args.override_dataset_config:
raise Exception(
f"Dataset {init_backend['id']} has inconsistent config, and --override_dataset_config was not provided."
f"\n-> Expected value {key}={prev_config.get(key)} differs from current value={backend.get(key)}."
f"\n-> Recommended action is to correct the current config values to match the values that were used to create this dataset:"
f"\n{prev_config}"
)
else:
logger.warning(
f"Overriding config value {key}={prev_config[key]} with {backend[key]}"
)
prev_config[key] = backend[key]
elif key not in backend:
logger.warning(
f"Overriding config value {key}={prev_config[key]} with {backend[key]}"
f"Key {key} not found in the current backend config, using the existing value {prev_config[key]}."
)
prev_config[key] = backend[key]
init_backend["config"][key] = prev_config[key]

init_backend["config"]["config_version"] = current_config_version
StateTracker.set_data_backend_config(init_backend["id"], init_backend["config"])
info_log(f"Configured backend: {init_backend}")

Expand Down Expand Up @@ -841,6 +857,11 @@ def configure_multi_databackend(
# Register the backend here so the sampler can be found.
StateTracker.register_data_backend(init_backend)

default_hash_option = default("hash_filenames", current_config_version)
hash_filenames = init_backend.get("config", {}).get(
"hash_filenames", default_hash_option
)

if "deepfloyd" not in StateTracker.get_args().model_type:
info_log(f"(id={init_backend['id']}) Creating VAE latent cache.")
init_backend["vaecache"] = VAECache(
Expand Down Expand Up @@ -876,7 +897,7 @@ def configure_multi_databackend(
max_workers=backend.get("max_workers", 32),
process_queue_size=backend.get("process_queue_size", 64),
vae_cache_preprocess=args.vae_cache_preprocess,
hash_filenames=backend.get("hash_filenames", False),
hash_filenames=hash_filenames,
)

if args.vae_cache_preprocess:
Expand Down Expand Up @@ -921,7 +942,7 @@ def configure_multi_databackend(
and "vae" not in args.skip_file_discovery
and "vae" not in backend.get("skip_file_discovery", "")
):
init_backend["vaecache"].split_cache_between_processes()
init_backend["vaecache"].discover_unprocessed_files()
if args.vae_cache_preprocess:
init_backend["vaecache"].process_buckets()
logger.debug(
Expand Down
15 changes: 15 additions & 0 deletions helpers/training/default_settings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
CURRENT_VERSION = 2

LATEST_DEFAULTS = {1: {"hash_filenames": False}, 2: {"hash_filenames": True}}


def default(setting: str, current_version: int = None, default_value=None):
if current_version <= 0 or current_version is None:
current_version = CURRENT_VERSION
if current_version in LATEST_DEFAULTS:
return LATEST_DEFAULTS[current_version].get(setting, default_value)
return default_value


def latest_config_version():
return CURRENT_VERSION

0 comments on commit 0653855

Please sign in to comment.