Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multigpu / relative path fixes for caching #604

Merged
merged 14 commits into from
Jul 27, 2024
28 changes: 21 additions & 7 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(
self.vae = vae
self.accelerator = accelerator
self.cache_dir = cache_dir
if self.cache_data_backend.type == "local":
self.cache_dir = os.path.abspath(self.cache_dir)
if len(self.cache_dir) > 0 and self.cache_dir[-1] == "/":
# Remove trailing slash
self.cache_dir = self.cache_dir[:-1]
Expand Down Expand Up @@ -127,34 +129,47 @@ def debug_log(self, msg: str):

def generate_vae_cache_filename(self, filepath: str) -> tuple:
"""Get the cache filename for a given image filepath and its base name."""
if self.instance_data_root not in filepath:
if self.cache_dir in filepath and filepath.endswith(".pt"):
return filepath, os.path.basename(filepath)
if filepath.endswith(".pt"):
return filepath, os.path.basename(filepath)
# Extract the base name from the filepath and replace the image extension with .pt
base_filename = os.path.splitext(os.path.basename(filepath))[0]
if self.hash_filenames:
base_filename = str(sha256(str(base_filename).encode()).hexdigest())
base_filename = str(base_filename) + ".pt"
# Find the subfolders the sample was in, and replace the instance_data_root with the cache_dir
subfolders = os.path.dirname(filepath).replace(self.instance_data_root, "")
if len(subfolders) > 0 and subfolders[0] == "/":
if len(subfolders) > 0 and subfolders[0] == "/" and self.cache_dir[0] != "/":
subfolders = subfolders[1:]
full_filename = os.path.join(self.cache_dir, subfolders, base_filename)
# logger.debug(
# f"full_filename: {full_filename} = os.path.join({self.cache_dir}, {subfolders}, {base_filename})"
# )
else:
full_filename = os.path.join(self.cache_dir, base_filename)
# logger.debug(
# f"full_filename: {full_filename} = os.path.join({self.cache_dir}, {base_filename})"
# )
return full_filename, base_filename

def _image_filename_from_vaecache_filename(self, filepath: str) -> tuple[str, str]:
test_filepath, _ = self.generate_vae_cache_filename(filepath)
result = self.vae_path_to_image_path.get(test_filepath, None)
if result is None:
logger.debug(f"Mapping: {self.vae_path_to_image_path}")
raise ValueError(
f"Could not find image path for cache file {filepath} (test_filepath: {test_filepath}). Is the map built? {True if self.vae_path_to_image_path != {} else False}"
)

return self.vae_path_to_image_path.get(test_filepath, None)
return result

def _build_vae_cache_filename_map(self, all_image_files: list):
def build_vae_cache_filename_map(self, all_image_files: list):
"""Build a map of image filepaths to their corresponding cache filenames."""
self.image_path_to_vae_path = {}
self.vae_path_to_image_path = {}
for image_file in all_image_files:
cache_filename, _ = self.generate_vae_cache_filename(image_file)
if self.cache_data_backend.type == "local":
cache_filename = os.path.abspath(cache_filename)
self.image_path_to_vae_path[image_file] = cache_filename
self.vae_path_to_image_path[cache_filename] = image_file

Expand Down Expand Up @@ -219,7 +234,6 @@ def discover_all_files(self):
),
data_backend_id=self.id,
)
self._build_vae_cache_filename_map(all_image_files)
# This isn't returned, because we merely check if it's stored, or, store it.
(
StateTracker.get_vae_cache_files(data_backend_id=self.id)
Expand Down
6 changes: 6 additions & 0 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,12 @@ def configure_multi_databackend(
if accelerator.is_local_main_process:
init_backend["vaecache"].discover_all_files()
accelerator.wait_for_everyone()
all_image_files = StateTracker.get_image_files(
data_backend_id=init_backend["id"]
)
init_backend["vaecache"].build_vae_cache_filename_map(
all_image_files=all_image_files
)

if (
(
Expand Down
Loading