Skip to content

Commit

Permalink
fix model_version_id_missing (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccssu authored Nov 26, 2024
1 parent 8c06f53 commit 32623e7
Showing 1 changed file with 128 additions and 2 deletions.
130 changes: 128 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import comfy

from bizyair import BizyAirBaseNode, BizyAirNodeIO, create_node_data, data_types
from bizyair.configs.conf import config_manager
from bizyair.path_utils import path_manager as folder_paths

LOGO = "☁️"
Expand Down Expand Up @@ -242,6 +243,85 @@ def decode(self, vae, samples):


class BizyAir_LoraLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": (data_types.MODEL,),
"clip": (data_types.CLIP,),
"lora_name": (
[
"to choose",
],
),
"strength_model": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
"strength_clip": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
"model_version_id": (
"STRING",
{
"default": "",
},
),
}
}

RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
RETURN_NAMES = ("MODEL", "CLIP")

FUNCTION = "load_lora"
CATEGORY = f"{PREFIX}/loaders"

def load_lora(
self,
model,
clip,
lora_name,
strength_model,
strength_clip,
model_version_id: str = None,
):
assigned_id = self.assigned_id
new_model: BizyAirNodeIO = model.copy(assigned_id)
new_clip: BizyAirNodeIO = clip.copy(assigned_id)
instances: List[BizyAirNodeIO] = [new_model, new_clip]

if model_version_id is not None and model_version_id != "":
# use model version id as lora name
lora_name = (
f"{config_manager.get_model_version_id_prefix()}{model_version_id}"
)

for slot_index, ins in zip(range(2), instances):
ins.add_node_data(
class_type="LoraLoader",
inputs={
"model": model,
"clip": clip,
"lora_name": lora_name,
"strength_model": strength_model,
"strength_clip": strength_clip,
},
outputs={"slot_index": slot_index},
)
return (
new_model,
new_clip,
)

@classmethod
def VALIDATE_INPUTS(cls, lora_name):
if lora_name == "" or lora_name is None:
return False
return True


class BizyAir_LoraLoader_Legacy(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
Expand Down Expand Up @@ -346,6 +426,52 @@ def encode(self, vae, pixels, mask, grow_mask_by=6):


class BizyAir_ControlNetLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"control_net_name": (
[
"to choose",
],
),
"model_version_id": ("STRING", {"default": "", "multiline": False}),
}
}

RETURN_TYPES = (data_types.CONTROL_NET,)
RETURN_NAMES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"

CATEGORY = f"{PREFIX}/loaders"

@classmethod
def VALIDATE_INPUTS(cls, control_net_name, model_version_id):
if control_net_name == "to choose":
return False
if model_version_id is not None and model_version_id != "":
return True
return True

def load_controlnet(self, control_net_name, model_version_id):
if model_version_id is not None and model_version_id != "":
control_net_name = (
f"{config_manager.get_model_version_id_prefix()}{model_version_id}"
)

node_data = create_node_data(
class_type="ControlNetLoader",
inputs={
"control_net_name": control_net_name,
},
outputs={"slot_index": 0},
)
assigned_id = self.assigned_id
node = BizyAirNodeIO(assigned_id, {assigned_id: node_data})
return (node,)


class BizyAir_ControlNetLoader_Legacy(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
Expand Down Expand Up @@ -789,7 +915,7 @@ def INPUT_TYPES(s):
return ret


class SharedLoraLoader(BizyAir_LoraLoader):
class SharedLoraLoader(BizyAir_LoraLoader_Legacy):
@classmethod
def INPUT_TYPES(s):
return {
Expand Down Expand Up @@ -1010,7 +1136,7 @@ def INPUT_TYPES(s):
CATEGORY = "advanced/conditioning"


class SharedControlNetLoader(BizyAir_ControlNetLoader):
class SharedControlNetLoader(BizyAir_ControlNetLoader_Legacy):
@classmethod
def INPUT_TYPES(s):
ret = super().INPUT_TYPES()
Expand Down

0 comments on commit 32623e7

Please sign in to comment.