diff --git a/.gitignore b/.gitignore index 4b07ab38..16e57384 100644 --- a/.gitignore +++ b/.gitignore @@ -113,6 +113,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.env_docker # Spyder project settings .spyderproject diff --git a/horde/apis/models/v2.py b/horde/apis/models/v2.py index 863c03bb..b7607553 100644 --- a/horde/apis/models/v2.py +++ b/horde/apis/models/v2.py @@ -2,10 +2,11 @@ # # SPDX-License-Identifier: AGPL-3.0-or-later -from flask_restx import fields, reqparse +from flask_restx import Namespace, fields, reqparse from horde.enums import WarningMessage from horde.exceptions import KNOWN_RC +from horde.model_reference import KnownImageModelRef, KnownTextModelRef from horde.vars import horde_noun, horde_title @@ -260,7 +261,7 @@ def __init__(self): class Models: - def __init__(self, api): + def __init__(self, api: Namespace): self.response_model_wp_status_lite = api.model( "RequestStatusCheck", { @@ -406,7 +407,7 @@ def __init__(self, api): min=0, ), "untrusted": fields.Integer( - description=("How many waiting requests were skipped because they demanded a trusted worker which this worker is not."), + description="How many waiting requests were skipped because they demanded a trusted worker which this worker is not.", min=0, ), "models": fields.Integer( @@ -618,11 +619,11 @@ def __init__(self, api): "forms": fields.List(fields.String(description="Which forms this worker if offering.")), "team": fields.Nested( self.response_model_team_details_lite, - "The Team to which this worker is dedicated.", + description="The Team to which this worker is dedicated.", ), "contact": fields.String( example="email@example.com", - description=("(Privileged) Contact details for the horde admins to reach the owner of this worker in emergencies."), + description="(Privileged) Contact details for the horde admins to reach the owner of this worker in emergencies.", min_length=5, max_length=500, ), @@ -1053,7 +1054,7 @@ def __init__(self, api): max=10, ), "worker_invited": fields.Integer( - description=("Set to the amount of workers this user is allowed to join to the horde when in worker invite-only mode."), + description="Set to the amount of workers this user is allowed to join to the horde when in worker invite-only mode.", ), "moderator": fields.Boolean( example=False, @@ -1278,6 +1279,232 @@ def __init__(self, api): ), }, ) + self.response_model_known_model_md = api.model( + "KnownModelMetadata", + { + "name": fields.String(description="The name of this model."), + "description": fields.String(description="The description of this model."), + "version": fields.String(description="The version of this model."), + "style": fields.String(description="The style of this model."), + "nsfw": fields.Boolean(description="Whether this model can generate NSFW content."), + "baseline": fields.String(description="The baseline model used for this model."), + }, + ) + + settings = api.model( + "KnownTextModelSettings", + { + "n": fields.Integer(example=1, min=1, max=20), + "frmtadsnsp": fields.Boolean( + example=False, + description=( + "Input formatting option. When enabled, adds a leading space to your input " + "if there is no trailing whitespace at the end of the previous action." + ), + ), + "frmtrmblln": fields.Boolean( + example=False, + description=( + "Output formatting option. When enabled, replaces all occurrences of two or more consecutive newlines " + "in the output with one newline." + ), + ), + "frmtrmspch": fields.Boolean( + example=False, + description=r"Output formatting option. When enabled, removes #/@%}{+=~|\^<> from the output.", + ), + "frmttriminc": fields.Boolean( + example=False, + description=( + "Output formatting option. When enabled, removes some characters from the end of the output such " + "that the output doesn't end in the middle of a sentence. " + "If the output is less than one sentence long, does nothing." + ), + ), + "max_context_length": fields.Integer( + min=80, + example=1024, + max=32000, + description="Maximum number of tokens to send to the model.", + ), + "max_length": fields.Integer( + min=16, + max=1024, + example=80, + description="Number of tokens to generate.", + ), + "rep_pen": fields.Float(description="Base repetition penalty value.", min=1, max=3), + "rep_pen_range": fields.Integer(description="Repetition penalty range.", min=0, max=4096), + "rep_pen_slope": fields.Float(description="Repetition penalty slope.", min=0, max=10), + "singleline": fields.Boolean( + example=False, + description=( + "Output formatting option. When enabled, removes everything after the first line of the output, " + "including the newline." + ), + ), + "temperature": fields.Float(description="Temperature value.", min=0, max=5.0), + "tfs": fields.Float(description="Tail free sampling value.", min=0.0, max=1.0), + "top_a": fields.Float(description="Top-a sampling value.", min=0.0, max=1.0), + "top_k": fields.Integer(description="Top-k sampling value.", min=0, max=100), + "top_p": fields.Float(description="Top-p sampling value.", min=0.001, max=1.0), + "typical": fields.Float(description="Typical sampling value.", min=0.0, max=1.0), + "sampler_order": fields.List( + fields.Integer(description="Array of integers representing the sampler order to be used."), + ), + "use_default_badwordsids": fields.Boolean( + example=True, + description="When True, uses the default KoboldAI bad word IDs.", + ), + "stop_sequence": fields.List( + fields.String( + description=( + "An array of string sequences whereby the model will stop generating further tokens. " + "The returned text WILL contain the stop sequence." + ), + ), + ), + "min_p": fields.Float(description="Min-p sampling value.", min=0.0, example=0.0, max=1.0), + "smoothing_factor": fields.Float( + description="Quadratic sampling value.", + min=0.0, + example=0.0, + max=10.0, + ), + "dynatemp_range": fields.Float( + description="Dynamic temperature range value.", + min=0.0, + example=0.0, + max=5.0, + ), + "dynatemp_exponent": fields.Float( + description="Dynamic temperature exponent value.", + min=0.0, + example=1.0, + max=5.0, + ), + }, + ) + + self.response_model_known_text_model_md = api.inherit( + "KnownTextModelMetadata", + self.response_model_known_model_md, + { + "parameters": fields.Integer(description="The number of parameters in this model."), + "display_name": fields.String(description="The display name of this model."), + "homepage": fields.String(description="The homepage of the model.", attribute="url"), + "tags": fields.List(fields.String(description="The tags of this model.")), + "instruct_format": fields.String(description="Instruct format template to use for this model."), + "settings": fields.Nested( + settings, + description="Recommended settings for this model.", + allow_null=False, + skip_none=True, + ), + }, + ) + + requirements = api.model( + "KnownImageModelRequirements", + { + "clip_skip": fields.Integer( + description="The number of steps to skip in CLIP.", + min=1, + example=1, + ), + "min_steps": fields.Integer( + description="The minimum number of steps to take.", + min=1, + example=30, + ), + "max_steps": fields.Integer( + description="The maximum number of steps to take.", + min=1, + example=30, + ), + "cfg_scale": fields.Float( + description="Classifier-free guidance scale.", + min=0.0, + example=7.5, + ), + "min_cfg_scale": fields.Float( + description="Minimum classifier-free guidance scale.", + min=0.0, + example=7.5, + ), + "max_cfg_scale": fields.Float( + description="Maximum classifier-free guidance scale.", + min=0.0, + example=7.5, + ), + "samplers": fields.List( + fields.String( + description="The samplers to use for this model.", + example="k_euler_a", + ), + ), + }, + ) + + download = api.model( + "KnownImageModelDownload", + { + "file_name": fields.String(description="The filename of the file to download."), + "file_path": fields.String(description="The path to the file to download."), + "file_url": fields.String(description="The URL to download the file from."), + }, + ) + + config = api.model( + "KnownImageModelConfig", + { + "files": fields.List(fields.Nested(api.model("KnownImageModelFile", {"path": fields.String, "sha256sum": fields.String}))), + "download": fields.List(fields.Nested(download)), + }, + ) + self.response_model_known_image_model_md = api.inherit( + "KnownImageModelMetadata", + self.response_model_known_model_md, + { + "homepage": fields.String(description="The URL of the model's page."), + "weight_type": fields.String(description="Storage format of the model weights.", attribute="type"), + "inpainting": fields.Boolean(description="Whether this model can generate inpainting content."), + "requirements": fields.Nested( + requirements, + description="Generation settings requirements for this model.", + allow_null=False, + skip_none=True, + ), + "config": fields.Nested( + config, + description="The configuration of the model.", + allow_null=False, + skip_none=True, + ), + "features_not_supported": fields.List(fields.String(description="The features not supported by the model.")), + "size_on_disk_bytes": fields.Integer(description="The size of the model on disk in bytes."), + }, + ) + + self.response_model_known_model = api.model( + "KnownModel", + { + "name": fields.String(description="The name of this model."), + "type": fields.String( + description="Model type (text or image).", + enum=["text", "image"], + ), + "metadata": fields.Polymorph( + { + KnownImageModelRef: self.response_model_known_image_model_md, + KnownTextModelRef: self.response_model_known_text_model_md, + }, + description="The metadata of the model.", + skip_none=True, + ), + }, + ) + self.response_model_active_model = api.inherit( "ActiveModel", self.response_model_active_model_lite, @@ -1291,8 +1518,17 @@ def __init__(self, api): description="The model type (text or image).", enum=["image", "text"], ), + "metadata": fields.Polymorph( + { + KnownImageModelRef: self.response_model_known_image_model_md, + KnownTextModelRef: self.response_model_known_text_model_md, + }, + description="The metadata of the model.", + skip_none=True, + ), }, ) + self.response_model_deleted_worker = api.model( "DeletedWorker", { diff --git a/horde/apis/v2/__init__.py b/horde/apis/v2/__init__.py index f9a4679b..a872df0f 100644 --- a/horde/apis/v2/__init__.py +++ b/horde/apis/v2/__init__.py @@ -53,3 +53,5 @@ api.add_resource(base.DocsTerms, "/documents/terms") api.add_resource(base.DocsPrivacy, "/documents/privacy") api.add_resource(base.DocsSponsors, "/documents/sponsors") +api.add_resource(base.KnownModels, "/knownmodels") +api.add_resource(base.KnownModelSingle, "/knownmodels/") diff --git a/horde/apis/v2/base.py b/horde/apis/v2/base.py index d4029c64..b6e86951 100644 --- a/horde/apis/v2/base.py +++ b/horde/apis/v2/base.py @@ -5,6 +5,7 @@ import json import os import time +from dataclasses import dataclass from datetime import datetime, timedelta import regex as re @@ -37,6 +38,7 @@ from horde.limiter import limiter from horde.logger import logger from horde.metrics import waitress_metrics +from horde.model_reference import KnownModelRef, model_reference from horde.patreon import patrons from horde.r2 import upload_prompt from horde.suspicions import Suspicions @@ -444,7 +446,7 @@ def post(self): self.prioritized_wp.append(wp) ## End prioritize by bridge request ## for wp in self.get_sorted_wp(): - if wp.id not in [wp.id for wp in self.prioritized_wp]: + if wp.id not in [pwp.id for pwp in self.prioritized_wp]: self.prioritized_wp.append(wp) # logger.warning(datetime.utcnow()) while len(self.prioritized_wp) > 0: @@ -1598,6 +1600,8 @@ def get(self): class Models(Resource): + MODEL_STATES = ["known", "custom", "all"] + get_parser = reqparse.RequestParser() get_parser.add_argument( "Client-Agent", @@ -1637,6 +1641,7 @@ class Models(Resource): required=False, default="all", type=str, + choices=MODEL_STATES, help=( "If 'known', only show stats for known models in the model reference. " "If 'custom' only show stats for custom models. " @@ -1644,6 +1649,14 @@ class Models(Resource): ), location="args", ) + get_parser.add_argument( + "metadata", + required=False, + default=False, + type=bool, + help="Include the model reference metadata in the response.", + location="args", + ) @cache.cached(timeout=2, query_string=True) @api.expect(get_parser) @@ -1657,15 +1670,24 @@ class Models(Resource): def get(self): """Returns a list of models active currently in this horde""" self.args = self.get_parser.parse_args() - if self.args.model_state not in ["known", "custom", "all"]: - raise e.BadRequest("'model_state' needs to be one of ['known', 'custom', 'all']") + if self.args.model_state not in self.MODEL_STATES: + raise e.BadRequest(f"'model_state' needs to be one of {self.MODEL_STATES}") models_ret = database.retrieve_available_models( model_type=self.args.type, min_count=self.args.min_count, max_count=self.args.max_count, model_state=self.args.model_state, ) - return (models_ret, 200) + + # here, augment with the model reference data in "metadata" key + # if args.type ever becomes properly optional, this can be done differently + if self.args.metadata: + if self.args.type == "image": + ref = model_reference.reference + else: + ref = model_reference.text_reference + models_ret = [{**model, "metadata": ref.get(model["name"], {})} for model in models_ret] + return models_ret, 200 class ModelSingle(Resource): @@ -3128,3 +3150,103 @@ def get(self): if self.args.format == "markdown": return {"markdown": markdownify(html_template).strip("\n")}, 200 return {"html": html_template}, 200 + + +@dataclass +class KnownModelsReturn: + name: str + type: str + metadata: KnownModelRef + + +def known_models(model_type: str = None, filter_model_name: str = None, exact: bool = None) -> list[KnownModelsReturn]: + """ + Filter known models from the model reference + + :param model_type: The model type to filter by. + :param filter_model_name: The model name to filter by. + :param exact: If the model name should be an exact match. + + :return: A list of known models. + """ + model_references = { + "image": model_reference.reference, + "text": model_reference.text_reference, + } + + if filter_model_name and not exact: + filter_model_name = filter_model_name.lower() + + matched_models = [] + for slug, source in model_references.items(): + if model_type and model_type != slug: + continue + for model_name, metadata in source.items(): + if filter_model_name: + if exact and model_name != filter_model_name: + continue + if not exact and filter_model_name not in model_name.lower(): + continue + + matched_models.append( + KnownModelsReturn( + name=model_name, + type=slug, + metadata=metadata, + ), + ) + return matched_models + + +class KnownModels(Resource): + args = None + parser = reqparse.RequestParser() + parser.add_argument( + "type", + type=str, + choices=["text", "image"], + required=False, + help="The model type to filter by.", + location="args", + ) + parser.add_argument( + "name", + type=str, + required=False, + help="The model name to filter by.", + location="args", + ) + + @cache.cached(timeout=2, query_string=True) + @api.expect(parser) + @api.response(400, "Validation Error", models.response_model_error) + @api.marshal_with( + models.response_model_known_model, + code=200, + description="List all known models", + as_list=True, + ) + def get(self): + """List all known models and their metadata""" + self.args = self.parser.parse_args() + models_ret = known_models(model_type=self.args.type, filter_model_name=self.args.name) + return models_ret, 200 + + +class KnownModelSingle(Resource): + @cache.cached(timeout=2, query_string=True) + @api.response(400, "Validation Error", models.response_model_error) + @api.response(404, "Model Not Found", models.response_model_error) + @api.marshal_with( + models.response_model_known_model, + code=200, + description="Get a known model", + ) + def get(self, model_name): + """Get the metadata for a known model""" + models_ret = known_models(filter_model_name=model_name, exact=True) + if not models_ret: + raise e.ThingNotFound("Model", model_name) + if len(models_ret) > 1: + raise e.BadRequest("More than one model found with the same name.") + return models_ret[0], 200 diff --git a/horde/database/functions.py b/horde/database/functions.py index 423770d2..c345136c 100644 --- a/horde/database/functions.py +++ b/horde/database/functions.py @@ -265,7 +265,7 @@ def worker_exists(worker_id): return wc -def get_available_models(filter_model_name: str = None): +def get_available_models(filter_model_name: str = None) -> list[dict]: models_dict = {} available_worker_models = None diff --git a/horde/model_reference.py b/horde/model_reference.py index 1f0a4279..680591da 100644 --- a/horde/model_reference.py +++ b/horde/model_reference.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: 2022 Konstantinos Thoukydidis +# SPDX-FileCopyrightText: 2024 ceruleandeep # # SPDX-License-Identifier: AGPL-3.0-or-later import os -from datetime import datetime +from datetime import datetime, timezone import requests @@ -11,85 +12,97 @@ from horde.threads import PrimaryTimedFunction +class KnownModelRef(dict): + """ + Base class for a known model reference entry. + + Known model references need to be typed for RESTX to work properly, + but they need to be dicts for everywhere else in the code. + """ + + +class KnownTextModelRef(KnownModelRef): + """ + A known text model reference entry + """ + + +class KnownImageModelRef(KnownModelRef): + """ + A known image model reference entry + """ + + +DEFAULT_HORDE_IMAGE_COMPVIS_REFERENCE = ( + "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/main/stable_diffusion.json" +) +DEFAULT_HORDE_IMAGE_LLM_REFERENCE = "https://raw.githubusercontent.com/db0/AI-Horde-text-model-reference/main/db.json" +DEFAULT_HORDE_IMAGE_DIFFUSERS_REFERENCE = "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/main/diffusers.json" + +SD_BASELINES = { + "stable diffusion 1", + "stable diffusion 2", + "stable diffusion 2 512", + "stable_diffusion_xl", + "stable_cascade", + "flux_1", +} + + class ModelReference(PrimaryTimedFunction): quorum = None - reference = None - text_reference = None - stable_diffusion_names = set() - text_model_names = set() - nsfw_models = set() - controlnet_models = set() + reference: dict[str, KnownImageModelRef] = None + text_reference: dict[str, KnownTextModelRef] = None + stable_diffusion_names: set[str] = set() + text_model_names: set[str] = set() + nsfw_models: set[str] = set() + controlnet_models: set[str] = set() + # Workaround because users lacking customizer role are getting models not in the reference stripped away. # However due to a racing or caching issue, this causes them to still pick jobs using those models # Need to investigate more to remove this workaround testing_models = {} def call_function(self): - """Retrieves to image and text model reference and stores in it a var""" - # If it's running in SQLITE_MODE, it means it's a test and we never want to grab the quorum - # We don't want to report on any random model name a client might request + """ + Retrieves image and text model references + """ for _riter in range(10): try: - ref_json = "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/main/stable_diffusion.json" - if datetime.utcnow() <= datetime(2024, 9, 30): # Flux Beta - ref_json = ( - "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/refs/heads/flux/stable_diffusion.json" - ) - logger.debug("Using flux beta model reference...") - self.reference = requests.get( - os.getenv( - "HORDE_IMAGE_COMPVIS_REFERENCE", - ref_json, - ), - timeout=2, - ).json() - diffusers = requests.get( - os.getenv( - "HORDE_IMAGE_DIFFUSERS_REFERENCE", - "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/main/diffusers.json", - ), - timeout=2, - ).json() - self.reference.update(diffusers) - # logger.debug(self.reference) - self.stable_diffusion_names = set() - for model in self.reference: - if self.reference[model].get("baseline") in { - "stable diffusion 1", - "stable diffusion 2", - "stable diffusion 2 512", - "stable_diffusion_xl", - "stable_cascade", - "flux_1", - }: - self.stable_diffusion_names.add(model) - if self.reference[model].get("nsfw"): - self.nsfw_models.add(model) - if self.reference[model].get("type") == "controlnet": - self.controlnet_models.add(model) - + self._load_image_models() break except Exception as e: - logger.error(f"Error when downloading nataili models list: {e}") + logger.error(f"Error when downloading image models list: {e}") for _riter in range(10): try: - self.text_reference = requests.get( - os.getenv( - "HORDE_IMAGE_LLM_REFERENCE", - "https://raw.githubusercontent.com/db0/AI-Horde-text-model-reference/main/db.json", - ), - timeout=2, - ).json() - # logger.debug(self.reference) - self.text_model_names = set() - for model in self.text_reference: - self.text_model_names.add(model) - if self.text_reference[model].get("nsfw"): - self.nsfw_models.add(model) + self._load_text_models() break - except Exception as err: - logger.error(f"Error when downloading known models list: {err}") + except Exception as e: + logger.error(f"Error when downloading text models list: {e}") + + def _load_text_models(self): + text_ref_data = requests.get(self._llm_ref_url, timeout=2).json() + self.text_reference = {name: KnownTextModelRef(text_ref_data[name]) for name in text_ref_data} + self.text_model_names = set() + for model in self.text_reference: + self.text_model_names.add(model) + if self.text_reference[model].get("nsfw"): + self.nsfw_models.add(model) + + def _load_image_models(self): + sd_ref_data = requests.get(self._compvis_ref_url, timeout=2).json() + diffuser_ref_data = requests.get(self._diffusers_ref_url, timeout=2).json() + self.reference = {name: KnownImageModelRef(sd_ref_data[name]) for name in sd_ref_data} + self.reference.update({name: KnownImageModelRef(diffuser_ref_data[name]) for name in diffuser_ref_data}) + self.stable_diffusion_names = set() + for model in self.reference: + if self.reference[model].get("baseline") in SD_BASELINES: + self.stable_diffusion_names.add(model) + if self.reference[model].get("nsfw"): + self.nsfw_models.add(model) + if self.reference[model].get("type") == "controlnet": + self.controlnet_models.add(model) def get_image_model_names(self): return set(self.reference.keys()) @@ -124,7 +137,7 @@ def get_text_model_multiplier(self, model_name): if not self.text_reference.get(model_name): return 1 multiplier = int(self.text_reference[model_name]["parameters"]) / 1000000000 - logger.debug(f"{model_name} param multiplier: {multiplier}") + # logger.debug(f"{model_name} param multiplier: {multiplier}") return multiplier def has_inpainting_models(self, model_names): @@ -169,6 +182,26 @@ def has_nsfw_models(self, model_names): # return True return False + @property + def _compvis_ref_url(self): + ref_json = DEFAULT_HORDE_IMAGE_COMPVIS_REFERENCE + if datetime.now(timezone.utc) <= datetime(2024, 9, 30, tzinfo=timezone.utc): + # Flux Beta + # I don't understand how this hack works, but perhaps HORDE_IMAGE_COMPVIS_REFERENCE is unset in prod + ref_json = "https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/refs/heads/flux/stable_diffusion.json" + logger.debug("Using flux beta model reference...") + return os.getenv("HORDE_IMAGE_COMPVIS_REFERENCE", ref_json) + + @property + def _llm_ref_url(self): + # it may not be necessary to constantly pull this from the environment + # but the original code does that so I'm keeping it + return os.getenv("HORDE_IMAGE_LLM_REFERENCE", DEFAULT_HORDE_IMAGE_LLM_REFERENCE) + + @property + def _diffusers_ref_url(self): + return os.getenv("HORDE_IMAGE_DIFFUSERS_REFERENCE", DEFAULT_HORDE_IMAGE_DIFFUSERS_REFERENCE) + model_reference = ModelReference(3600, None) model_reference.call_function()