diff --git a/comfy/caching.py b/comfy/caching.py index 060d53d5584..abcf68ae452 100644 --- a/comfy/caching.py +++ b/comfy/caching.py @@ -122,13 +122,6 @@ def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_map order_mapping[ancestor_id] = len(ancestors) - 1 self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) -class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature): - def __init__(self, dynprompt, node_ids, is_changed_cache): - super().__init__(dynprompt, node_ids, is_changed_cache) - - def include_node_id_in_input(self): - return True - class BasicCache: def __init__(self, key_class): self.key_class = key_class @@ -151,10 +144,8 @@ def all_node_ids(self): node_ids = node_ids.union(subcache.all_node_ids()) return node_ids - def clean_unused(self): - assert self.initialized + def _clean_cache(self): preserve_keys = set(self.cache_key_set.get_used_keys()) - preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) to_remove = [] for key in self.cache: if key not in preserve_keys: @@ -162,6 +153,9 @@ def clean_unused(self): for key in to_remove: del self.cache[key] + def _clean_subcaches(self): + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) + to_remove = [] for key in self.subcaches: if key not in preserve_subcaches: @@ -169,6 +163,11 @@ def clean_unused(self): for key in to_remove: del self.subcaches[key] + def clean_unused(self): + assert self.initialized + self._clean_cache() + self._clean_subcaches() + def _set_immediate(self, node_id, value): assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) @@ -246,15 +245,6 @@ def ensure_subcache_for(self, node_id, children_ids): assert cache is not None return cache._ensure_subcache(node_id, children_ids) - def all_active_values(self): - active_nodes = self.all_node_ids() - result = [] - for node_id in active_nodes: - value = self.get(node_id) - if value is not None: - result.append(value) - return result - class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -279,6 +269,7 @@ def clean_unused(self): del self.used_generation[key] if key in self.children: del self.children[key] + self._clean_subcaches() def get(self, node_id): self._mark_used(node_id) @@ -294,6 +285,9 @@ def set(self, node_id, value): return self._set_immediate(node_id, value) def ensure_subcache_for(self, node_id, children_ids): + # Just uses subcaches for tracking 'live' nodes + super()._ensure_subcache(node_id, children_ids) + self.cache_key_set.add_keys(children_ids) self._mark_used(node_id) cache_key = self.cache_key_set.get_data_key(node_id) @@ -303,15 +297,3 @@ def ensure_subcache_for(self, node_id, children_ids): self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self - def all_active_values(self): - explored = set() - to_explore = set(self.cache_key_set.get_used_keys()) - while len(to_explore) > 0: - cache_key = to_explore.pop() - if cache_key not in explored: - self.used_generation[cache_key] = self.generation - explored.add(cache_key) - if cache_key in self.children: - to_explore.update(self.children[cache_key]) - return [self.cache[key] for key in explored if key in self.cache] - diff --git a/execution.py b/execution.py index 5895e1031cc..b5dd94f5e13 100644 --- a/execution.py +++ b/execution.py @@ -15,7 +15,7 @@ import comfy.graph_utils from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy.graph_utils import is_link, GraphBuilder -from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID +from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from comfy.cli_args import args class ExecutionResult(Enum): @@ -69,13 +69,13 @@ def __init__(self, lru_size=None): # blowing away the cache every time def init_lru_cache(self, cache_size): self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) # Performs like the old cache -- dump data ASAP def init_classic_cache(self): self.outputs = HierarchicalCache(CacheKeySetInputSignature) - self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID) + self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def recursive_debug_dump(self): @@ -486,10 +486,12 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): ui_outputs = {} meta_outputs = {} - for ui_info in self.caches.ui.all_active_values(): - node_id = ui_info["meta"]["node_id"] - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] + all_node_ids = self.caches.ui.all_node_ids() + for node_id in all_node_ids: + ui_info = self.caches.ui.get(node_id) + if ui_info is not None: + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] self.history_result = { "outputs": ui_outputs, "meta": meta_outputs, diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 1fb58d5e0fe..e9f93797622 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -117,16 +117,26 @@ class TestExecution: # # Initialize server and client # - @fixture(scope="class", autouse=True) - def _server(self, args_pytest): + @fixture(scope="class", autouse=True, params=[ + # (use_lru, lru_size) + (False, 0), + (True, 0), + (True, 100), + ]) + def _server(self, args_pytest, request): # Start server - p = subprocess.Popen([ - 'python','main.py', - '--output-directory', args_pytest["output_dir"], - '--listen', args_pytest["listen"], - '--port', str(args_pytest["port"]), - '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', - ]) + pargs = [ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + ] + use_lru, lru_size = request.param + if use_lru: + pargs += ['--cache-lru', str(lru_size)] + print("Running server with args:", pargs) + p = subprocess.Popen(pargs) yield p.kill() torch.cuda.empty_cache() @@ -159,15 +169,9 @@ def client(self, shared_client, request): shared_client.set_test_name(f"execution[{request.node.name}]") yield shared_client - def clear_cache(self, client: ComfyClient): - g = GraphBuilder(prefix="foo") - random = g.node("StubImage", content="NOISE", height=1, width=1, batch_size=1) - g.node("PreviewImage", images=random.out(0)) - client.run(g) - @fixture - def builder(self): - yield GraphBuilder(prefix="") + def builder(self, request): + yield GraphBuilder(prefix=request.node.name) def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -187,7 +191,6 @@ def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): assert result.did_run(lazy_mix) def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): - self.clear_cache(client) g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -196,14 +199,12 @@ def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) g.node("SaveImage", images=lazy_mix.out(0)) - result1 = client.run(g) + client.run(g) result2 = client.run(g) for node_id, node in g.nodes.items(): - assert result1.did_run(node), f"Node {node_id} didn't run" assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): - self.clear_cache(client) g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -212,15 +213,11 @@ def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) g.node("SaveImage", images=lazy_mix.out(0)) - result1 = client.run(g) + client.run(g) mask.inputs['value'] = 0.4 result2 = client.run(g) - for node_id, node in g.nodes.items(): - assert result1.did_run(node), f"Node {node_id} didn't run" assert not result2.did_run(input1), "Input1 should have been cached" assert not result2.did_run(input2), "Input2 should have been cached" - assert result2.did_run(mask), "Mask should have been re-run" - assert result2.did_run(lazy_mix), "Lazy mix should have been re-run" def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -365,7 +362,6 @@ def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): assert result4.did_run(is_changed), "is_changed should not have been cached" def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder): - self.clear_cache(client) g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -378,8 +374,6 @@ def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder): result_image = result.get_images(output)[0] expected = 255 // 4 assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" - assert result.did_run(input1) - assert result.did_run(input2) def test_for_loop(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -418,3 +412,17 @@ def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilde assert len(images_literal) == 3, "Should have 2 images" for i in range(3): assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white" + + def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + output1 = g.node("PreviewImage", images=input1.out(0)) + output2 = g.node("PreviewImage", images=input1.out(0)) + + result = client.run(g) + images1 = result.get_images(output1) + images2 = result.get_images(output2) + assert len(images1) == 1, "Should have 1 image" + assert len(images2) == 1, "Should have 1 image" +