From c6fb00b307f863a9303b2e8c54ba8433c0e52841 Mon Sep 17 00:00:00 2001 From: tazlin Date: Thu, 22 Jun 2023 07:52:01 -0400 Subject: [PATCH 1/2] fix: correctly require correct versions of horde_* deps --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 172f5bcb..2f1a560e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Add this in for tox, comment out for build --extra-index-url https://download.pytorch.org/whl/cu118 -horde_model_reference>=0.1.1 +horde_model_reference~=0.2.0 pydantic torch>=2.0.0 xformers>=0.0.19 @@ -40,4 +40,4 @@ scikit-image mediapipe>=0.9.1.0 unidecode fuzzywuzzy -horde_clipfree>=0.0.2 +horde_clipfree==0.0.2 From 2b7dbbde472ed52a82404b52854bbdcc5b64d290 Mon Sep 17 00:00:00 2001 From: tazlin Date: Thu, 22 Jun 2023 08:49:45 -0400 Subject: [PATCH 2/2] fix: re-add taint models (unintentionally removed in 1.6.0) --- hordelib/model_manager/base.py | 44 ++++++++++++++++++++++++++++++--- hordelib/model_manager/hyper.py | 13 ++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/hordelib/model_manager/base.py b/hordelib/model_manager/base.py index 2b699cc6..727544f5 100644 --- a/hordelib/model_manager/base.py +++ b/hordelib/model_manager/base.py @@ -39,6 +39,8 @@ class BaseModelManager(ABC): available_models: list[str] # XXX rework as a property? _loaded_models: dict[str, dict] """The models available for immediate use.""" + tainted_models: list + """Models which seem to be corrupted and should be deleted when the correct replacement is downloaded.""" models_db_name: str models_db_path: Path cuda_available: bool @@ -89,6 +91,7 @@ def __init__( self.model_reference = {} self.available_models = [] self._loaded_models = {} + self.tainted_models = [] self.pkg = importlib_resources.files("hordelib") # XXX Remove self.models_db_name = MODEL_DB_NAMES[model_category_name] self.models_db_path = Path(get_hordelib_path()).joinpath( @@ -500,6 +503,16 @@ def unload_all_models(self): self.unload_model(model) return True + def taint_model(self, model_name: str): + """Marks a model as not valid by removing it from available_models""" + if model_name in self.available_models: + self.available_models.remove(model_name) + self.tainted_models.append(model_name) + + def taint_models(self, models: list[str]) -> None: + for model in models: + self.taint_model(model) + def validate_model(self, model_name: str, skip_checksum: bool = False) -> bool | None: """Check the if the model file is on disk and, optionally, also if the checksum is correct. @@ -798,7 +811,8 @@ def download_model(self, model_name: str): - unzip file """ # XXX this function is wacky in its premise and needs to be reworked - if model_name in self.available_models: + is_model_tainted = model_name in self.tainted_models + if not is_model_tainted and model_name in self.available_models: logger.debug(f"{model_name} is already available.") return True download = self.get_model_download(model_name) @@ -809,6 +823,9 @@ def download_model(self, model_name: str): if "file_path" in download[i] else files[i]["path"] ) + download_url = None + download_name = None + download_path = None if "file_url" in download[i]: download_url = download[i]["file_url"] @@ -828,6 +845,11 @@ def download_model(self, model_name: str): if "file_content" in download[i]: file_content = download[i]["file_content"] logger.info(f"writing {file_content} to {file_path}") + if not download_path or not download_name: + raise RuntimeError( + f"download_path and download_name are required for file_content download type for " + f"{model_name}", + ) os.makedirs( os.path.join(self.modelFolderPath, download_path), exist_ok=True, @@ -840,6 +862,10 @@ def download_model(self, model_name: str): elif "symlink" in download[i]: logger.info(f"symlink {file_path} to {download[i]['symlink']}") symlink = download[i]["symlink"] + if not download_path or not download_name: + raise RuntimeError( + f"download_path and download_name are required for symlink download type for " f"{model_name}", + ) os.makedirs( os.path.join(self.modelFolderPath, download_path), exist_ok=True, @@ -864,7 +890,10 @@ def download_model(self, model_name: str): zip_path = f"{self.modelFolderPath}/{download_name}.zip" temp_path = f"{self.modelFolderPath}/{str(uuid4())}/" os.makedirs(temp_path, exist_ok=True) - + if not download_url or not download_path: + raise RuntimeError( + f"download_url and download_path are required for unzip download type for {model_name}", + ) download_succeeded = self.download_file(download_url, zip_path) if not download_succeeded: return False @@ -885,8 +914,17 @@ def download_model(self, model_name: str): logger.info(f"delete {temp_path}") shutil.rmtree(temp_path) else: - if not self.check_file_available(file_path): + if not self.check_file_available(file_path) or is_model_tainted: logger.debug(f"Downloading {download_url} to {file_path}") + if is_model_tainted: + logger.debug(f"Model {model_name} is tainted.") + + if not download_url: + logger.error( + f"download_url is required for download type for {model_name}", + ) + return False + download_succeeded = self.download_file(download_url, file_path) if not download_succeeded: return False diff --git a/hordelib/model_manager/hyper.py b/hordelib/model_manager/hyper.py index 1ba376a8..a88ae460 100644 --- a/hordelib/model_manager/hyper.py +++ b/hordelib/model_manager/hyper.py @@ -285,6 +285,19 @@ def validate_model( return model_manager.validate_model(model_name, skip_checksum) return None + def taint_models(self, models: list[str]) -> None: + """Marks a list of models to be unavailable. + + Args: + models (list[str]): The list of models to mark. + """ + for model_manager_type in MODEL_MANAGERS_TYPE_LOOKUP: + model_manager: BaseModelManager = getattr(self, model_manager_type) + if model_manager is None: + continue + if any(model in model_manager.model_reference for model in models): + model_manager.taint_models(models) + def unload_model(self, model_name: str) -> bool | None: """Unloads the target model.