Skip to content

Commit

Permalink
Merge pull request #603 from bghira/debug/mgpu-cache
Browse files Browse the repository at this point in the history
fix for relative cache directories with NoneType being unsubscriptable
  • Loading branch information
bghira authored Jul 27, 2024
2 parents f9845f8 + f63fc19 commit 5cd9b34
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(
self.vae = vae
self.accelerator = accelerator
self.cache_dir = cache_dir
if self.cache_data_backend.type == "local":
self.cache_dir = os.path.abspath(self.cache_dir)
if len(self.cache_dir) > 0 and self.cache_dir[-1] == "/":
# Remove trailing slash
self.cache_dir = self.cache_dir[:-1]
Expand Down Expand Up @@ -127,34 +129,47 @@ def debug_log(self, msg: str):

def generate_vae_cache_filename(self, filepath: str) -> tuple:
"""Get the cache filename for a given image filepath and its base name."""
if self.instance_data_root not in filepath:
if self.cache_dir in filepath and filepath.endswith(".pt"):
return filepath, os.path.basename(filepath)
if filepath.endswith(".pt"):
return filepath, os.path.basename(filepath)
# Extract the base name from the filepath and replace the image extension with .pt
base_filename = os.path.splitext(os.path.basename(filepath))[0]
if self.hash_filenames:
base_filename = str(sha256(str(base_filename).encode()).hexdigest())
base_filename = str(base_filename) + ".pt"
# Find the subfolders the sample was in, and replace the instance_data_root with the cache_dir
subfolders = os.path.dirname(filepath).replace(self.instance_data_root, "")
if len(subfolders) > 0 and subfolders[0] == "/":
if len(subfolders) > 0 and subfolders[0] == "/" and self.cache_dir[0] != "/":
subfolders = subfolders[1:]
full_filename = os.path.join(self.cache_dir, subfolders, base_filename)
# logger.debug(
# f"full_filename: {full_filename} = os.path.join({self.cache_dir}, {subfolders}, {base_filename})"
# )
else:
full_filename = os.path.join(self.cache_dir, base_filename)
# logger.debug(
# f"full_filename: {full_filename} = os.path.join({self.cache_dir}, {base_filename})"
# )
return full_filename, base_filename

def _image_filename_from_vaecache_filename(self, filepath: str) -> tuple[str, str]:
test_filepath, _ = self.generate_vae_cache_filename(filepath)
result = self.vae_path_to_image_path.get(test_filepath, None)
if result is None:
logger.debug(f"Mapping: {self.vae_path_to_image_path}")
raise ValueError(
f"Could not find image path for cache file {filepath} (test_filepath: {test_filepath}). Is the map built? {True if self.vae_path_to_image_path != {} else False}"
)

return self.vae_path_to_image_path.get(test_filepath, None)
return result

def build_vae_cache_filename_map(self, all_image_files: list):
"""Build a map of image filepaths to their corresponding cache filenames."""
self.image_path_to_vae_path = {}
self.vae_path_to_image_path = {}
for image_file in all_image_files:
cache_filename, _ = self.generate_vae_cache_filename(image_file)
if self.cache_data_backend.type == "local":
cache_filename = os.path.abspath(cache_filename)
self.image_path_to_vae_path[image_file] = cache_filename
self.vae_path_to_image_path[cache_filename] = image_file

Expand Down

0 comments on commit 5cd9b34

Please sign in to comment.