From 6a95a2906539b4d8b2095b0ad1d0581923bbb832 Mon Sep 17 00:00:00 2001 From: Lucain Date: Tue, 13 Aug 2024 12:31:32 +0200 Subject: [PATCH] Update Hugging Face utilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hello 👋 This PR updates some attributes and methods that are now deprecated in the `huggingface_hub` library. - `modelId` is getting deprecated https://github.com/huggingface/huggingface_hub/issues/2408 - `ModelFilter` has been deprecated for some time (see https://github.com/huggingface/huggingface_hub/issues/2028) This PR will make the codebase more future-proof while being compatible with existing versions of `huggingface_hub`. Let me know if you have any questions :hugs: --- generator_process/actions/huggingface_hub.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/generator_process/actions/huggingface_hub.py b/generator_process/actions/huggingface_hub.py index b3dc7eb1..f2dc9929 100644 --- a/generator_process/actions/huggingface_hub.py +++ b/generator_process/actions/huggingface_hub.py @@ -32,24 +32,23 @@ def hf_list_models( query: str, token: str, ) -> list[Model]: - from huggingface_hub import HfApi, ModelFilter + from huggingface_hub import HfApi if hasattr(self, "huggingface_hub_api"): api: HfApi = self.huggingface_hub_api else: api = HfApi() setattr(self, "huggingface_hub_api", api) - - filter = ModelFilter(tags="diffusers") + models = api.list_models( - filter=filter, + tags="diffusers", search=query, - use_auth_token=token + token=token, ) return [ - Model(m.modelId, m.author or "", m.tags, m.likes if hasattr(m, "likes") else 0, getattr(m, "downloads", -1), ModelType.UNKNOWN) + Model(m.id, m.author or "", m.tags, m.likes if hasattr(m, "likes") else 0, getattr(m, "downloads", -1), ModelType.UNKNOWN) for m in models - if m.modelId is not None and m.tags is not None and 'diffusers' in (m.tags or {}) + if m.id is not None and m.tags is not None and 'diffusers' in (m.tags or {}) ] def hf_list_installed_models(self) -> list[Model]: @@ -177,7 +176,7 @@ def hf_snapshot_download( _, variant_files = variant_compatible_siblings(files, variant=variant) StableDiffusionPipeline.download( model, - use_auth_token=token, + token=token, variant=variant if len(variant_files) > 0 else None, resume_download=resume_download, ) @@ -204,4 +203,4 @@ def hf_snapshot_download( else: raise ValueError(f"{model} doesn't appear to be a pipeline or model") - future.set_done() \ No newline at end of file + future.set_done()