Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: attr addressable model references; add some convenience methods #119

Merged
merged 3 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions horde_model_reference/meta_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,25 @@ class STABLE_DIFFUSION_BASELINE_CATEGORY(StrEnum):
MODEL_REFERENCE_CATEGORY.stable_diffusion: MODEL_PURPOSE.image_generation,
MODEL_REFERENCE_CATEGORY.miscellaneous: MODEL_PURPOSE.miscellaneous,
}

STABLE_DIFFUSION_BASELINE_NATIVE_RESOLUTION_LOOKUP: dict[STABLE_DIFFUSION_BASELINE_CATEGORY, int] = {
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_1: 512,
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_2_768: 768,
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_2_512: 512,
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_diffusion_xl: 1024,
STABLE_DIFFUSION_BASELINE_CATEGORY.stable_cascade: 1024,
}
"""The single-side preferred resolution for each known stable diffusion baseline."""


def get_baseline_native_resolution(baseline: STABLE_DIFFUSION_BASELINE_CATEGORY) -> int:
"""
Get the native resolution of a stable diffusion baseline.

Args:
baseline: The stable diffusion baseline.

Returns:
The native resolution of the baseline.
"""
return STABLE_DIFFUSION_BASELINE_NATIVE_RESOLUTION_LOOKUP[baseline]
139 changes: 139 additions & 0 deletions horde_model_reference/model_reference_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from horde_model_reference.legacy.download_live_legacy_dbs import LegacyReferenceDownloadManager
from horde_model_reference.model_reference_records import (
MODEL_REFERENCE_TYPE_LOOKUP,
CLIP_ModelReference,
ControlNet_ModelReference,
Generic_ModelReference,
StableDiffusion_ModelReference,
)
from horde_model_reference.path_consts import MODEL_REFERENCE_CATEGORY

Expand Down Expand Up @@ -128,3 +131,139 @@ def get_all_model_references(
return_dict[reference_type] = reference.model_copy(deep=True)

return return_dict

@property
def blip(self) -> Generic_ModelReference:
"""
Get the BLIP model reference.

Returns:
The BLIP model reference.
"""
blip = self.all_model_references[MODEL_REFERENCE_CATEGORY.blip]
if blip is None:
raise ValueError("BLIP model reference not found.")

return blip

@property
def clip(self) -> CLIP_ModelReference:
"""
Get the CLIP model reference.

Returns:
The CLIP model reference.
"""
clip = self.all_model_references[MODEL_REFERENCE_CATEGORY.clip]
if clip is None:
raise ValueError("CLIP model reference not found.")

if not isinstance(clip, CLIP_ModelReference):
raise TypeError("CLIP model reference is not of the correct type.")

return clip

@property
def codeformer(self) -> Generic_ModelReference:
"""
Get the codeformer model reference.

Returns:
The codeformer model reference.
"""
codeformer = self.all_model_references[MODEL_REFERENCE_CATEGORY.codeformer]
if codeformer is None:
raise ValueError("Codeformer model reference not found.")

return codeformer

@property
def controlnet(self) -> ControlNet_ModelReference:
"""
Get the controlnet model reference.

Returns:
The controlnet model reference.
"""
controlnet = self.all_model_references[MODEL_REFERENCE_CATEGORY.controlnet]
if controlnet is None:
raise ValueError("ControlNet model reference not found.")

if not isinstance(controlnet, ControlNet_ModelReference):
raise TypeError("ControlNet model reference is not of the correct type.")

return controlnet

@property
def esrgan(self) -> Generic_ModelReference:
"""
Get the ESRGAN model reference.

Returns:
The ESRGAN model reference.
"""
esrgan = self.all_model_references[MODEL_REFERENCE_CATEGORY.esrgan]
if esrgan is None:
raise ValueError("ESRGAN model reference not found.")

return esrgan

@property
def gfpgan(self) -> Generic_ModelReference:
"""
Get the GfPGAN model reference.

Returns:
The GfPGAN model reference.
"""
gfpgan = self.all_model_references[MODEL_REFERENCE_CATEGORY.gfpgan]
if gfpgan is None:
raise ValueError("GfPGAN model reference not found.")

return gfpgan

@property
def safety_checker(self) -> Generic_ModelReference:
"""
Get the safety checker model reference.

Returns:
The safety checker model reference.
"""
safety_checker = self.all_model_references[MODEL_REFERENCE_CATEGORY.safety_checker]
if safety_checker is None:
raise ValueError("Safety checker model reference not found.")

return safety_checker

@property
def stable_diffusion(self) -> StableDiffusion_ModelReference:
"""
Get the stable diffusion model reference.

Returns:
The stable diffusion model reference.
"""
stable_diffusion = self.all_model_references[MODEL_REFERENCE_CATEGORY.stable_diffusion]

if stable_diffusion is None:
raise ValueError("Stable diffusion model reference not found.")

if not isinstance(stable_diffusion, StableDiffusion_ModelReference):
raise TypeError("Stable diffusion model reference is not of the correct type.")

return stable_diffusion

@property
def miscellaneous(self) -> Generic_ModelReference:
"""
Get the miscellaneous model reference.

Returns:
The miscellaneous model reference.
"""
miscellaneous = self.all_model_references[MODEL_REFERENCE_CATEGORY.miscellaneous]
if miscellaneous is None:
raise ValueError("Miscellaneous model reference not found.")

return miscellaneous
18 changes: 18 additions & 0 deletions horde_model_reference/model_reference_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,24 @@ def models_names(self) -> set[str]:
"""Return a list of all the model names."""
return set(self.root.keys())

def get_model_baseline(self, model_name: str) -> STABLE_DIFFUSION_BASELINE_CATEGORY | str | None:
"""Return the baseline for a given model name."""
model: StableDiffusion_ModelRecord | None = self.root.get(model_name)

return model.baseline if model else None

def get_model_style(self, model_name: str) -> MODEL_STYLE | str | None:
"""Return the style for a given model name."""
model: StableDiffusion_ModelRecord | None = self.root.get(model_name)

return model.style if model else None

def get_model_tags(self, model_name: str) -> list[str] | None:
"""Return the tags for a given model name."""
model: StableDiffusion_ModelRecord | None = self.root.get(model_name)

return model.tags if model else None


class CLIP_ModelReference(Generic_ModelReference):
root: Mapping[str, CLIP_ModelRecord]
Expand Down
8 changes: 8 additions & 0 deletions tests/test_convert_legacy_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def test_validate_converted_stable_diffusion_database(base_path_for_tests) -> No
assert model_reference.root is not None
assert len(model_reference.root) >= 100

assert model_reference.get_model_baseline("BAD_MODEL_KEY") is None
assert model_reference.get_model_style("BAD_MODEL_KEY") is None
assert model_reference.get_model_tags("BAD_MODEL_KEY") is None

assert model_reference.root["stable_diffusion"] is not None
assert model_reference.root["stable_diffusion"].name == "stable_diffusion"
assert model_reference.root["stable_diffusion"].showcases is not None
Expand Down Expand Up @@ -112,3 +116,7 @@ def test_validate_converted_stable_diffusion_database(base_path_for_tests) -> No
if model_info.trigger is not None:
for trigger_record in model_info.trigger:
assert trigger_record != ""

assert model_reference.get_model_baseline(model_key) is not None
assert model_reference.get_model_style(model_key) is not None
assert isinstance(model_reference.get_model_tags(model_key), list)
Loading