diff --git a/horde_model_reference/meta_consts.py b/horde_model_reference/meta_consts.py index 9a76858..b89dfcb 100644 --- a/horde_model_reference/meta_consts.py +++ b/horde_model_reference/meta_consts.py @@ -103,3 +103,25 @@ class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum): MODEL_REFERENCE_CATEGORY.stable_diffusion: MODEL_PURPOSE.image_generation, MODEL_REFERENCE_CATEGORY.miscellaneous: MODEL_PURPOSE.miscellaneous, } + +STABLE_DIFFUSION_BASELINE_NATIVE_RESOLUTION_LOOKUP: dict[STABLE_DIFFUSION_BASELINE_CATEGORY, int] = { + STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_1: 512, + STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_2_768: 768, + STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_2_512: 512, + STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_xl: 1024, + STABLE_DIFFUSION_BASELINE_CATEGORY.stable_cascade: 1024, +} +"""The single-side preferred resolution for each known stable diffusion baseline.""" + + +def get_baseline_native_resolution(baseline: STABLE_DIFFUSION_BASELINE_CATEGORY) -> int: + """ + Get the native resolution of a stable diffusion baseline. + + Args: + baseline: The stable diffusion baseline. + + Returns: + The native resolution of the baseline. + """ + return STABLE_DIFFUSION_BASELINE_NATIVE_RESOLUTION_LOOKUP[baseline] diff --git a/horde_model_reference/model_reference_manager.py b/horde_model_reference/model_reference_manager.py index c933136..d7591f6 100644 --- a/horde_model_reference/model_reference_manager.py +++ b/horde_model_reference/model_reference_manager.py @@ -6,7 +6,10 @@ from horde_model_reference.legacy.download_live_legacy_dbs import LegacyReferenceDownloadManager from horde_model_reference.model_reference_records import ( MODEL_REFERENCE_TYPE_LOOKUP, + CLIP_ModelReference, + ControlNet_ModelReference, Generic_ModelReference, + StableDiffusion_ModelReference, ) from horde_model_reference.path_consts import MODEL_REFERENCE_CATEGORY @@ -128,3 +131,139 @@ def get_all_model_references( return_dict[reference_type] = reference.model_copy(deep=True) return return_dict + + @property + def blip(self) -> Generic_ModelReference: + """ + Get the BLIP model reference. + + Returns: + The BLIP model reference. + """ + blip = self.all_model_references[MODEL_REFERENCE_CATEGORY.blip] + if blip is None: + raise ValueError("BLIP model reference not found.") + + return blip + + @property + def clip(self) -> CLIP_ModelReference: + """ + Get the CLIP model reference. + + Returns: + The CLIP model reference. + """ + clip = self.all_model_references[MODEL_REFERENCE_CATEGORY.clip] + if clip is None: + raise ValueError("CLIP model reference not found.") + + if not isinstance(clip, CLIP_ModelReference): + raise TypeError("CLIP model reference is not of the correct type.") + + return clip + + @property + def codeformer(self) -> Generic_ModelReference: + """ + Get the codeformer model reference. + + Returns: + The codeformer model reference. + """ + codeformer = self.all_model_references[MODEL_REFERENCE_CATEGORY.codeformer] + if codeformer is None: + raise ValueError("Codeformer model reference not found.") + + return codeformer + + @property + def controlnet(self) -> ControlNet_ModelReference: + """ + Get the controlnet model reference. + + Returns: + The controlnet model reference. + """ + controlnet = self.all_model_references[MODEL_REFERENCE_CATEGORY.controlnet] + if controlnet is None: + raise ValueError("ControlNet model reference not found.") + + if not isinstance(controlnet, ControlNet_ModelReference): + raise TypeError("ControlNet model reference is not of the correct type.") + + return controlnet + + @property + def esrgan(self) -> Generic_ModelReference: + """ + Get the ESRGAN model reference. + + Returns: + The ESRGAN model reference. + """ + esrgan = self.all_model_references[MODEL_REFERENCE_CATEGORY.esrgan] + if esrgan is None: + raise ValueError("ESRGAN model reference not found.") + + return esrgan + + @property + def gfpgan(self) -> Generic_ModelReference: + """ + Get the GfPGAN model reference. + + Returns: + The GfPGAN model reference. + """ + gfpgan = self.all_model_references[MODEL_REFERENCE_CATEGORY.gfpgan] + if gfpgan is None: + raise ValueError("GfPGAN model reference not found.") + + return gfpgan + + @property + def safety_checker(self) -> Generic_ModelReference: + """ + Get the safety checker model reference. + + Returns: + The safety checker model reference. + """ + safety_checker = self.all_model_references[MODEL_REFERENCE_CATEGORY.safety_checker] + if safety_checker is None: + raise ValueError("Safety checker model reference not found.") + + return safety_checker + + @property + def stable_diffusion(self) -> StableDiffusion_ModelReference: + """ + Get the stable diffusion model reference. + + Returns: + The stable diffusion model reference. + """ + stable_diffusion = self.all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion] + + if stable_diffusion is None: + raise ValueError("Stable diffusion model reference not found.") + + if not isinstance(stable_diffusion, StableDiffusion_ModelReference): + raise TypeError("Stable diffusion model reference is not of the correct type.") + + return stable_diffusion + + @property + def miscellaneous(self) -> Generic_ModelReference: + """ + Get the miscellaneous model reference. + + Returns: + The miscellaneous model reference. + """ + miscellaneous = self.all_model_references[MODEL_REFERENCE_CATEGORY.miscellaneous] + if miscellaneous is None: + raise ValueError("Miscellaneous model reference not found.") + + return miscellaneous diff --git a/horde_model_reference/model_reference_records.py b/horde_model_reference/model_reference_records.py index 72ef8ff..43ed171 100644 --- a/horde_model_reference/model_reference_records.py +++ b/horde_model_reference/model_reference_records.py @@ -232,6 +232,24 @@ def models_names(self) -> set[str]: """Return a list of all the model names.""" return set(self.root.keys()) + def get_model_baseline(self, model_name: str) -> STABLE_DIFFUSION_BASELINE_CATEGORY | str | None: + """Return the baseline for a given model name.""" + model: StableDiffusion_ModelRecord | None = self.root.get(model_name) + + return model.baseline if model else None + + def get_model_style(self, model_name: str) -> MODEL_STYLE | str | None: + """Return the style for a given model name.""" + model: StableDiffusion_ModelRecord | None = self.root.get(model_name) + + return model.style if model else None + + def get_model_tags(self, model_name: str) -> list[str] | None: + """Return the tags for a given model name.""" + model: StableDiffusion_ModelRecord | None = self.root.get(model_name) + + return model.tags if model else None + class CLIP_ModelReference(Generic_ModelReference): root: Mapping[str, CLIP_ModelRecord] diff --git a/tests/test_convert_legacy_database.py b/tests/test_convert_legacy_database.py index 16ea5fb..e495f87 100644 --- a/tests/test_convert_legacy_database.py +++ b/tests/test_convert_legacy_database.py @@ -75,6 +75,10 @@ def test_validate_converted_stable_diffusion_database(base_path_for_tests) -> No assert model_reference.root is not None assert len(model_reference.root) >= 100 + assert model_reference.get_model_baseline("BAD_MODEL_KEY") is None + assert model_reference.get_model_style("BAD_MODEL_KEY") is None + assert model_reference.get_model_tags("BAD_MODEL_KEY") is None + assert model_reference.root["stable_diffusion"] is not None assert model_reference.root["stable_diffusion"].name == "stable_diffusion" assert model_reference.root["stable_diffusion"].showcases is not None @@ -112,3 +116,7 @@ def test_validate_converted_stable_diffusion_database(base_path_for_tests) -> No if model_info.trigger is not None: for trigger_record in model_info.trigger: assert trigger_record != "" + + assert model_reference.get_model_baseline(model_key) is not None + assert model_reference.get_model_style(model_key) is not None + assert isinstance(model_reference.get_model_tags(model_key), list)