Skip to content

Commit

Permalink
Merge pull request #604 from bghira/main
Browse files Browse the repository at this point in the history
multigpu / relative path fixes for caching
  • Loading branch information
bghira authored Jul 27, 2024
2 parents 6e7fcb3 + 5cd9b34 commit 679c6c0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
28 changes: 21 additions & 7 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(
self.vae = vae
self.accelerator = accelerator
self.cache_dir = cache_dir
if self.cache_data_backend.type == "local":
self.cache_dir = os.path.abspath(self.cache_dir)
if len(self.cache_dir) > 0 and self.cache_dir[-1] == "/":
# Remove trailing slash
self.cache_dir = self.cache_dir[:-1]
Expand Down Expand Up @@ -127,34 +129,47 @@ def debug_log(self, msg: str):

def generate_vae_cache_filename(self, filepath: str) -> tuple:
"""Get the cache filename for a given image filepath and its base name."""
if self.instance_data_root not in filepath:
if self.cache_dir in filepath and filepath.endswith(".pt"):
return filepath, os.path.basename(filepath)
if filepath.endswith(".pt"):
return filepath, os.path.basename(filepath)
# Extract the base name from the filepath and replace the image extension with .pt
base_filename = os.path.splitext(os.path.basename(filepath))[0]
if self.hash_filenames:
base_filename = str(sha256(str(base_filename).encode()).hexdigest())
base_filename = str(base_filename) + ".pt"
# Find the subfolders the sample was in, and replace the instance_data_root with the cache_dir
subfolders = os.path.dirname(filepath).replace(self.instance_data_root, "")
if len(subfolders) > 0 and subfolders[0] == "/":
if len(subfolders) > 0 and subfolders[0] == "/" and self.cache_dir[0] != "/":
subfolders = subfolders[1:]
full_filename = os.path.join(self.cache_dir, subfolders, base_filename)
# logger.debug(
# f"full_filename: {full_filename} = os.path.join({self.cache_dir}, {subfolders}, {base_filename})"
# )
else:
full_filename = os.path.join(self.cache_dir, base_filename)
# logger.debug(
# f"full_filename: {full_filename} = os.path.join({self.cache_dir}, {base_filename})"
# )
return full_filename, base_filename

def _image_filename_from_vaecache_filename(self, filepath: str) -> tuple[str, str]:
test_filepath, _ = self.generate_vae_cache_filename(filepath)
result = self.vae_path_to_image_path.get(test_filepath, None)
if result is None:
logger.debug(f"Mapping: {self.vae_path_to_image_path}")
raise ValueError(
f"Could not find image path for cache file {filepath} (test_filepath: {test_filepath}). Is the map built? {True if self.vae_path_to_image_path != {} else False}"
)

return self.vae_path_to_image_path.get(test_filepath, None)
return result

def _build_vae_cache_filename_map(self, all_image_files: list):
def build_vae_cache_filename_map(self, all_image_files: list):
"""Build a map of image filepaths to their corresponding cache filenames."""
self.image_path_to_vae_path = {}
self.vae_path_to_image_path = {}
for image_file in all_image_files:
cache_filename, _ = self.generate_vae_cache_filename(image_file)
if self.cache_data_backend.type == "local":
cache_filename = os.path.abspath(cache_filename)
self.image_path_to_vae_path[image_file] = cache_filename
self.vae_path_to_image_path[cache_filename] = image_file

Expand Down Expand Up @@ -219,7 +234,6 @@ def discover_all_files(self):
),
data_backend_id=self.id,
)
self._build_vae_cache_filename_map(all_image_files)
# This isn't returned, because we merely check if it's stored, or, store it.
(
StateTracker.get_vae_cache_files(data_backend_id=self.id)
Expand Down
6 changes: 6 additions & 0 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,12 @@ def configure_multi_databackend(
if accelerator.is_local_main_process:
init_backend["vaecache"].discover_all_files()
accelerator.wait_for_everyone()
all_image_files = StateTracker.get_image_files(
data_backend_id=init_backend["id"]
)
init_backend["vaecache"].build_vae_cache_filename_map(
all_image_files=all_image_files
)

if (
(
Expand Down

0 comments on commit 679c6c0

Please sign in to comment.