From afa4c7b26089c6d3c47db9da5c56908e902dae4a Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 21 Apr 2024 21:55:03 -0700 Subject: [PATCH] Refactor `map_node_over_list` function --- execution.py | 50 ++++++++++++++++---------------------------------- 1 file changed, 16 insertions(+), 34 deletions(-) diff --git a/execution.py b/execution.py index 1e35da8db1a..5895e1031cc 100644 --- a/execution.py +++ b/execution.py @@ -128,59 +128,41 @@ def mark_missing(): def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists - input_is_list = False - if hasattr(obj, "INPUT_IS_LIST"): - input_is_list = obj.INPUT_IS_LIST + input_is_list = getattr(obj, "INPUT_IS_LIST", False) if len(input_data_all) == 0: max_len_input = 0 else: - max_len_input = max([len(x) for x in input_data_all.values()]) + max_len_input = max(len(x) for x in input_data_all.values()) # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): - d_new = dict() - for k,v in d.items(): - d_new[k] = v[i if len(v) > i else -1] - return d_new + return {k: v[i if len(v) > i else -1] for k, v in d.items()} results = [] - if input_is_list: + def process_inputs(inputs, index=None): if allow_interrupt: nodes.before_node_execution() execution_block = None - for k, v in input_data_all.items(): - for input in v: - if isinstance(v, ExecutionBlocker): - execution_block = execution_block_cb(v) if execution_block_cb is not None else v - break - + for k, v in inputs.items(): + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb else v + break if execution_block is None: - if pre_execute_cb is not None: - pre_execute_cb(0) - results.append(getattr(obj, func)(**input_data_all)) + if pre_execute_cb is not None and index is not None: + pre_execute_cb(index) + results.append(getattr(obj, func)(**inputs)) else: results.append(execution_block) + + if input_is_list: + process_inputs(input_data_all, 0) elif max_len_input == 0: - if allow_interrupt: - nodes.before_node_execution() - results.append(getattr(obj, func)()) + process_inputs({}) else: for i in range(max_len_input): - if allow_interrupt: - nodes.before_node_execution() input_dict = slice_dict(input_data_all, i) - execution_block = None - for k, v in input_dict.items(): - if isinstance(v, ExecutionBlocker): - execution_block = execution_block_cb(v) if execution_block_cb is not None else v - break - if execution_block is None: - if pre_execute_cb is not None: - pre_execute_cb(i) - results.append(getattr(obj, func)(**input_dict)) - else: - results.append(execution_block) + process_inputs(input_dict, i) return results def merge_result_data(results, obj):