diff --git a/execution.py b/execution.py index 5f5d6c73834..18005376798 100644 --- a/execution.py +++ b/execution.py @@ -384,7 +384,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): -def validate_inputs(prompt, item, validated): +def validate_inputs(prompt, item, validated, stack=[]): unique_id = item if unique_id in validated: return validated[unique_id] @@ -399,6 +399,22 @@ def validate_inputs(prompt, item, validated): errors = [] valid = True + if unique_id in stack: + error = { + "type": "infinite_loop", + "message": "loop detected in workflow validation", + "details": f"detected at {unique_id}", + "extra_info": {"stack": f"{stack}"}, + } + errors.append(error) + ret = (False, errors, unique_id) + validated[unique_id] = ret + # don't continue, because we're already here further up the stack + return ret + +# add this node to the stack + stack.append(unique_id) + for x in required_inputs: if x not in inputs: error = { @@ -450,7 +466,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue try: - r = validate_inputs(prompt, o_id, validated) + r = validate_inputs(prompt, o_id, validated, stack) if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -577,13 +593,18 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue +# pop the node back off the stack: + stack.pop() + if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) else: ret = (True, [], unique_id) - validated[unique_id] = ret - return ret +# if we had a loop, unique_id will have been marked invalid further down the tree + if unique_id not in validated: + validated[unique_id] = ret + return validated[unique_id] def full_type_name(klass): module = klass.__module__ @@ -615,7 +636,7 @@ def validate_prompt(prompt): valid = False reasons = [] try: - m = validate_inputs(prompt, o, validated) + m = validate_inputs(prompt, o, validated, []) valid = m[0] reasons = m[1] except Exception as ex: @@ -659,6 +680,8 @@ def validate_prompt(prompt): print(f"* {class_type} {node_id}:") for reason in reasons: print(f" - {reason['message']}: {reason['details']}") + if 'extra_info' in reason: + print(f" - {reason['extra_info']}") node_errors[node_id]["dependent_outputs"].append(o) print("Output will be ignored")