diff --git a/folder_paths.py b/folder_paths.py index 71faa2df4db..3db1da61a94 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,13 +1,13 @@ +from __future__ import annotations + import os import time import logging -from typing import Set, List, Dict, Tuple +from collections.abc import Collection -supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft']) +supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} -SupportedFileExtensionsType = Set[str] -ScanPathType = List[str] -folder_names_and_paths: Dict[str, Tuple[ScanPathType, SupportedFileExtensionsType]] = {} +folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {} base_path = os.path.dirname(os.path.realpath(__file__)) models_dir = os.path.join(base_path, "models") @@ -42,7 +42,7 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") -filename_list_cache = {} +filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {} if not os.path.exists(input_directory): try: @@ -50,33 +50,33 @@ except: logging.error("Failed to create input directory") -def set_output_directory(output_dir): +def set_output_directory(output_dir: str) -> None: global output_directory output_directory = output_dir -def set_temp_directory(temp_dir): +def set_temp_directory(temp_dir: str) -> None: global temp_directory temp_directory = temp_dir -def set_input_directory(input_dir): +def set_input_directory(input_dir: str) -> None: global input_directory input_directory = input_dir -def get_output_directory(): +def get_output_directory() -> str: global output_directory return output_directory -def get_temp_directory(): +def get_temp_directory() -> str: global temp_directory return temp_directory -def get_input_directory(): +def get_input_directory() -> str: global input_directory return input_directory #NOTE: used in http server so don't put folders that should not be accessed remotely -def get_directory_by_type(type_name): +def get_directory_by_type(type_name: str) -> str | None: if type_name == "output": return get_output_directory() if type_name == "temp": @@ -88,7 +88,7 @@ def get_directory_by_type(type_name): # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # otherwise use default_path as base_dir -def annotated_filepath(name): +def annotated_filepath(name: str) -> tuple[str, str | None]: if name.endswith("[output]"): base_dir = get_output_directory() name = name[:-9] @@ -104,7 +104,7 @@ def annotated_filepath(name): return name, base_dir -def get_annotated_filepath(name, default_dir=None): +def get_annotated_filepath(name: str, default_dir: str | None=None) -> str: name, base_dir = annotated_filepath(name) if base_dir is None: @@ -116,7 +116,7 @@ def get_annotated_filepath(name, default_dir=None): return os.path.join(base_dir, name) -def exists_annotated_filepath(name): +def exists_annotated_filepath(name) -> bool: name, base_dir = annotated_filepath(name) if base_dir is None: @@ -126,17 +126,17 @@ def exists_annotated_filepath(name): return os.path.exists(filepath) -def add_model_folder_path(folder_name, full_folder_path): +def add_model_folder_path(folder_name: str, full_folder_path: str) -> None: global folder_names_and_paths if folder_name in folder_names_and_paths: folder_names_and_paths[folder_name][0].append(full_folder_path) else: folder_names_and_paths[folder_name] = ([full_folder_path], set()) -def get_folder_paths(folder_name): +def get_folder_paths(folder_name: str) -> list[str]: return folder_names_and_paths[folder_name][0][:] -def recursive_search(directory, excluded_dir_names=None): +def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]: if not os.path.isdir(directory): return [], {} @@ -153,6 +153,10 @@ def recursive_search(directory, excluded_dir_names=None): logging.warning(f"Warning: Unable to access {directory}. Skipping this path.") logging.debug("recursive file list on directory {}".format(directory)) + dirpath: str + subdirs: list[str] + filenames: list[str] + for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] for file_name in filenames: @@ -160,7 +164,7 @@ def recursive_search(directory, excluded_dir_names=None): result.append(relative_path) for d in subdirs: - path = os.path.join(dirpath, d) + path: str = os.path.join(dirpath, d) try: dirs[path] = os.path.getmtime(path) except FileNotFoundError: @@ -169,12 +173,12 @@ def recursive_search(directory, excluded_dir_names=None): logging.debug("found {} files".format(len(result))) return result, dirs -def filter_files_extensions(files, extensions): +def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]: return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files))) -def get_full_path(folder_name, filename): +def get_full_path(folder_name: str, filename: str) -> str | None: global folder_names_and_paths if folder_name not in folder_names_and_paths: return None @@ -189,7 +193,7 @@ def get_full_path(folder_name, filename): return None -def get_filename_list_(folder_name): +def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: global folder_names_and_paths output_list = set() folders = folder_names_and_paths[folder_name] @@ -199,9 +203,9 @@ def get_filename_list_(folder_name): output_list.update(filter_files_extensions(files, folders[1])) output_folders = {**output_folders, **folders_all} - return (sorted(list(output_list)), output_folders, time.perf_counter()) + return sorted(list(output_list)), output_folders, time.perf_counter() -def cached_filename_list_(folder_name): +def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None: global filename_list_cache global folder_names_and_paths if folder_name not in filename_list_cache: @@ -222,7 +226,7 @@ def cached_filename_list_(folder_name): return out -def get_filename_list(folder_name): +def get_filename_list(folder_name: str) -> list[str]: out = cached_filename_list_(folder_name) if out is None: out = get_filename_list_(folder_name) @@ -230,17 +234,17 @@ def get_filename_list(folder_name): filename_list_cache[folder_name] = out return list(out[0]) -def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): - def map_filename(filename): +def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]: + def map_filename(filename: str) -> tuple[int, str]: prefix_len = len(os.path.basename(filename_prefix)) prefix = filename[:prefix_len + 1] try: digits = int(filename[prefix_len + 1:].split('_')[0]) except: digits = 0 - return (digits, prefix) + return digits, prefix - def compute_vars(input, image_width, image_height): + def compute_vars(input: str, image_width: int, image_height: int) -> str: input = input.replace("%width%", str(image_width)) input = input.replace("%height%", str(image_height)) return input