Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX recursive_will_execute performance (simple ~300x performance increase} #2852

Merged
merged 4 commits into from
Feb 22, 2024

Conversation

ricklove
Copy link
Contributor

This solves a major performance problem that makes large graphs impossible to use in ComfyUI. It also speeds up even medium size graphs that have a long chain of dependent nodes.

Performance improvement example:

  • sorting nodes len(to_execute): 40
  • recursive_will_execute Execution time: 46.01077870000154 secs
  • recursive_will_execute_len Execution time: 0.1593280000379309 secs

image

Measurement code (in execution.py):

                #always execute the output that depends on the least amount of unexecuted nodes first
                print(f'sorting nodes len(to_execute):{len(to_execute)}')

                start_time = time.perf_counter()
                to_execute_old = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
                end_time = time.perf_counter()
                execution_time = end_time - start_time
                print(f'recursive_will_execute Execution time: {execution_time} secs')

                start_time = time.perf_counter()
                memo = {}
                to_execute_old = sorted(list(map(lambda a: (len(recursive_will_execute_memo(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
                end_time = time.perf_counter()
                execution_time = end_time - start_time
                print(f'recursive_will_execute_memo Execution time: {execution_time} secs')

                start_time = time.perf_counter()
                memo = {}
                to_execute = sorted(list(map(lambda a: ((recursive_will_execute_len(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
                end_time = time.perf_counter()
                execution_time = end_time - start_time
                print(f'recursive_will_execute_len Execution time: {execution_time} secs')

functions:


def recursive_will_execute(prompt, outputs, current_item):
    unique_id = current_item
    # print(f'recursive_will_execute {unique_id}')

    inputs = prompt[unique_id]['inputs']
    will_execute = []
    if unique_id in outputs:
        return []

    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
                will_execute += recursive_will_execute(prompt, outputs, input_unique_id)

    return will_execute + [unique_id]

def recursive_will_execute_memo(prompt, outputs, current_item, memo):
    unique_id = current_item
    # print(f'recursive_will_execute_memo {unique_id}')

    if unique_id in memo:
        return memo[unique_id]

    inputs = prompt[unique_id]['inputs']
    will_execute = []
    if unique_id in outputs:
        return []

    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
                will_execute += recursive_will_execute_memo(prompt, outputs, input_unique_id, memo)

    memo[unique_id] = will_execute + [unique_id]
    return memo[unique_id]

def recursive_will_execute_len(prompt, outputs, current_item, memo):
    unique_id = current_item
    # print(f'recursive_will_execute_len {unique_id}')
    
    if unique_id in memo:
        return memo[unique_id]

    inputs = prompt[unique_id]['inputs']
    will_execute = 0
    if unique_id in outputs:
        return 0

    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
                will_execute += recursive_will_execute_len(prompt, outputs, input_unique_id, memo)

    memo[unique_id] = will_execute + 1
    return memo[unique_id]

@ricklove
Copy link
Contributor Author

#2666 related

@ricklove
Copy link
Contributor Author

ricklove commented Feb 20, 2024

Here is a simpler version for quick patching after an update:

def recursive_will_execute(prompt, outputs, current_item, memo={}):
    unique_id = current_item

    if unique_id in memo:
        return memo[unique_id]

    inputs = prompt[unique_id]['inputs']
    will_execute = []
    if unique_id in outputs:
        return []

    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
                will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)

    memo[unique_id] = will_execute + [unique_id]
    return memo[unique_id]

Edit: this also needs a memo={} created where it is called (for best performance), but this PR is now using this

@ricklove
Copy link
Contributor Author

I switched to the version that has minimal code changes so that any external calls to this function will still work as expected

Copy link
Contributor Author

@ricklove ricklove left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

notes

@@ -194,8 +194,12 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute

return (True, None, None)

def recursive_will_execute(prompt, outputs, current_item):
def recursive_will_execute(prompt, outputs, current_item, memo={}):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be backwards compatible with any external callers of this function.

@@ -377,7 +382,8 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):

while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
memo = {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memo should be created outside the lambda so it is reused for entire sorting algorithm

@guill
Copy link
Contributor

guill commented Feb 21, 2024

@ricklove Could you share the workflow and/or code you've used to perf test this? I believe I addressed the combinatorial explosion in #2666 as well (in a more invasive way which likely makes this PR worth merging in the meantime), but I'd like to make sure I'm comparing apples to apples.

@ricklove
Copy link
Contributor Author

ricklove commented Feb 21, 2024

@ricklove Could you share the workflow and/or code you've used to perf test this? I believe I addressed the combinatorial explosion in #2666 as well (in a more invasive way which likely makes this PR worth merging in the meantime), but I'd like to make sure I'm comparing apples to apples.

Sure! I'll create a minimal workflow that uses only basic nodes and has this problem.

btw: I'm excited to see your PR, nested workflows/components are the killer feature (we just have to solve performance)

@ricklove
Copy link
Contributor Author

This is a workflow that demonstrates the problem. You can also easy change the complexity level with ctrl+b to make chunks of the graph passthrough.

performance-sort-complexity-01.json

@ricklove
Copy link
Contributor Author

This is a workflow that demonstrates the problem. You can also easy change the complexity level with ctrl+b to make chunks of the graph passthrough.

performance-sort-complexity-01.json

@guill

@comfyanonymous comfyanonymous merged commit f81dbe2 into comfyanonymous:master Feb 22, 2024
1 check passed
@rgthree
Copy link
Contributor

rgthree commented Feb 24, 2024

It's nice to see this [finally] getting some attention.

This PR doesn't quite solve the entire issues here, and still suffers from millions of wasted cycles and hundreds of seconds--especially on re-execution--for complex workflows. The PR I sent months ago does, though is a bite more complex; perhaps it need not be (#1503 for issue #1502).

Without combing through the differences in the memoization itself, the broad differences are:

  1. Apply the same to recursive_output_delete_if_changed
  2. Move the memo init above the while, since the prompt doesn't change at all.

@ricklove Do you think you could do this (since, outside PRs seem to fall on deaf ears). Tens of thousands of folks have already using my PR for months, as it's integrated from the outside in rgthree-comfy. I don't know if you intended/checked that it didn't break all those workflows, but it didn't, so thanks for that :)

xingren23 pushed a commit to xingren23/ComfyUI that referenced this pull request Mar 8, 2024
…ease} (comfyanonymous#2852)

* FIX recursive_will_execute performance

* Minimize code changes

* memo must be created outside lambda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants