Skip to content

Commit

Permalink
Merge pull request #24 from jug-dev/main
Browse files Browse the repository at this point in the history
fix: re-add taint models feature, better dependency pins
  • Loading branch information
tazlin authored Jun 22, 2023
2 parents e080427 + 3457cf0 commit 6dafdf3
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 5 deletions.
44 changes: 41 additions & 3 deletions hordelib/model_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class BaseModelManager(ABC):
available_models: list[str] # XXX rework as a property?
_loaded_models: dict[str, dict]
"""The models available for immediate use."""
tainted_models: list
"""Models which seem to be corrupted and should be deleted when the correct replacement is downloaded."""
models_db_name: str
models_db_path: Path
cuda_available: bool
Expand Down Expand Up @@ -89,6 +91,7 @@ def __init__(
self.model_reference = {}
self.available_models = []
self._loaded_models = {}
self.tainted_models = []
self.pkg = importlib_resources.files("hordelib") # XXX Remove
self.models_db_name = MODEL_DB_NAMES[model_category_name]
self.models_db_path = Path(get_hordelib_path()).joinpath(
Expand Down Expand Up @@ -500,6 +503,16 @@ def unload_all_models(self):
self.unload_model(model)
return True

def taint_model(self, model_name: str):
"""Marks a model as not valid by removing it from available_models"""
if model_name in self.available_models:
self.available_models.remove(model_name)
self.tainted_models.append(model_name)

def taint_models(self, models: list[str]) -> None:
for model in models:
self.taint_model(model)

def validate_model(self, model_name: str, skip_checksum: bool = False) -> bool | None:
"""Check the if the model file is on disk and, optionally, also if the checksum is correct.
Expand Down Expand Up @@ -798,7 +811,8 @@ def download_model(self, model_name: str):
- unzip file
"""
# XXX this function is wacky in its premise and needs to be reworked
if model_name in self.available_models:
is_model_tainted = model_name in self.tainted_models
if not is_model_tainted and model_name in self.available_models:
logger.debug(f"{model_name} is already available.")
return True
download = self.get_model_download(model_name)
Expand All @@ -809,6 +823,9 @@ def download_model(self, model_name: str):
if "file_path" in download[i]
else files[i]["path"]
)
download_url = None
download_name = None
download_path = None

if "file_url" in download[i]:
download_url = download[i]["file_url"]
Expand All @@ -828,6 +845,11 @@ def download_model(self, model_name: str):
if "file_content" in download[i]:
file_content = download[i]["file_content"]
logger.info(f"writing {file_content} to {file_path}")
if not download_path or not download_name:
raise RuntimeError(
f"download_path and download_name are required for file_content download type for "
f"{model_name}",
)
os.makedirs(
os.path.join(self.modelFolderPath, download_path),
exist_ok=True,
Expand All @@ -840,6 +862,10 @@ def download_model(self, model_name: str):
elif "symlink" in download[i]:
logger.info(f"symlink {file_path} to {download[i]['symlink']}")
symlink = download[i]["symlink"]
if not download_path or not download_name:
raise RuntimeError(
f"download_path and download_name are required for symlink download type for " f"{model_name}",
)
os.makedirs(
os.path.join(self.modelFolderPath, download_path),
exist_ok=True,
Expand All @@ -864,7 +890,10 @@ def download_model(self, model_name: str):
zip_path = f"{self.modelFolderPath}/{download_name}.zip"
temp_path = f"{self.modelFolderPath}/{str(uuid4())}/"
os.makedirs(temp_path, exist_ok=True)

if not download_url or not download_path:
raise RuntimeError(
f"download_url and download_path are required for unzip download type for {model_name}",
)
download_succeeded = self.download_file(download_url, zip_path)
if not download_succeeded:
return False
Expand All @@ -885,8 +914,17 @@ def download_model(self, model_name: str):
logger.info(f"delete {temp_path}")
shutil.rmtree(temp_path)
else:
if not self.check_file_available(file_path):
if not self.check_file_available(file_path) or is_model_tainted:
logger.debug(f"Downloading {download_url} to {file_path}")
if is_model_tainted:
logger.debug(f"Model {model_name} is tainted.")

if not download_url:
logger.error(
f"download_url is required for download type for {model_name}",
)
return False

download_succeeded = self.download_file(download_url, file_path)
if not download_succeeded:
return False
Expand Down
13 changes: 13 additions & 0 deletions hordelib/model_manager/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,19 @@ def validate_model(
return model_manager.validate_model(model_name, skip_checksum)
return None

def taint_models(self, models: list[str]) -> None:
"""Marks a list of models to be unavailable.
Args:
models (list[str]): The list of models to mark.
"""
for model_manager_type in MODEL_MANAGERS_TYPE_LOOKUP:
model_manager: BaseModelManager = getattr(self, model_manager_type)
if model_manager is None:
continue
if any(model in model_manager.model_reference for model in models):
model_manager.taint_models(models)

def unload_model(self, model_name: str) -> bool | None:
"""Unloads the target model.
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Add this in for tox, comment out for build
--extra-index-url https://download.pytorch.org/whl/cu118
horde_model_reference>=0.1.1
horde_model_reference~=0.2.0
pydantic
torch>=2.0.0
xformers>=0.0.19
Expand Down Expand Up @@ -40,4 +40,4 @@ scikit-image
mediapipe>=0.9.1.0
unidecode
fuzzywuzzy
horde_clipfree>=0.0.2
horde_clipfree==0.0.2

0 comments on commit 6dafdf3

Please sign in to comment.