diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index 8b6ef5a8..be9ab9a2 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -86,6 +86,8 @@ def __init__( self.vae = vae self.accelerator = accelerator self.cache_dir = cache_dir + if self.cache_data_backend.type == "local": + self.cache_dir = os.path.abspath(self.cache_dir) if len(self.cache_dir) > 0 and self.cache_dir[-1] == "/": # Remove trailing slash self.cache_dir = self.cache_dir[:-1] @@ -127,9 +129,8 @@ def debug_log(self, msg: str): def generate_vae_cache_filename(self, filepath: str) -> tuple: """Get the cache filename for a given image filepath and its base name.""" - if self.instance_data_root not in filepath: - if self.cache_dir in filepath and filepath.endswith(".pt"): - return filepath, os.path.basename(filepath) + if filepath.endswith(".pt"): + return filepath, os.path.basename(filepath) # Extract the base name from the filepath and replace the image extension with .pt base_filename = os.path.splitext(os.path.basename(filepath))[0] if self.hash_filenames: @@ -137,17 +138,29 @@ def generate_vae_cache_filename(self, filepath: str) -> tuple: base_filename = str(base_filename) + ".pt" # Find the subfolders the sample was in, and replace the instance_data_root with the cache_dir subfolders = os.path.dirname(filepath).replace(self.instance_data_root, "") - if len(subfolders) > 0 and subfolders[0] == "/": + if len(subfolders) > 0 and subfolders[0] == "/" and self.cache_dir[0] != "/": subfolders = subfolders[1:] full_filename = os.path.join(self.cache_dir, subfolders, base_filename) + # logger.debug( + # f"full_filename: {full_filename} = os.path.join({self.cache_dir}, {subfolders}, {base_filename})" + # ) else: full_filename = os.path.join(self.cache_dir, base_filename) + # logger.debug( + # f"full_filename: {full_filename} = os.path.join({self.cache_dir}, {base_filename})" + # ) return full_filename, base_filename def _image_filename_from_vaecache_filename(self, filepath: str) -> tuple[str, str]: test_filepath, _ = self.generate_vae_cache_filename(filepath) + result = self.vae_path_to_image_path.get(test_filepath, None) + if result is None: + logger.debug(f"Mapping: {self.vae_path_to_image_path}") + raise ValueError( + f"Could not find image path for cache file {filepath} (test_filepath: {test_filepath}). Is the map built? {True if self.vae_path_to_image_path != {} else False}" + ) - return self.vae_path_to_image_path.get(test_filepath, None) + return result def build_vae_cache_filename_map(self, all_image_files: list): """Build a map of image filepaths to their corresponding cache filenames.""" @@ -155,6 +168,8 @@ def build_vae_cache_filename_map(self, all_image_files: list): self.vae_path_to_image_path = {} for image_file in all_image_files: cache_filename, _ = self.generate_vae_cache_filename(image_file) + if self.cache_data_backend.type == "local": + cache_filename = os.path.abspath(cache_filename) self.image_path_to_vae_path[image_file] = cache_filename self.vae_path_to_image_path[cache_filename] = image_file