Skip to content

Commit

Permalink
Make custom VALIDATE_INPUTS skip normal validation
Browse files Browse the repository at this point in the history
Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`,
that variable will be a dictionary of the socket type of all incoming
connections. If that argument exists, normal socket type validation will
not occur. This removes the last hurdle for enabling variant types
entirely from custom nodes, so I've removed that command-line option.

I've added appropriate unit tests for these changes.
  • Loading branch information
guill committed Feb 25, 2024
1 parent 5ab1565 commit 6d09dd7
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 39 deletions.
1 change: 0 additions & 1 deletion comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")

parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--enable-variants", action="store_true", help="Enables '*' type nodes.")

if comfy.options.args_parsing:
args = parser.parse_args()
Expand Down
61 changes: 33 additions & 28 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro
cached_output = outputs.get(input_unique_id)
if cached_output is None:
continue
if output_index >= len(cached_output):
continue
obj = cached_output[output_index]
input_data_all[x] = obj
elif input_category is not None:
Expand Down Expand Up @@ -514,6 +516,7 @@ def validate_inputs(prompt, item, validated):
validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
received_types = {}

for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
Expand Down Expand Up @@ -551,9 +554,9 @@ def validate_inputs(prompt, item, validated):
o_id = val[0]
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
is_variant = args.enable_variants and (r[val[1]] == "*" or type_input == "*")
if r[val[1]] != type_input and not is_variant:
received_type = r[val[1]]
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",
Expand Down Expand Up @@ -622,34 +625,34 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue

if "min" in extra_info and val < extra_info["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
if x not in validate_function_inputs:
if "min" in extra_info and val < extra_info["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
}
errors.append(error)
continue
if "max" in extra_info and val > extra_info["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
errors.append(error)
continue
if "max" in extra_info and val > extra_info["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
}
errors.append(error)
continue
errors.append(error)
continue

if x not in validate_function_inputs:
if isinstance(type_input, list):
if val not in type_input:
input_config = info
Expand Down Expand Up @@ -682,6 +685,8 @@ def validate_inputs(prompt, item, validated):
for x in input_data_all:
if x in validate_function_inputs:
input_filtered[x] = input_data_all[x]
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]

#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
Expand Down
63 changes: 62 additions & 1 deletion tests/inference/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import uuid
import urllib.request
import urllib.parse
import urllib.error
from comfy.graph_utils import GraphBuilder, Node

class RunResult:
Expand Down Expand Up @@ -125,7 +126,6 @@ def _server(self, args_pytest):
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--enable-variants',
])
yield
p.kill()
Expand Down Expand Up @@ -237,6 +237,67 @@ def test_error(self, client: ComfyClient, builder: GraphBuilder):
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"

@pytest.mark.parametrize("test_value, expect_error", [
(5, True),
("foo", True),
(5.0, False),
])
def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
g.node("SaveImage", images=validation1.out(0))

if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)

@pytest.mark.parametrize("test_type, test_value", [
("StubInt", 5),
("StubFloat", 5.0)
])
def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation1.out(0))

with pytest.raises(urllib.error.HTTPError):
client.run(g)

@pytest.mark.parametrize("test_type, test_value, expect_error", [
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation2.out(0))

if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)

@pytest.mark.parametrize("test_type, test_value, expect_error", [
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation3.out(0))

if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)

def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Creating the nodes in this specific order previously caused a bug
Expand Down
4 changes: 4 additions & 0 deletions tests/inference/testing_nodes/testing-pack/flow_control.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from comfy.graph_utils import GraphBuilder, is_link
from comfy.graph import ExecutionBlocker
from .tools import VariantSupport

NUM_FLOW_SOCKETS = 5
@VariantSupport()
class TestWhileLoopOpen:
def __init__(self):
pass
Expand Down Expand Up @@ -31,6 +33,7 @@ def while_loop_open(self, condition, **kwargs):
values.append(kwargs.get("initial_value%d" % i, None))
return tuple(["stub"] + values)

@VariantSupport()
class TestWhileLoopClose:
def __init__(self):
pass
Expand Down Expand Up @@ -131,6 +134,7 @@ def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=No
"expand": graph.finalize(),
}

@VariantSupport()
class TestExecutionBlockerNode:
def __init__(self):
pass
Expand Down
112 changes: 103 additions & 9 deletions tests/inference/testing_nodes/testing-pack/specific_tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import torch
from .tools import VariantSupport

class TestLazyMixImages:
def __init__(self):
pass

@classmethod
def INPUT_TYPES(cls):
return {
Expand Down Expand Up @@ -50,9 +48,6 @@ def mix(self, mask, image1 = None, image2 = None):
return (result[0],)

class TestVariadicAverage:
def __init__(self):
pass

@classmethod
def INPUT_TYPES(cls):
return {
Expand All @@ -74,9 +69,6 @@ def variadic_average(self, input1, **kwargs):


class TestCustomIsChanged:
def __init__(self):
pass

@classmethod
def INPUT_TYPES(cls):
return {
Expand All @@ -103,14 +95,116 @@ def IS_CHANGED(cls, should_change=False, *args, **kwargs):
else:
return False

class TestCustomValidation1:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation1"

CATEGORY = "Testing/Nodes"

def custom_validation1(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)

@classmethod
def VALIDATE_INPUTS(cls, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"

return True

class TestCustomValidation2:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation2"

CATEGORY = "Testing/Nodes"

def custom_validation2(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)

@classmethod
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"

if 'input1' in input_types:
if input_types['input1'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input1: {input_types['input1']}"
if 'input2' in input_types:
if input_types['input2'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input2: {input_types['input2']}"

return True

@VariantSupport()
class TestCustomValidation3:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation3"

CATEGORY = "Testing/Nodes"

def custom_validation3(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)

TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
"TestCustomIsChanged": TestCustomIsChanged,
"TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
}

TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestLazyMixImages": "Lazy Mix Images",
"TestVariadicAverage": "Variadic Average",
"TestCustomIsChanged": "Custom IsChanged",
"TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",
}
Loading

0 comments on commit 6d09dd7

Please sign in to comment.