Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to folder_paths.py #4191

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions folder_paths.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import os
import time
import logging
from typing import Set, List, Dict, Tuple
from collections.abc import Collection

supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}

SupportedFileExtensionsType = Set[str]
ScanPathType = List[str]
folder_names_and_paths: Dict[str, Tuple[ScanPathType, SupportedFileExtensionsType]] = {}
SupportedFileExtensionsType = set[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should inline these 2 variables. My bad.

Traceback (most recent call last):
  File "main.py", line 6, in <module>
    import folder_paths
  File "/home/runner/work/ComfyUI/ComfyUI/ComfyUI/folder_paths.py", line 10, in <module>
    SupportedFileExtensionsType = set[str]
TypeError: 'type' object is not subscriptable

ScanPathType = list[str]
folder_names_and_paths: dict[str, tuple[ScanPathType, SupportedFileExtensionsType]] = {}

base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(base_path, "models")
Expand Down Expand Up @@ -42,41 +44,41 @@
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")

filename_list_cache = {}
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}

if not os.path.exists(input_directory):
try:
os.makedirs(input_directory)
except:
logging.error("Failed to create input directory")

def set_output_directory(output_dir):
def set_output_directory(output_dir: str) -> None:
global output_directory
output_directory = output_dir

def set_temp_directory(temp_dir):
def set_temp_directory(temp_dir: str) -> None:
global temp_directory
temp_directory = temp_dir

def set_input_directory(input_dir):
def set_input_directory(input_dir: str) -> None:
global input_directory
input_directory = input_dir

def get_output_directory():
def get_output_directory() -> str:
global output_directory
return output_directory

def get_temp_directory():
def get_temp_directory() -> str:
global temp_directory
return temp_directory

def get_input_directory():
def get_input_directory() -> str:
global input_directory
return input_directory


#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
def get_directory_by_type(type_name: str) -> str | None:
if type_name == "output":
return get_output_directory()
if type_name == "temp":
Expand All @@ -88,7 +90,7 @@ def get_directory_by_type(type_name):

# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
def annotated_filepath(name):
def annotated_filepath(name: str) -> tuple[str, str | None]:
if name.endswith("[output]"):
base_dir = get_output_directory()
name = name[:-9]
Expand All @@ -104,7 +106,7 @@ def annotated_filepath(name):
return name, base_dir


def get_annotated_filepath(name, default_dir=None):
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
name, base_dir = annotated_filepath(name)

if base_dir is None:
Expand All @@ -116,7 +118,7 @@ def get_annotated_filepath(name, default_dir=None):
return os.path.join(base_dir, name)


def exists_annotated_filepath(name):
def exists_annotated_filepath(name) -> bool:
name, base_dir = annotated_filepath(name)

if base_dir is None:
Expand All @@ -126,17 +128,17 @@ def exists_annotated_filepath(name):
return os.path.exists(filepath)


def add_model_folder_path(folder_name, full_folder_path):
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
global folder_names_and_paths
if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())

def get_folder_paths(folder_name):
def get_folder_paths(folder_name: str) -> list[str]:
return folder_names_and_paths[folder_name][0][:]

def recursive_search(directory, excluded_dir_names=None):
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
if not os.path.isdir(directory):
return [], {}

Expand All @@ -153,14 +155,18 @@ def recursive_search(directory, excluded_dir_names=None):
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")

logging.debug("recursive file list on directory {}".format(directory))
dirpath: str
subdirs: list[str]
filenames: list[str]

for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)

for d in subdirs:
path = os.path.join(dirpath, d)
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
Expand All @@ -169,12 +175,12 @@ def recursive_search(directory, excluded_dir_names=None):
logging.debug("found {} files".format(len(result)))
return result, dirs

def filter_files_extensions(files, extensions):
def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]:
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))



def get_full_path(folder_name, filename):
def get_full_path(folder_name: str, filename: str) -> str | None:
global folder_names_and_paths
if folder_name not in folder_names_and_paths:
return None
Expand All @@ -189,7 +195,7 @@ def get_full_path(folder_name, filename):

return None

def get_filename_list_(folder_name):
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
global folder_names_and_paths
output_list = set()
folders = folder_names_and_paths[folder_name]
Expand All @@ -199,9 +205,9 @@ def get_filename_list_(folder_name):
output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all}

return (sorted(list(output_list)), output_folders, time.perf_counter())
return sorted(list(output_list)), output_folders, time.perf_counter()

def cached_filename_list_(folder_name):
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
global filename_list_cache
global folder_names_and_paths
if folder_name not in filename_list_cache:
Expand All @@ -222,25 +228,25 @@ def cached_filename_list_(folder_name):

return out

def get_filename_list(folder_name):
def get_filename_list(folder_name: str) -> list[str]:
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
return list(out[0])

def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def map_filename(filename):
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
def map_filename(filename: str) -> tuple[int, str]:
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
return digits, prefix

def compute_vars(input, image_width, image_height):
def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height))
return input
Expand Down
Loading