From 8a92ac2120109d08611148c00703a0305e48cbc6 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 22 Jan 2024 18:56:43 +0000 Subject: [PATCH 01/64] Ability to hide menu Responsive setting screen Touch events for zooming/context menu --- web/extensions/core/simpleTouchSupport.js | 102 +++++++++++++ web/scripts/ui.js | 34 ++++- web/scripts/ui/settings.js | 12 +- web/style.css | 166 ++++++++++++++-------- 4 files changed, 252 insertions(+), 62 deletions(-) create mode 100644 web/extensions/core/simpleTouchSupport.js diff --git a/web/extensions/core/simpleTouchSupport.js b/web/extensions/core/simpleTouchSupport.js new file mode 100644 index 00000000000..041fc2c4ca9 --- /dev/null +++ b/web/extensions/core/simpleTouchSupport.js @@ -0,0 +1,102 @@ +import { app } from "../../scripts/app.js"; + +let touchZooming; +let touchCount = 0; + +app.registerExtension({ + name: "Comfy.SimpleTouchSupport", + setup() { + let zoomPos; + let touchTime; + let lastTouch; + + function getMultiTouchPos(e) { + return Math.hypot(e.touches[0].clientX - e.touches[1].clientX, e.touches[0].clientY - e.touches[1].clientY); + } + + app.canvasEl.addEventListener( + "touchstart", + (e) => { + touchCount++; + lastTouch = null; + if (e.touches?.length === 1) { + // Store start time for press+hold for context menu + touchTime = new Date(); + lastTouch = e.touches[0]; + } else { + touchTime = null; + if (e.touches?.length === 2) { + // Store center pos for zoom + zoomPos = getMultiTouchPos(e); + app.canvas.pointer_is_down = false; + } + } + }, + true + ); + + app.canvasEl.addEventListener("touchend", (e) => { + touchZooming = false; + touchCount = e.touches?.length ?? touchCount - 1; + if (touchTime && !e.touches?.length) { + if (new Date() - touchTime > 600) { + try { + // hack to get litegraph to use this event + e.constructor = CustomEvent; + } catch (error) {} + e.clientX = lastTouch.clientX; + e.clientY = lastTouch.clientY; + + app.canvas.pointer_is_down = true; + app.canvas._mousedown_callback(e); + } + touchTime = null; + } + }); + + app.canvasEl.addEventListener( + "touchmove", + (e) => { + touchTime = null; + if (e.touches?.length === 2) { + app.canvas.pointer_is_down = false; + touchZooming = true; + LiteGraph.closeAllContextMenus(); + app.canvas.search_box?.close(); + const newZoomPos = getMultiTouchPos(e); + + const midX = (e.touches[0].clientX + e.touches[1].clientX) / 2; + const midY = (e.touches[0].clientY + e.touches[1].clientY) / 2; + + let scale = app.canvas.ds.scale; + const diff = zoomPos - newZoomPos; + if (diff > 0.5) { + scale *= 1 / 1.07; + } else if (diff < -0.5) { + scale *= 1.07; + } + app.canvas.ds.changeScale(scale, [midX, midY]); + app.canvas.setDirty(true, true); + zoomPos = newZoomPos; + } + }, + true + ); + }, +}); + +const processMouseDown = LGraphCanvas.prototype.processMouseDown; +LGraphCanvas.prototype.processMouseDown = function (e) { + if (touchZooming || touchCount) { + return; + } + return processMouseDown.apply(this, arguments); +}; + +const processMouseMove = LGraphCanvas.prototype.processMouseMove; +LGraphCanvas.prototype.processMouseMove = function (e) { + if (touchZooming || touchCount > 1) { + return; + } + return processMouseMove.apply(this, arguments); +}; diff --git a/web/scripts/ui.js b/web/scripts/ui.js index d4835c6e445..5ca6214ebca 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -394,18 +394,42 @@ export class ComfyUI { } }); - this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [ - $el("div.drag-handle", { + this.menuHamburger = $el( + "div.comfy-menu-hamburger", + { + parent: document.body, + onclick: () => { + this.menuContainer.style.display = "block"; + this.menuHamburger.style.display = "none"; + }, + }, + [$el("div"), $el("div"), $el("div")] + ); + + this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ + $el("div.drag-handle.comfy-menu-header", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } - }, [ + }, [ $el("span.drag-handle"), - $el("span", {$: (q) => (this.queueSize = q)}), - $el("button.comfy-settings-btn", {textContent: "⚙️", onclick: () => this.settings.show()}), + $el("span.comfy-menu-queue-size", { $: (q) => (this.queueSize = q) }), + $el("div.comfy-menu-actions", [ + $el("button.comfy-settings-btn", { + textContent: "⚙️", + onclick: () => this.settings.show(), + }), + $el("button.comfy-close-menu-btn", { + textContent: "\u00d7", + onclick: () => { + this.menuContainer.style.display = "none"; + this.menuHamburger.style.display = "flex"; + }, + }), + ]), ]), $el("button.comfy-queue-btn", { id: "queue-button", diff --git a/web/scripts/ui/settings.js b/web/scripts/ui/settings.js index 1cdba5cfe40..9e9d13af00b 100644 --- a/web/scripts/ui/settings.js +++ b/web/scripts/ui/settings.js @@ -16,7 +16,17 @@ export class ComfySettingsDialog extends ComfyDialog { }, [ $el("table.comfy-modal-content.comfy-table", [ - $el("caption", { textContent: "Settings" }), + $el( + "caption", + { textContent: "Settings" }, + $el("button.comfy-btn", { + type: "button", + textContent: "\u00d7", + onclick: () => { + this.element.close(); + }, + }) + ), $el("tbody", { $: (tbody) => (this.textElement = tbody) }), $el("button", { type: "button", diff --git a/web/style.css b/web/style.css index 44ee6019885..863840b2866 100644 --- a/web/style.css +++ b/web/style.css @@ -82,6 +82,24 @@ body { margin: 3px 3px 3px 4px; } +.comfy-menu-hamburger { + position: fixed; + top: 10px; + z-index: 9999; + right: 10px; + width: 30px; + display: none; + gap: 8px; + flex-direction: column; + cursor: pointer; +} +.comfy-menu-hamburger div { + height: 3px; + width: 100%; + border-radius: 20px; + background-color: white; +} + .comfy-menu { font-size: 15px; position: absolute; @@ -101,6 +119,44 @@ body { box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); } +.comfy-menu-header { + display: flex; +} + +.comfy-menu-actions { + display: flex; + gap: 3px; + align-items: center; + height: 20px; + position: relative; + top: -1px; + font-size: 22px; +} + +.comfy-menu .comfy-menu-actions button { + background-color: rgba(0, 0, 0, 0); + padding: 0; + border: none; + cursor: pointer; + font-size: inherit; +} + +.comfy-menu .comfy-menu-actions .comfy-settings-btn { + font-size: 0.6em; +} + +button.comfy-close-menu-btn { + font-size: 1em; + line-height: 12px; + color: #ccc; + position: relative; + top: -1px; +} + +.comfy-menu-queue-size { + flex: auto; +} + .comfy-menu button, .comfy-modal button { font-size: 20px; @@ -121,7 +177,6 @@ body { width: 100%; } -.comfy-toggle-switch, .comfy-btn, .comfy-menu > button, .comfy-menu-btns button, @@ -140,17 +195,11 @@ body { .comfy-menu-btns button:hover, .comfy-menu .comfy-list button:hover, .comfy-modal button:hover, -.comfy-settings-btn:hover { +.comfy-menu-actions button:hover { filter: brightness(1.2); cursor: pointer; } -.comfy-menu span.drag-handle { - position: absolute; - top: 0; - left: 0; -} - span.drag-handle { width: 10px; height: 20px; @@ -215,15 +264,6 @@ span.drag-handle::after { font-size: 12px; } -button.comfy-settings-btn { - background-color: rgba(0, 0, 0, 0); - font-size: 12px; - padding: 0; - position: absolute; - right: 0; - border: none; -} - button.comfy-queue-btn { margin: 6px 0 !important; } @@ -269,7 +309,19 @@ button.comfy-queue-btn { } .comfy-menu span.drag-handle { - visibility: hidden + display: none; + } + + .comfy-menu-queue-size { + flex: unset; + } + + .comfy-menu-header { + justify-content: space-between; + } + .comfy-menu-actions { + gap: 10px; + font-size: 28px; } } @@ -320,7 +372,7 @@ dialog::backdrop { text-align: right; } -#comfy-settings-dialog button { +#comfy-settings-dialog tbody button, #comfy-settings-dialog table > button { background-color: var(--bg-color); border: 1px var(--border-color) solid; border-radius: 0; @@ -343,12 +395,33 @@ dialog::backdrop { } .comfy-table caption { + position: sticky; + top: 0; background-color: var(--bg-color); color: var(--input-text); font-size: 1rem; font-weight: bold; padding: 8px; text-align: center; + border-bottom: 1px solid var(--border-color); +} + +.comfy-table caption .comfy-btn { + position: absolute; + top: -2px; + right: 0; + bottom: 0; + cursor: pointer; + border: none; + height: 100%; + border-radius: 0; + aspect-ratio: 1/1; + user-select: none; + font-size: 20px; +} + +.comfy-table caption .comfy-btn:focus { + outline: none; } .comfy-table tr:nth-child(even) { @@ -435,43 +508,6 @@ dialog::backdrop { margin-left: 5px; } -.comfy-toggle-switch { - border-width: 2px; - display: flex; - background-color: var(--comfy-input-bg); - margin: 2px 0; - white-space: nowrap; -} - -.comfy-toggle-switch label { - padding: 2px 0px 3px 6px; - flex: auto; - border-radius: 8px; - align-items: center; - display: flex; - justify-content: center; -} - -.comfy-toggle-switch label:first-child { - border-top-left-radius: 8px; - border-bottom-left-radius: 8px; -} -.comfy-toggle-switch label:last-child { - border-top-right-radius: 8px; - border-bottom-right-radius: 8px; -} - -.comfy-toggle-switch .comfy-toggle-selected { - background-color: var(--comfy-menu-bg); -} - -#extraOptions { - padding: 4px; - background-color: var(--bg-color); - margin-bottom: 4px; - border-radius: 4px; -} - /* Search box */ .litegraph.litesearchbox { @@ -498,3 +534,21 @@ dialog::backdrop { color: var(--input-text); filter: brightness(50%); } + +@media only screen and (max-width: 450px) { + #comfy-settings-dialog .comfy-table tbody { + display: grid; + } + #comfy-settings-dialog .comfy-table tr { + display: grid; + } + #comfy-settings-dialog tr > td:first-child { + text-align: center; + border-bottom: none; + padding-bottom: 0; + } + #comfy-settings-dialog tr > td:not(:first-child) { + text-align: center; + border-top: none; + } +} \ No newline at end of file From 9321198da6c191fab0ab3ea39e626c9e3d9053d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 Jan 2024 00:24:53 -0500 Subject: [PATCH 02/64] Add node to set only the conditioning area strength. --- nodes.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nodes.py b/nodes.py index 4ad35f79b5a..fe38be9dfef 100644 --- a/nodes.py +++ b/nodes.py @@ -184,6 +184,26 @@ def append(self, conditioning, width, height, x, y, strength): c.append(n) return (c, ) +class ConditioningSetAreaStrength: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, strength): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['strength'] = strength + c.append(n) + return (c, ) + + class ConditioningSetMask: @classmethod def INPUT_TYPES(s): @@ -1754,6 +1774,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): "ConditioningConcat": ConditioningConcat, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, + "ConditioningSetAreaStrength": ConditioningSetAreaStrength, "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, From ed2fa105ae29af6621232dd8ef622ff1e3346b3f Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 29 Jan 2024 18:43:59 +0000 Subject: [PATCH 03/64] Make auto saved workflow stored per tab --- web/scripts/api.js | 4 +++- web/scripts/app.js | 25 ++++++++++++++++++------- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/web/scripts/api.js b/web/scripts/api.js index 3a9bcc87a4e..8c8155be66c 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -5,6 +5,7 @@ class ComfyApi extends EventTarget { super(); this.api_host = location.host; this.api_base = location.pathname.split('/').slice(0, -1).join('/'); + this.initialClientId = sessionStorage.getItem("clientId"); } apiURL(route) { @@ -118,7 +119,8 @@ class ComfyApi extends EventTarget { case "status": if (msg.data.sid) { this.clientId = msg.data.sid; - window.name = this.clientId; + window.name = this.clientId; // use window name so it isnt reused when duplicating tabs + sessionStorage.setItem("clientId", this.clientId); // store in session storage so duplicate tab can load correct workflow } this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); break; diff --git a/web/scripts/app.js b/web/scripts/app.js index 6df393ba60d..b3a84899331 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1499,12 +1499,17 @@ export class ComfyApp { // Load previous workflow let restored = false; try { - const json = localStorage.getItem("workflow"); - if (json) { - const workflow = JSON.parse(json); - await this.loadGraphData(workflow); - restored = true; - } + const loadWorkflow = async (json) => { + if (json) { + const workflow = JSON.parse(json); + await this.loadGraphData(workflow); + return true; + } + }; + const clientId = api.initialClientId ?? api.clientId; + restored = + (clientId && (await loadWorkflow(sessionStorage.getItem(`workflow:${clientId}`)))) || + (await loadWorkflow(localStorage.getItem("workflow"))); } catch (err) { console.error("Error loading previous workflow", err); } @@ -1515,7 +1520,13 @@ export class ComfyApp { } // Save current workflow automatically - setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.graph.serialize())), 1000); + setInterval(() => { + const workflow = JSON.stringify(this.graph.serialize()); + localStorage.setItem("workflow", workflow); + if (api.clientId) { + sessionStorage.setItem(`workflow:${api.clientId}`, workflow); + } + }, 1000); this.#addDrawNodeHandler(); this.#addDrawGroupsHandler(); From 364ef19354c70fd8d0b072b4a69cd9f6271155cf Mon Sep 17 00:00:00 2001 From: Meowu <474384902@qq.com> Date: Tue, 30 Jan 2024 14:23:01 +0800 Subject: [PATCH 04/64] fix: inpaint on mask editor bottom area --- web/extensions/core/maskeditor.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index bb2f16d42b5..f6b035bdce1 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -110,6 +110,7 @@ class MaskEditorDialog extends ComfyDialog { createButton(name, callback) { var button = document.createElement("button"); + button.style.pointerEvents = "auto"; button.innerText = name; button.addEventListener("click", callback); return button; @@ -146,6 +147,7 @@ class MaskEditorDialog extends ComfyDialog { divElement.style.display = "flex"; divElement.style.position = "relative"; divElement.style.top = "2px"; + divElement.style.pointerEvents = "auto"; self.brush_slider_input = document.createElement('input'); self.brush_slider_input.setAttribute('type', 'range'); self.brush_slider_input.setAttribute('min', '1'); @@ -173,6 +175,7 @@ class MaskEditorDialog extends ComfyDialog { bottom_panel.style.left = "20px"; bottom_panel.style.right = "20px"; bottom_panel.style.height = "50px"; + bottom_panel.style.pointerEvents = "none"; var brush = document.createElement("div"); brush.id = "brush"; From da7a8df0d2582c8dc91e5afafe51300899c91392 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 Jan 2024 02:24:38 -0500 Subject: [PATCH 05/64] Put VAE key name in model config. --- comfy/sd.py | 2 +- comfy/supported_models_base.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 9ca9d1d1209..c15d73fed5e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -462,7 +462,7 @@ class WeightsLoader(torch.nn.Module): model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) vae = VAE(sd=vae_sd) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 5baf4bca6c6..58535a9fbf8 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -21,6 +21,7 @@ class BASE: noise_aug_config = None sampling_settings = {} latent_format = latent_formats.LatentFormat + vae_key_prefix = ["first_stage_model."] manual_cast_dtype = None From 29558fb3acc984979609d671d458128f69ccc1fc Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:59:47 +0000 Subject: [PATCH 06/64] Fix crash when no widgets on customized group node --- web/extensions/core/groupNode.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 0f041fcd2f9..0b0763d1d49 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -910,6 +910,9 @@ export class GroupNodeHandler { const self = this; const onNodeCreated = this.node.onNodeCreated; this.node.onNodeCreated = function () { + if (!this.widgets) { + return; + } const config = self.groupData.nodeData.config; if (config) { for (const n in config) { From af6165ab691210188d1792369d8b07a8ed6f2228 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 30 Jan 2024 18:00:01 +0000 Subject: [PATCH 07/64] Fix scrolling with lots of nodes --- web/extensions/core/groupNodeManage.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/groupNodeManage.css b/web/extensions/core/groupNodeManage.css index 5ac89aee31b..5470ecb5e67 100644 --- a/web/extensions/core/groupNodeManage.css +++ b/web/extensions/core/groupNodeManage.css @@ -48,7 +48,7 @@ list-style: none; } .comfy-group-manage-list-items { - max-height: 70vh; + max-height: calc(100% - 40px); overflow-y: scroll; overflow-x: hidden; } From 6565c9ad4dcd64238a86e16e5605fed446069952 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 31 Jan 2024 02:26:27 -0500 Subject: [PATCH 08/64] Litegraph node search improvements. See: https://github.com/comfyanonymous/litegraph.js/pull/5 --- web/lib/litegraph.core.js | 15 +++++++++++++-- web/lib/litegraph.css | 13 +++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 080e0ef47da..4aae889ef4e 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -11910,7 +11910,7 @@ LGraphNode.prototype.executeAction = function(action) var ctor = LiteGraph.registered_node_types[ type ]; if(filter && ctor.filter != filter ) return false; - if ((!options.show_all_if_empty || str) && type.toLowerCase().indexOf(str) === -1) + if ((!options.show_all_if_empty || str) && type.toLowerCase().indexOf(str) === -1 && (!ctor.title || ctor.title.toLowerCase().indexOf(str) === -1)) return false; // filter by slot IN, OUT types @@ -11964,7 +11964,18 @@ LGraphNode.prototype.executeAction = function(action) if (!first) { first = type; } - help.innerText = type; + + const nodeType = LiteGraph.registered_node_types[type]; + if (nodeType?.title) { + help.innerText = nodeType?.title; + const typeEl = document.createElement("span"); + typeEl.className = "litegraph lite-search-item-type"; + typeEl.textContent = type; + help.append(typeEl); + } else { + help.innerText = type; + } + help.dataset["type"] = escape(type); help.className = "litegraph lite-search-item"; if (className) { diff --git a/web/lib/litegraph.css b/web/lib/litegraph.css index 918858f415d..5524e24bacb 100644 --- a/web/lib/litegraph.css +++ b/web/lib/litegraph.css @@ -184,6 +184,7 @@ color: white; padding-left: 10px; margin-right: 5px; + max-width: 300px; } .litegraph.litesearchbox .name { @@ -227,6 +228,18 @@ color: black; } +.litegraph.lite-search-item-type { + display: inline-block; + background: rgba(0,0,0,0.2); + margin-left: 5px; + font-size: 14px; + padding: 2px 5px; + position: relative; + top: -2px; + opacity: 0.8; + border-radius: 4px; + } + /* DIALOGs ******/ .litegraph .dialog { From c5a369a33ddb622827552716d9b0119035a2e666 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 31 Jan 2024 02:27:12 -0500 Subject: [PATCH 09/64] Update readme for new pytorch 2.2 release. --- README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 0938739215d..ff3ab64204e 100644 --- a/README.md +++ b/README.md @@ -95,16 +95,15 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your VAE in: models/vae -Note: pytorch stable does not support python 3.12 yet. If you have python 3.12 you will have to use the nightly version of pytorch. If you run into issues you should try python 3.11 instead. ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7``` -This is the command to install the nightly with ROCm 5.7 which has a python 3.12 package and might have some performance improvements: +This is the command to install the nightly with ROCm 6.0 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0``` ### NVIDIA @@ -112,7 +111,7 @@ Nvidia users should install stable pytorch using this command: ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121``` -This is the command to install pytorch nightly instead which has a python 3.12 package and might have performance improvements: +This is the command to install pytorch nightly instead which might have performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121``` From 6ab42054229cfea7fbab63e6b51ae295c2e3fc49 Mon Sep 17 00:00:00 2001 From: "Lt.Dr.Data" Date: Wed, 31 Jan 2024 18:28:36 +0900 Subject: [PATCH 10/64] feat: better pen support for mask editor - alt-drag: erase - shift-drag(up/down): zoom in/out --- web/extensions/core/maskeditor.js | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index f6b035bdce1..cd7d904e18e 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -521,6 +521,19 @@ class MaskEditorDialog extends ComfyDialog { event.preventDefault(); self.pan_move(self, event); } + + let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1; + + if(event.shiftKey && left_button_down) { + self.drawing_mode = false; + + const y = event.clientY; + let delta = (self.zoom_lasty - y)*0.005; + self.zoom_ratio = Math.max(Math.min(10.0, self.last_zoom_ratio - delta), 0.2); + + this.invalidatePanZoom(); + return; + } } pan_move(self, event) { @@ -538,7 +551,7 @@ class MaskEditorDialog extends ComfyDialog { } draw_move(self, event) { - if(event.ctrlKey) { + if(event.ctrlKey || event.shiftKey) { return; } @@ -549,7 +562,10 @@ class MaskEditorDialog extends ComfyDialog { self.updateBrushPreview(self); - if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) { + let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1; + let right_button_down = [2, 5, 32].includes(event.buttons); + + if (!event.altKey && left_button_down) { var diff = performance.now() - self.lasttime; const maskRect = self.maskCanvas.getBoundingClientRect(); @@ -616,7 +632,7 @@ class MaskEditorDialog extends ComfyDialog { self.lasttime = performance.now(); } - else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { + else if((event.altKey && left_button_down) || right_button_down) { const maskRect = self.maskCanvas.getBoundingClientRect(); const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; @@ -690,12 +706,19 @@ class MaskEditorDialog extends ComfyDialog { self.drawing_mode = true; event.preventDefault(); + + if(event.shiftKey) { + self.zoom_lasty = event.clientY; + self.last_zoom_ratio = self.zoom_ratio; + return; + } + const maskRect = self.maskCanvas.getBoundingClientRect(); const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; self.maskCtx.beginPath(); - if (event.button == 0) { + if (!event.altKey && event.button == 0) { self.maskCtx.fillStyle = "rgb(0,0,0)"; self.maskCtx.globalCompositeOperation = "source-over"; } else { From 53a22e1ab9df4385aae07d65d7cd2fc157e989c9 Mon Sep 17 00:00:00 2001 From: pksebben Date: Wed, 31 Jan 2024 16:14:50 -0800 Subject: [PATCH 11/64] add increment-wrap as option to ValueControlWidget when isCombo, which loops back to 0 when at end of list --- web/scripts/widgets.js | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 0529b1d80b5..678b1b8ec7a 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -81,6 +81,9 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando const isCombo = targetWidget.type === "combo"; let comboFilter; + if (isCombo) { + valueControl.options.values.push("increment-wrap"); + } if (isCombo && options.addFilterList !== false) { comboFilter = node.addWidget( "string", @@ -128,6 +131,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando case "increment": current_index += 1; break; + case "increment-wrap": + current_index += 1; + if ( current_index >= current_length ) { + current_index = 0; + } + break; case "decrement": current_index -= 1; break; @@ -295,7 +304,7 @@ export const ComfyWidgets = { let disable_rounding = app.ui.settings.getSettingValue("Comfy.DisableFloatRounding") if (precision == 0) precision = undefined; const { val, config } = getNumberDefaults(inputData, 0.5, precision, !disable_rounding); - return { widget: node.addWidget(widgetType, inputName, val, + return { widget: node.addWidget(widgetType, inputName, val, function (v) { if (config.round) { this.value = Math.round(v/config.round)*config.round; From 951a2064a34e3e2ab468942663fa59df6a212af3 Mon Sep 17 00:00:00 2001 From: Chaoses-Ib Date: Fri, 2 Feb 2024 13:27:03 +0800 Subject: [PATCH 12/64] Fix frontend webp prompt handling --- web/scripts/app.js | 2 ++ 1 file changed, 2 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index b3a84899331..c1461d259e9 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2107,6 +2107,8 @@ export class ComfyApp { this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node. } else if (pngInfo.prompt) { this.loadApiJson(JSON.parse(pngInfo.prompt)); + } else if (pngInfo.Prompt) { + this.loadApiJson(JSON.parse(pngInfo.Prompt)); // Support loading prompts from that webp custom node. } } } else if (file.type === "application/json" || file.name?.endsWith(".json")) { From f2bae7463e506048600093e6f0adf90cf89edc86 Mon Sep 17 00:00:00 2001 From: FizzleDorf <1fizzledorf@gmail.com> Date: Fri, 2 Feb 2024 18:31:35 +0900 Subject: [PATCH 13/64] changed default of LatentBatchSeedBehavior to fixed --- comfy_extras/nodes_latent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index b7fd8cd687f..eabae088516 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -126,7 +126,7 @@ class LatentBatchSeedBehavior: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), - "seed_behavior": (["random", "fixed"],),}} + "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} RETURN_TYPES = ("LATENT",) FUNCTION = "op" From 4b0239066daa0529bc18a1c932d4e8cd148b5ab5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Feb 2024 10:02:49 -0500 Subject: [PATCH 14/64] Always use fp16 for the text encoders. --- comfy/model_management.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index e12146d11b8..cbaa8087419 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -546,10 +546,8 @@ def text_encoder_dtype(device=None): if is_device_cpu(device): return torch.float16 - if should_use_fp16(device, prioritize_performance=False): - return torch.float16 - else: - return torch.float32 + return torch.float16 + def intermediate_device(): if args.gpu_only: From 5f3dbede5855c239709d2774f93af9aad3f7b18d Mon Sep 17 00:00:00 2001 From: ultimabear Date: Sat, 3 Feb 2024 10:29:44 +0300 Subject: [PATCH 15/64] Mask editor: semitransparent brush, brush color modes --- web/extensions/core/maskeditor.js | 121 +++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 12 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index cd7d904e18e..4f69ac7607c 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -62,7 +62,7 @@ async function uploadMask(filepath, formData) { ClipspaceDialog.invalidatePreview(); } -function prepare_mask(image, maskCanvas, maskCtx) { +function prepare_mask(image, maskCanvas, maskCtx, maskColor) { // paste mask data into alpha channel maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height); const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height); @@ -74,9 +74,9 @@ function prepare_mask(image, maskCanvas, maskCtx) { else maskData.data[i+3] = 255; - maskData.data[i] = 0; - maskData.data[i+1] = 0; - maskData.data[i+2] = 0; + maskData.data[i] = maskColor.r; + maskData.data[i+1] = maskColor.g; + maskData.data[i+2] = maskColor.b; } maskCtx.globalCompositeOperation = 'source-over'; @@ -194,14 +194,29 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(bottom_panel); document.body.appendChild(brush); + var clearButton = this.createLeftButton("Clear", () => { + self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); + }); + this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { self.brush_size = event.target.value; self.updateBrushPreview(self, null, null); }); - var clearButton = this.createLeftButton("Clear", - () => { - self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); - }); + + this.colorButton = this.createLeftButton(this.getColorButtonText(), () => { + if (self.brush_color_mode === "black") { + self.brush_color_mode = "white"; + } + else if (self.brush_color_mode === "white") { + self.brush_color_mode = "negative"; + } + else { + self.brush_color_mode = "black"; + } + + self.updateWhenBrushColorModeChanged(); + }); + var cancelButton = this.createRightButton("Cancel", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); @@ -222,6 +237,7 @@ class MaskEditorDialog extends ComfyDialog { bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(this.brush_size_slider); + bottom_panel.appendChild(this.colorButton); imgCanvas.style.position = "absolute"; maskCanvas.style.position = "absolute"; @@ -231,6 +247,10 @@ class MaskEditorDialog extends ComfyDialog { maskCanvas.style.top = imgCanvas.style.top; maskCanvas.style.left = imgCanvas.style.left; + + const maskCanvasStyle = this.getMaskCanvasStyle(); + maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode; + maskCanvas.style.opacity = maskCanvasStyle.opacity; } async show() { @@ -316,7 +336,7 @@ class MaskEditorDialog extends ComfyDialog { let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true }); imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height); - prepare_mask(mask_image, this.maskCanvas, maskCtx); + prepare_mask(mask_image, this.maskCanvas, maskCtx, this.getMaskColor()); } async setImages(imgCanvas) { @@ -442,7 +462,84 @@ class MaskEditorDialog extends ComfyDialog { } } + getMaskCanvasStyle() { + if (this.brush_color_mode === "negative") { + return { + mixBlendMode: "difference", + opacity: "1", + }; + } + else { + return { + mixBlendMode: "initial", + opacity: "0.7", + }; + } + } + + getMaskColor() { + if (this.brush_color_mode === "black") { + return { r: 0, g: 0, b: 0 }; + } + if (this.brush_color_mode === "white") { + return { r: 255, g: 255, b: 255 }; + } + if (this.brush_color_mode === "negative") { + // negative effect only works with white color + return { r: 255, g: 255, b: 255 }; + } + + return { r: 0, g: 0, b: 0 }; + } + + getMaskFillStyle() { + const maskColor = this.getMaskColor(); + + return "rgb(" + maskColor.r + "," + maskColor.g + "," + maskColor.b + ")"; + } + + getColorButtonText() { + let colorCaption = "unknown"; + + if (this.brush_color_mode === "black") { + colorCaption = "black"; + } + else if (this.brush_color_mode === "white") { + colorCaption = "white"; + } + else if (this.brush_color_mode === "negative") { + colorCaption = "negative"; + } + + return "Color: " + colorCaption; + } + + updateWhenBrushColorModeChanged() { + this.colorButton.innerText = this.getColorButtonText(); + + // update mask canvas css styles + + const maskCanvasStyle = this.getMaskCanvasStyle(); + this.maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode; + this.maskCanvas.style.opacity = maskCanvasStyle.opacity; + + // update mask canvas rgb colors + + const maskColor = this.getMaskColor(); + + const maskData = this.maskCtx.getImageData(0, 0, this.maskCanvas.width, this.maskCanvas.height); + + for (let i = 0; i < maskData.data.length; i += 4) { + maskData.data[i] = maskColor.r; + maskData.data[i+1] = maskColor.g; + maskData.data[i+2] = maskColor.b; + } + + this.maskCtx.putImageData(maskData, 0, 0); + } + brush_size = 10; + brush_color_mode = "black"; drawing_mode = false; lastx = -1; lasty = -1; @@ -600,7 +697,7 @@ class MaskEditorDialog extends ComfyDialog { if(diff > 20 && !this.drawing_mode) requestAnimationFrame(() => { self.maskCtx.beginPath(); - self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.fillStyle = this.getMaskFillStyle(); self.maskCtx.globalCompositeOperation = "source-over"; self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); self.maskCtx.fill(); @@ -610,7 +707,7 @@ class MaskEditorDialog extends ComfyDialog { else requestAnimationFrame(() => { self.maskCtx.beginPath(); - self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.fillStyle = this.getMaskFillStyle(); self.maskCtx.globalCompositeOperation = "source-over"; var dx = x - self.lastx; @@ -719,7 +816,7 @@ class MaskEditorDialog extends ComfyDialog { self.maskCtx.beginPath(); if (!event.altKey && event.button == 0) { - self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.fillStyle = this.getMaskFillStyle(); self.maskCtx.globalCompositeOperation = "source-over"; } else { self.maskCtx.globalCompositeOperation = "destination-out"; From 24129d78e6ac5349389ca99349242a13cdedf1d2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Feb 2024 13:23:43 -0500 Subject: [PATCH 16/64] Speed up SDXL on 16xx series with fp16 weights and manual cast. --- comfy/model_management.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index cbaa8087419..aa40c502af5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -496,7 +496,7 @@ def unet_dtype(device=None, model_params=0): return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 - if should_use_fp16(device=device, model_params=model_params): + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): return torch.float16 return torch.float32 @@ -696,7 +696,7 @@ def is_device_mps(device): return True return False -def should_use_fp16(device=None, model_params=0, prioritize_performance=True): +def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled if device is not None: @@ -738,7 +738,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if x in props.name.lower(): fp16_works = True - if fp16_works: + if fp16_works or manual_cast: free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True From 66e28ef45c02437c1ca6a31afbe5f399eda15256 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Feb 2024 20:53:35 -0500 Subject: [PATCH 17/64] Don't use is_bf16_supported to check for fp16 support. --- comfy/model_management.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index aa40c502af5..a8dc91b9ecf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -722,10 +722,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True - if torch.cuda.is_bf16_supported(): + if torch.version.hip: return True props = torch.cuda.get_device_properties("cuda") + if props.major >= 8: + return True + if props.major < 6: return False From 74b7233f57301bb08c2b29fb420eeacf8757d41c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Feb 2024 23:15:49 -0500 Subject: [PATCH 18/64] Document IS_CHANGED in the example custom node. --- custom_nodes/example_node.py.example | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 733014f3c7d..7ce271ec617 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -6,6 +6,8 @@ class Example: ------------- INPUT_TYPES (dict): Tell the main program input parameters of nodes. + IS_CHANGED: + optional method to control when the node is re executed. Attributes ---------- @@ -89,6 +91,17 @@ class Example: image = 1.0 - image return (image,) + """ + The node will always be re executed if any of the inputs change but + this method can be used to force the node to execute again even when the inputs don't change. + You can make this node return a number or a string. This value will be compared to the one returned the last time the node was + executed, if it is different the node will be executed again. + This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash + changes between executions the LoadImage node is executed again. + """ + #@classmethod + #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): + # return "" # A dictionary that contains all nodes you want to export with their names # NOTE: names should be globally unique From 236bda26830d719843ba9b5703894297f67f6704 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 5 Feb 2024 01:29:26 -0500 Subject: [PATCH 19/64] Make minimum tile size the size of the overlap. --- comfy/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index f8026ddab9d..1113bf0f52f 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -413,6 +413,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): + x = max(0, min(s.shape[-1] - overlap, x)) + y = max(0, min(s.shape[-2] - overlap, y)) s_in = s[:,:,y:y+tile_y,x:x+tile_x] ps = function(s_in).to(output_device) From d2e7f1b04b729aa5c5a5633ce1130bb47828b4b4 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 6 Feb 2024 16:55:55 +0000 Subject: [PATCH 20/64] Support linking converted inputs from api json --- web/extensions/core/widgetInputs.js | 6 ++++++ web/scripts/app.js | 13 +++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 3f1c1f8c126..b12ad968f4f 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -260,6 +260,12 @@ app.registerExtension({ async beforeRegisterNodeDef(nodeType, nodeData, app) { // Add menu options to conver to/from widgets const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; + nodeType.prototype.convertWidgetToInput = function (widget) { + const config = getConfig.call(this, widget.name) ?? [widget.type, widget.options || {}]; + if (!isConvertableWidget(widget, config)) return false; + convertToInput(this, widget, config); + return true; + }; nodeType.prototype.getExtraMenuOptions = function (_, options) { const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : undefined; diff --git a/web/scripts/app.js b/web/scripts/app.js index c1461d259e9..77f29b8e5b1 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2162,8 +2162,17 @@ export class ComfyApp { if (value instanceof Array) { const [fromId, fromSlot] = value; const fromNode = app.graph.getNodeById(fromId); - const toSlot = node.inputs?.findIndex((inp) => inp.name === input); - if (toSlot !== -1) { + let toSlot = node.inputs?.findIndex((inp) => inp.name === input); + if (toSlot == null || toSlot === -1) { + try { + // Target has no matching input, most likely a converted widget + const widget = node.widgets?.find((w) => w.name === input); + if (widget && node.convertWidgetToInput?.(widget)) { + toSlot = node.inputs?.length - 1; + } + } catch (error) {} + } + if (toSlot != null || toSlot !== -1) { fromNode.connect(fromSlot, node, toSlot); } } else { From 7daad468ec945aabfbf3f502c6c059bfc818014d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Feb 2024 12:43:06 -0500 Subject: [PATCH 21/64] Sync litegraph to repo. https://github.com/comfyanonymous/litegraph.js/pull/6 --- web/lib/litegraph.core.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 4aae889ef4e..4ff05ae8130 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -11549,7 +11549,7 @@ LGraphNode.prototype.executeAction = function(action) dialog.close(); } else if (e.keyCode == 13) { if (selected) { - select(selected.innerHTML); + select(unescape(selected.dataset["type"])); } else if (first) { select(first); } else { From c661a8b118d727f1841b2df75f343d5e40d52728 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Feb 2024 18:52:51 -0500 Subject: [PATCH 22/64] Don't use numpy for calculating sigmas. --- comfy/ldm/modules/diffusionmodules/util.py | 4 ++-- comfy/model_sampling.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index ac7e27173bd..5a6aa7d77d1 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -98,7 +98,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) + betas = torch.clamp(betas, min=0, max=0.999) elif schedule == "squaredcos_cap_v2": # used for karlo prior # return early @@ -113,7 +113,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() + return betas def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index cc8745c1064..d5870027b9b 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -1,5 +1,4 @@ import torch -import numpy as np from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule import math @@ -42,8 +41,7 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps else: betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas - alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) - # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod = torch.cumprod(alphas, dim=0) timesteps, = betas.shape self.num_timesteps = int(timesteps) @@ -58,8 +56,8 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps self.set_sigmas(sigmas) def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) + self.register_buffer('sigmas', sigmas.float()) + self.register_buffer('log_sigmas', sigmas.log().float()) @property def sigma_min(self): From a352c021eca1d39c3fc128a6ee65c0527f5c6f3c Mon Sep 17 00:00:00 2001 From: blepping Date: Thu, 8 Feb 2024 02:24:23 -0700 Subject: [PATCH 23/64] Allow custom samplers to request discard penultimate sigma --- comfy/samplers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f4c3e268f73..5dd72f3faed 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -652,6 +652,7 @@ def sampler_object(name): class KSampler: SCHEDULERS = SCHEDULER_NAMES SAMPLERS = SAMPLER_NAMES + DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2')) def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -670,7 +671,7 @@ def calculate_sigmas(self, steps): sigmas = None discard_penultimate_sigma = False - if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']: + if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS: steps += 1 discard_penultimate_sigma = True From 2ccc0be28f64474191b46aa9dd5bb27070a9c1df Mon Sep 17 00:00:00 2001 From: Imran Azeez Date: Thu, 8 Feb 2024 22:01:56 +1000 Subject: [PATCH 24/64] Add batch number to filename with %batch_num% Allow configurable addition of batch number to output file name. --- nodes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index fe38be9dfef..0cce5f51a8a 100644 --- a/nodes.py +++ b/nodes.py @@ -1434,7 +1434,7 @@ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pngi filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() - for image in images: + for (batch_number, image) in enumerate(images): i = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) metadata = None @@ -1446,7 +1446,8 @@ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pngi for x in extra_pnginfo: metadata.add_text(x, json.dumps(extra_pnginfo[x])) - file = f"{filename}_{counter:05}_.png" + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.png" img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) results.append({ "filename": file, From 25a4805e519ea97110651ae5bb1d7c0e6644b26f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Feb 2024 14:13:31 -0500 Subject: [PATCH 25/64] Add a way to set different conditioning for the controlnet. --- comfy/controlnet.py | 2 +- comfy/model_base.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 82170431ef2..d9d990a7166 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -166,7 +166,7 @@ def get_control(self, x_noisy, t, cond, batched_number): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] + context = cond.get('crossattn_controlnet', cond['c_crossattn']) y = cond.get('y', None) if y is not None: y = y.to(dtype) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8a843a98c39..aafb88e05c6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -153,6 +153,10 @@ def blank_inpaint_image_like(latent_image): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) + if cross_attn_cnet is not None: + out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet) + return out def load_model_weights(self, sd, unet_prefix=""): From f44225fd5f433daf78484b9c21b9b777bea04220 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 9 Feb 2024 17:11:34 -0600 Subject: [PATCH 26/64] Fix infinite while loop being possible in ddim_scheduler --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f4c3e268f73..f2ac3c5dbb1 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -295,7 +295,7 @@ def simple_scheduler(model, steps): def ddim_scheduler(model, steps): s = model.model_sampling sigs = [] - ss = len(s.sigmas) // steps + ss = max(len(s.sigmas) // steps, 1) x = 1 while x < len(s.sigmas): sigs += [float(s.sigmas[x])] From 20e3da6b313feaac07c34a4cc746e5da931f7c76 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 10 Feb 2024 08:27:05 -0500 Subject: [PATCH 27/64] Add a node to give the controlnet a prompt different from the unet. --- comfy_extras/nodes_cond.py | 25 +++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 26 insertions(+) create mode 100644 comfy_extras/nodes_cond.py diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py new file mode 100644 index 00000000000..646fefa1746 --- /dev/null +++ b/comfy_extras/nodes_cond.py @@ -0,0 +1,25 @@ + + +class CLIPTextEncodeControlnet: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "_for_testing/conditioning" + + def encode(self, clip, conditioning, text): + tokens = clip.tokenize(text) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['cross_attn_controlnet'] = cond + n[1]['pooled_output_controlnet'] = pooled + c.append(n) + return (c, ) + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet +} diff --git a/nodes.py b/nodes.py index fe38be9dfef..d9bc4884aef 100644 --- a/nodes.py +++ b/nodes.py @@ -1965,6 +1965,7 @@ def init_custom_nodes(): "nodes_stable3d.py", "nodes_sdupscale.py", "nodes_photomaker.py", + "nodes_cond.py", ] for node_file in extras_files: From 02409c30d9ea5314e5103d03f7c9933fa1012659 Mon Sep 17 00:00:00 2001 From: Steven Lu Date: Mon, 12 Feb 2024 01:44:53 +0700 Subject: [PATCH 28/64] Safari: Draws certain elements on CPU. In case of search popup, can cause 10 seconds+ main thread lock due to painting. (#2763) * lets toggle this setting first. * also makes it easier for debug. I'll be honest this is generally preferred behavior as well for me but I ain't no power user shrug. * attempting trick to put the work for filter: brightness on GPU as a first attempt before falling back to not using filter for large lists! * revert litegraph.core.js changes from branch * oops --- web/style.css | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/web/style.css b/web/style.css index 863840b2866..cf7a8b9ea2d 100644 --- a/web/style.css +++ b/web/style.css @@ -197,6 +197,7 @@ button.comfy-close-menu-btn { .comfy-modal button:hover, .comfy-menu-actions button:hover { filter: brightness(1.2); + will-change: transform; cursor: pointer; } @@ -462,11 +463,13 @@ dialog::backdrop { z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; filter: brightness(95%); + will-change: transform; } .litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) { background-color: var(--comfy-menu-bg) !important; filter: brightness(155%); + will-change: transform; color: var(--input-text); } @@ -527,12 +530,14 @@ dialog::backdrop { color: var(--input-text); background-color: var(--comfy-input-bg); filter: brightness(80%); + will-change: transform; padding-left: 0.2em; } .litegraph.lite-search-item.generic_type { color: var(--input-text); filter: brightness(50%); + will-change: transform; } @media only screen and (max-width: 450px) { @@ -551,4 +556,4 @@ dialog::backdrop { text-align: center; border-top: none; } -} \ No newline at end of file +} From cf4910a3a451ad9e2e5261749a5a44acdcf7bbec Mon Sep 17 00:00:00 2001 From: chrisgoringe Date: Mon, 12 Feb 2024 08:59:25 +1100 Subject: [PATCH 29/64] Prevent hideWidget being called twice for same widget Fix for #2766 --- web/extensions/core/widgetInputs.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index b12ad968f4f..23f51d812b4 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -22,6 +22,7 @@ function isConvertableWidget(widget, config) { } function hideWidget(node, widget, suffix = "") { + if (widget.type?.startsWith(CONVERTED_TYPE)) return; widget.origType = widget.type; widget.origComputeSize = widget.computeSize; widget.origSerializeValue = widget.serializeValue; From 0c9bc19768683c9e2772bd75e7bf823f976ccfba Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 12 Feb 2024 12:46:15 -0500 Subject: [PATCH 30/64] Add ImageFromBatch. --- comfy_extras/nodes_images.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index aa80f5269a3..8f638bf8fc1 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -48,6 +48,25 @@ def repeat(self, image, amount): s = image.repeat((amount, 1,1,1)) return (s,) +class ImageFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "frombatch" + + CATEGORY = "image/batch" + + def frombatch(self, image, batch_index, length): + s_in = image + batch_index = min(s_in.shape[0] - 1, batch_index) + length = min(s_in.shape[0] - batch_index, length) + s = s_in[batch_index:batch_index + length].clone() + return (s,) + class SaveAnimatedWEBP: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -170,6 +189,7 @@ def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", pr NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, + "ImageFromBatch": ImageFromBatch, "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, } From 38b7ac6e269e6ecc5bdd6fefdfb2fb1185b09c9d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 13 Feb 2024 00:01:08 -0500 Subject: [PATCH 31/64] Don't init the CLIP model when the checkpoint has no CLIP weights. --- comfy/sd.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c15d73fed5e..5b22d1178fc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -470,10 +470,13 @@ class WeightsLoader(torch.nn.Module): w = WeightsLoader() clip_target = model_config.clip_target() if clip_target is not None: - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model sd = model_config.process_clip_state_dict(sd) - load_model_weights(w, sd) + if any(k.startswith('cond_stage_model.') for k in sd): + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_model_weights(w, sd) + else: + print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") left_over = sd.keys() if len(left_over) > 0: From 7f89cb48bf8200254cde1306ba60d10ca019264d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 14 Feb 2024 02:59:53 -0500 Subject: [PATCH 32/64] Add a disabled SaveImageWebsocket custom node. This node can be used to efficiently get images without saving them to disk when using ComfyUI as a backend. --- custom_nodes/websocket_image_save.py.disabled | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 custom_nodes/websocket_image_save.py.disabled diff --git a/custom_nodes/websocket_image_save.py.disabled b/custom_nodes/websocket_image_save.py.disabled new file mode 100644 index 00000000000..b85a5de8be0 --- /dev/null +++ b/custom_nodes/websocket_image_save.py.disabled @@ -0,0 +1,49 @@ +from PIL import Image, ImageOps +from io import BytesIO +import numpy as np +import struct +import comfy.utils +import time + +#You can use this node to save full size images through the websocket, the +#images will be sent in exactly the same format as the image previews: as +#binary images on the websocket with a 8 byte header indicating the type +#of binary message (first 4 bytes) and the image format (next 4 bytes). + +#The reason this node is disabled by default is because there is a small +#issue when using it with the default ComfyUI web interface: When generating +#batches only the last image will be shown in the UI. + +#Note that no metadata will be put in the images saved with this node. + +class SaveImageWebsocket: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ),} + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "image" + + def save_images(self, images): + pbar = comfy.utils.ProgressBar(images.shape[0]) + step = 0 + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pbar.update_absolute(step, images.shape[0], ("PNG", img, None)) + step += 1 + + return {} + + def IS_CHANGED(s, images): + return time.time() + +NODE_CLASS_MAPPINGS = { + "SaveImageWebsocket": SaveImageWebsocket, +} From aeaeca10bd7cf6e40d6e71f1089c594e0fab5a99 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 15 Feb 2024 21:10:10 -0500 Subject: [PATCH 33/64] Small refactor of is_device_* functions. --- comfy/model_management.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a8dc91b9ecf..0b3f6ead634 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -684,17 +684,20 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS -def is_device_cpu(device): +def is_device_type(device, type): if hasattr(device, 'type'): - if (device.type == 'cpu'): + if (device.type == type): return True return False +def is_device_cpu(device): + return is_device_type(device, 'cpu') + def is_device_mps(device): - if hasattr(device, 'type'): - if (device.type == 'mps'): - return True - return False + return is_device_type(device, 'mps') + +def is_device_cuda(device): + return is_device_type(device, 'cuda') def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled From 5e06baf112c0ccccbf51f0249abb3e121147f2d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 16 Feb 2024 06:30:39 -0500 Subject: [PATCH 34/64] Stable Cascade Stage A. --- comfy/ldm/cascade/stage_a.py | 254 +++++++++++++++++++++++++++++++++++ comfy/sd.py | 24 +++- 2 files changed, 272 insertions(+), 6 deletions(-) create mode 100644 comfy/ldm/cascade/stage_a.py diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py new file mode 100644 index 00000000000..55fdbf17dfc --- /dev/null +++ b/comfy/ldm/cascade/stage_a.py @@ -0,0 +1,254 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +from torch.autograd import Function + +class vector_quantize(Function): + @staticmethod + def forward(ctx, x, codebook): + with torch.no_grad(): + codebook_sqr = torch.sum(codebook ** 2, dim=1) + x_sqr = torch.sum(x ** 2, dim=1, keepdim=True) + + dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) + _, indices = dist.min(dim=1) + + ctx.save_for_backward(indices, codebook) + ctx.mark_non_differentiable(indices) + + nn = torch.index_select(codebook, 0, indices) + return nn, indices + + @staticmethod + def backward(ctx, grad_output, grad_indices): + grad_inputs, grad_codebook = None, None + + if ctx.needs_input_grad[0]: + grad_inputs = grad_output.clone() + if ctx.needs_input_grad[1]: + # Gradient wrt. the codebook + indices, codebook = ctx.saved_tensors + + grad_codebook = torch.zeros_like(codebook) + grad_codebook.index_add_(0, indices, grad_output) + + return (grad_inputs, grad_codebook) + + +class VectorQuantize(nn.Module): + def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): + """ + Takes an input of variable size (as long as the last dimension matches the embedding size). + Returns one tensor containing the nearest neigbour embeddings to each of the inputs, + with the same size as the input, vq and commitment components for the loss as a touple + in the second output and the indices of the quantized vectors in the third: + quantized, (vq_loss, commit_loss), indices + """ + super(VectorQuantize, self).__init__() + + self.codebook = nn.Embedding(k, embedding_size) + self.codebook.weight.data.uniform_(-1./k, 1./k) + self.vq = vector_quantize.apply + + self.ema_decay = ema_decay + self.ema_loss = ema_loss + if ema_loss: + self.register_buffer('ema_element_count', torch.ones(k)) + self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight)) + + def _laplace_smoothing(self, x, epsilon): + n = torch.sum(x) + return ((x + epsilon) / (n + x.size(0) * epsilon) * n) + + def _updateEMA(self, z_e_x, indices): + mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() + elem_count = mask.sum(dim=0) + weight_sum = torch.mm(mask.t(), z_e_x) + + self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) + self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) + self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) + + self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) + + def idx2vq(self, idx, dim=-1): + q_idx = self.codebook(idx) + if dim != -1: + q_idx = q_idx.movedim(-1, dim) + return q_idx + + def forward(self, x, get_losses=True, dim=-1): + if dim != -1: + x = x.movedim(dim, -1) + z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x + z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) + vq_loss, commit_loss = None, None + if self.ema_loss and self.training: + self._updateEMA(z_e_x.detach(), indices.detach()) + # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss + z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) + if get_losses: + vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() + commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() + + z_q_x = z_q_x.view(x.shape) + if dim != -1: + z_q_x = z_q_x.movedim(-1, dim) + return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) + + +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + + +class StageA(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, + scale_factor=0.43): # 0.3764 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x, quantize=False): + x = self.in_block(x) + x = self.down_blocks(x) + if quantize: + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + else: + return x / self.scale_factor + + def decode(self, x): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/comfy/sd.py b/comfy/sd.py index 5b22d1178fc..5e37cff915a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -2,6 +2,8 @@ from comfy import model_management from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine +from .ldm.cascade.stage_a import StageA + import yaml import comfy.utils @@ -156,6 +158,8 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.downscale_ratio = 8 self.latent_channels = 4 + self.process_input = lambda image: image * 2.0 - 1.0 + self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -168,6 +172,14 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: self.first_stage_model = comfy.taesd.taesd.TAESD() + elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade + self.first_stage_model = StageA() + self.downscale_ratio = 4 + #TODO + #self.memory_used_encode + #self.memory_used_decode + self.process_input = lambda image: image + self.process_output = lambda image: image else: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -206,12 +218,12 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() - output = torch.clamp(( + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + output = self.process_output( (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar)) - / 3.0) / 2.0, min=0.0, max=1.0) + / 3.0) return output def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -220,7 +232,7 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -238,7 +250,7 @@ def decode(self, samples_in): pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) + pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) @@ -261,7 +273,7 @@ def encode(self, pixel_samples): batch_number = max(1, batch_number) samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) + pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: From f83109f09bec04f39f028c275b4eb1231adba00a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 16 Feb 2024 10:55:08 -0500 Subject: [PATCH 35/64] Stable Cascade Stage C. --- comfy/controlnet.py | 17 ++- comfy/latent_formats.py | 6 + comfy/ldm/cascade/common.py | 161 ++++++++++++++++++++ comfy/ldm/cascade/stage_c.py | 271 +++++++++++++++++++++++++++++++++ comfy/model_base.py | 40 ++++- comfy/model_detection.py | 30 +++- comfy/model_management.py | 34 ++++- comfy/model_sampling.py | 30 ++++ comfy/sd.py | 20 +-- comfy/supported_models.py | 35 ++++- comfy/supported_models_base.py | 6 +- 11 files changed, 619 insertions(+), 31 deletions(-) create mode 100644 comfy/ldm/cascade/common.py create mode 100644 comfy/ldm/cascade/stage_c.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d9d990a7166..416197586a1 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -318,9 +318,10 @@ def load_controlnet(ckpt_path, model=None): return ControlLora(controlnet_data) controlnet_config = None + supported_inference_dtypes = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - unet_dtype = comfy.model_management.unet_dtype() - controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -380,12 +381,20 @@ def load_controlnet(ckpt_path, model=None): return net if controlnet_config is None: - unet_dtype = comfy.model_management.unet_dtype() - controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) + supported_inference_dtypes = model_config.supported_inference_dtypes + controlnet_config = model_config.unet_config + load_device = comfy.model_management.get_torch_device() + if supported_inference_dtypes is None: + unet_dtype = comfy.model_management.unet_dtype() + else: + unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) if manual_cast_dtype is not None: controlnet_config["operations"] = comfy.ops.manual_cast + controlnet_config["dtype"] = unet_dtype controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 2252a075ed5..8ba767372b6 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -37,3 +37,9 @@ def __init__(self): class SD_X4(LatentFormat): def __init__(self): self.scale_factor = 0.08333 + +class SC_Prior(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 + + diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py new file mode 100644 index 00000000000..3c5bf99771e --- /dev/null +++ b/comfy/ldm/cascade/common.py @@ -0,0 +1,161 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import torch.nn as nn +from comfy.ldm.modules.attention import optimized_attention + +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + +class OptimizedAttention(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.heads = nhead + + self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + def forward(self, q, k, v): + q = self.to_q(q) + k = self.to_k(k) + v = self.to_v(v) + + out = optimized_attention(q, k, v, self.heads) + + return self.out_proj(out) + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) + # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + # x = self.attn(x, kv, kv, need_weights=False)[0] + x = self.attn(x, kv, kv) + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +def LayerNorm2d_op(operations): + class LayerNorm2d(operations.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return LayerNorm2d + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + operations.Linear(c_cond, c, dtype=dtype, device=device) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None): + super().__init__() + self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py new file mode 100644 index 00000000000..2e0f47068b4 --- /dev/null +++ b/comfy/ldm/cascade/stage_c.py @@ -0,0 +1,271 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +import numpy as np +import math +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock +# from .controlnet import ControlNetDeliverer + +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None, + dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device) + self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pooled = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + # Model Blocks + x = self.embedding(x) + if cnet is not None: + cnet = ControlNetDeliverer(cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet) + x = self._up_decode(level_outputs, r_embed, clip, cnet) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/comfy/model_base.py b/comfy/model_base.py index aafb88e05c6..dde2a2f36e3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,5 +1,6 @@ import torch from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep +from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation import comfy.model_management @@ -12,9 +13,10 @@ class ModelType(Enum): EPS = 1 V_PREDICTION = 2 V_PREDICTION_EDM = 3 + STABLE_CASCADE = 4 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM +from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling def model_sampling(model_config, model_type): @@ -27,6 +29,9 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.V_PREDICTION_EDM: c = V_PREDICTION s = ModelSamplingContinuousEDM + elif model_type == ModelType.STABLE_CASCADE: + c = EPS + s = StableCascadeSampling class ModelSampling(s, c): pass @@ -35,7 +40,7 @@ class ModelSampling(s, c): class BaseModel(torch.nn.Module): - def __init__(self, model_config, model_type=ModelType.EPS, device=None): + def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() unet_config = model_config.unet_config @@ -48,7 +53,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None): operations = comfy.ops.manual_cast else: operations = comfy.ops.disable_weight_init - self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) + self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -427,3 +432,32 @@ def extra_conds(self, **kwargs): out['c_concat'] = comfy.conds.CONDNoiseShape(image) out['y'] = comfy.conds.CONDRegular(noise_level) return out + +class StableCascade_C(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageC) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + + if "unclip_conditioning" in kwargs: + embeds = [] + for unclip_cond in kwargs["unclip_conditioning"]: + weight = unclip_cond["strength"] + embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight) + clip_img = torch.cat(embeds, dim=1) + else: + clip_img = torch.zeros((1, 1, 768)) + out["clip_img"] = comfy.conds.CONDRegular(clip_img) + out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) + out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,))) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn) + return out + diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ea824c44ca1..8d4fb7b66f5 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -28,9 +28,26 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack return None -def detect_unet_config(state_dict, key_prefix, dtype): +def detect_unet_config(state_dict, key_prefix): state_dict_keys = list(state_dict.keys()) + if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade + unet_config = {} + text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix) + if text_mapper_name in state_dict_keys: + unet_config['stable_cascade_stage'] = 'c' + w = state_dict[text_mapper_name] + if w.shape[0] == 1536: #stage c lite + unet_config['c_cond'] = 1536 + unet_config['c_hidden'] = [1536, 1536] + unet_config['nhead'] = [24, 24] + unet_config['blocks'] = [[4, 12], [12, 4]] + elif w.shape[0] == 2048: #stage c full + unet_config['c_cond'] = 2048 + elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: + unet_config['stable_cascade_stage'] = 'b' + return unet_config + unet_config = { "use_checkpoint": False, "image_size": 32, @@ -45,7 +62,6 @@ def detect_unet_config(state_dict, key_prefix, dtype): else: unet_config["adm_in_channels"] = None - unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] @@ -159,8 +175,8 @@ def model_config_from_unet_config(unet_config): print("no match", unet_config) return None -def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False): - unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype) +def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): + unet_config = detect_unet_config(state_dict, unet_key_prefix) model_config = model_config_from_unet_config(unet_config) if model_config is None and use_base_if_no_match: return comfy.supported_models_base.BASE(unet_config) @@ -206,7 +222,7 @@ def convert_config(unet_config): return new_config -def unet_config_from_diffusers_unet(state_dict, dtype): +def unet_config_from_diffusers_unet(state_dict, dtype=None): match = {} transformer_depth = [] @@ -313,8 +329,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype): return convert_config(unet_config) return None -def model_config_from_diffusers_unet(state_dict, dtype): - unet_config = unet_config_from_diffusers_unet(state_dict, dtype) +def model_config_from_diffusers_unet(state_dict): + unet_config = unet_config_from_diffusers_unet(state_dict) if unet_config is not None: return model_config_from_unet_config(unet_config) return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 0b3f6ead634..eb7178b4442 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev -def unet_dtype(device=None, model_params=0): +def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if args.bf16_unet: return torch.bfloat16 if args.fp16_unet: @@ -497,20 +497,31 @@ def unet_dtype(device=None, model_params=0): if args.fp8_e5m2_unet: return torch.float8_e5m2 if should_use_fp16(device=device, model_params=model_params, manual_cast=True): - return torch.float16 + if torch.float16 in supported_dtypes: + return torch.float16 + if should_use_bf16(device): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 return torch.float32 # None means no manual cast -def unet_manual_cast(weight_dtype, inference_device): +def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if weight_dtype == torch.float32: return None - fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False) + fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) if fp16_supported and weight_dtype == torch.float16: return None - if fp16_supported: + bf16_supported = should_use_bf16(inference_device) + if bf16_supported and weight_dtype == torch.bfloat16: + return None + + if fp16_supported and torch.float16 in supported_dtypes: return torch.float16 + + elif bf16_supported and torch.bfloat16 in supported_dtypes: + return torch.bfloat16 else: return torch.float32 @@ -760,6 +771,19 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return True +def should_use_bf16(device=None): + if is_intel_xpu(): + return True + + if device is None: + device = torch.device("cuda") + + props = torch.cuda.get_device_properties(device) + if props.major >= 8: + return True + + return False + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index d5870027b9b..b1fbf3e2113 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -132,3 +132,33 @@ def percent_to_sigma(self, percent): log_sigma_min = math.log(self.sigma_min) return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) + +class StableCascadeSampling(ModelSamplingDiscrete): + def __init__(self, model_config=None): + super().__init__() + self.num_timesteps = 1000 + cosine_s=8e-3 + self.cosine_s = torch.tensor([cosine_s]) + sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) + self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 + for x in range(self.num_timesteps): + t = x / self.num_timesteps + sigmas[x] = self.sigma(t) + + self.set_sigmas(sigmas) + + def sigma(self, timestep): + alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod).clamp(0.0001, 0.9999) + return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5 + + def timestep(self, sigma): + return super().timestep(sigma) / 1000.0 + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + + percent = 1.0 - percent + return self.sigma(torch.tensor(percent)) diff --git a/comfy/sd.py b/comfy/sd.py index 5e37cff915a..f3ec62b3a36 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -450,15 +450,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clip_target = None parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") - unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) class WeightsLoader(torch.nn.Module): pass - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) - model_config.set_manual_cast(manual_cast_dtype) + model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -507,16 +507,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) - if "input_blocks.0.0.weight" in sd: #ldm - model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) + if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade + model_config = model_detection.model_config_from_unet(sd, "") if model_config is None: return None new_sd = sd else: #diffusers - model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) + model_config = model_detection.model_config_from_diffusers_unet(sd) if model_config is None: return None @@ -528,8 +527,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format new_sd[diffusers_keys[k]] = sd.pop(k) else: print(diffusers_keys[k], k) + offload_device = model_management.unet_offload_device() - model_config.set_manual_cast(manual_cast_dtype) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1d442d4dd9c..a8863e72b0b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -306,5 +306,38 @@ def get_model(self, state_dict, prefix="", device=None): out = model_base.SD_X4Upscaler(self, device=device) return out -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler] +class Stable_Cascade_C(supported_models_base.BASE): + unet_config = { + "stable_cascade_stage": 'c', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_Prior + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def process_unet_state_dict(self, state_dict): + key_list = list(state_dict.keys()) + for y in ["weight", "bias"]: + suffix = "in_proj_{}".format(y) + keys = filter(lambda a: a.endswith(suffix), key_list) + for k_from in keys: + weights = state_dict.pop(k_from) + prefix = k_from[:-(len(suffix) + 1)] + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["to_q", "to_k", "to_v"] + k_to = "{}.{}.{}".format(prefix, p[x], y) + state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return state_dict + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_C(self, device=device) + return out + + def clip_target(self): + return None + + +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 58535a9fbf8..3bd4f9c6523 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -22,13 +22,14 @@ class BASE: sampling_settings = {} latent_format = latent_formats.LatentFormat vae_key_prefix = ["first_stage_model."] + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] manual_cast_dtype = None @classmethod def matches(s, unet_config): for k in s.unet_config: - if s.unet_config[k] != unet_config[k]: + if k not in unet_config or s.unet_config[k] != unet_config[k]: return False return True @@ -80,5 +81,6 @@ def process_vae_state_dict_for_saving(self, state_dict): replace_prefix = {"": "first_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_manual_cast(self, manual_cast_dtype): + def set_inference_dtype(self, dtype, manual_cast_dtype): + self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype From 667c92814e46e60eba82b86bb23b664f3157c9b9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 16 Feb 2024 12:56:11 -0500 Subject: [PATCH 36/64] Stable Cascade Stage B. --- comfy/latent_formats.py | 4 +- comfy/ldm/cascade/common.py | 2 +- comfy/ldm/cascade/stage_b.py | 257 +++++++++++++++++++++++++++ comfy/ldm/cascade/stage_c.py | 8 +- comfy/model_base.py | 25 +++ comfy/ops.py | 49 ++++- comfy/supported_models.py | 16 +- comfy/utils.py | 2 + comfy_extras/nodes_stable_cascade.py | 74 ++++++++ nodes.py | 1 + 10 files changed, 430 insertions(+), 8 deletions(-) create mode 100644 comfy/ldm/cascade/stage_b.py create mode 100644 comfy_extras/nodes_stable_cascade.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 8ba767372b6..68fd73d0b5d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -42,4 +42,6 @@ class SC_Prior(LatentFormat): def __init__(self): self.scale_factor = 1.0 - +class SC_B(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py index 3c5bf99771e..c2ef3ec4b94 100644 --- a/comfy/ldm/cascade/common.py +++ b/comfy/ldm/cascade/common.py @@ -84,7 +84,7 @@ def __init__(self, dim, dtype=None, device=None): def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x + return self.gamma.to(x.device) * (x * Nx) + self.beta.to(x.device) + x class ResBlock(nn.Module): diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py new file mode 100644 index 00000000000..6d2c2223143 --- /dev/null +++ b/comfy/ldm/cascade/stage_b.py @@ -0,0 +1,257 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import math +import numpy as np +import torch +from torch import nn +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock + +class StageB(nn.Module): + def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], + nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], + block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, + c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True, + t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.effnet_mapper = nn.Sequential( + operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.pixels_mapper = nn.Sequential( + operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip): + if len(clip.shape) == 2: + clip = clip.unsqueeze(1) + clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + if pixels is None: + pixels = x.new_zeros(x.size(0), 3, 8, 8) + + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x = self.embedding(x) + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) + x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', + align_corners=True) + level_outputs = self._down_encode(x, r_embed, clip) + x = self._up_decode(level_outputs, r_embed, clip) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py index 2e0f47068b4..08e33aded22 100644 --- a/comfy/ldm/cascade/stage_c.py +++ b/comfy/ldm/cascade/stage_c.py @@ -42,7 +42,7 @@ class StageC(nn.Module): def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, - dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None, + dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None, dtype=None, device=None, operations=None): super().__init__() self.dtype = dtype @@ -100,7 +100,7 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): if block_repeat is not None: block_repeat_mappers = nn.ModuleList() for _ in range(block_repeat[0][i] - 1): - block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) self.down_repeat_mappers.append(block_repeat_mappers) # -- up blocks @@ -126,12 +126,12 @@ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): if block_repeat is not None: block_repeat_mappers = nn.ModuleList() for _ in range(block_repeat[1][::-1][i] - 1): - block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) self.up_repeat_mappers.append(block_repeat_mappers) # OUTPUT self.clf = nn.Sequential( - LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), nn.PixelShuffle(patch_size), ) diff --git a/comfy/model_base.py b/comfy/model_base.py index dde2a2f36e3..fefce76378c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,6 +1,7 @@ import torch from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC +from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation import comfy.model_management @@ -461,3 +462,27 @@ def extra_conds(self, **kwargs): out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn) return out + +class StableCascade_B(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageB) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + noise = kwargs.get("noise", None) + + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + + #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched + prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) + + out["effnet"] = comfy.conds.CONDRegular(prior) + out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['clip'] = comfy.conds.CONDCrossAttn(cross_attn) + return out diff --git a/comfy/ops.py b/comfy/ops.py index f674b47f762..517688e8b92 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch import comfy.model_management @@ -78,7 +96,11 @@ def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) + if self.weight is not None: + weight, bias = cast_bias_weight(self, input) + else: + weight = None + bias = None return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): @@ -87,6 +109,28 @@ def forward(self, *args, **kwargs): else: return super().forward(*args, **kwargs) + class ConvTranspose2d(torch.nn.ConvTranspose2d): + comfy_cast_weights = False + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input, output_size=None): + num_spatial_dims = 2 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, + num_spatial_dims, self.dilation) + + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.conv_transpose2d( + input, weight, bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: @@ -112,3 +156,6 @@ class GroupNorm(disable_weight_init.GroupNorm): class LayerNorm(disable_weight_init.LayerNorm): comfy_cast_weights = True + + class ConvTranspose2d(disable_weight_init.ConvTranspose2d): + comfy_cast_weights = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index a8863e72b0b..7859bac90c1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -338,6 +338,20 @@ def get_model(self, state_dict, prefix="", device=None): def clip_target(self): return None +class Stable_Cascade_B(Stable_Cascade_C): + unet_config = { + "stable_cascade_stage": 'b', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_B + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_B(self, device=device) + return out + -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C] +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] models += [SVD_img2vid] diff --git a/comfy/utils.py b/comfy/utils.py index 1113bf0f52f..04cf76ed678 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -169,6 +169,8 @@ def transformers_convert(sd, prefix_from, prefix_to, number): } def unet_to_diffusers(unet_config): + if "num_res_blocks" not in unet_config: + return {} num_res_blocks = unet_config["num_res_blocks"] channel_mult = unet_config["channel_mult"] transformer_depth = unet_config["transformer_depth"][:] diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py new file mode 100644 index 00000000000..5d31c1e59c7 --- /dev/null +++ b/comfy_extras/nodes_stable_cascade.py @@ -0,0 +1,74 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import nodes + + +class StableCascade_EmptyLatentImage: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "compression": ("INT", {"default": 42, "min": 32, "max": 64, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}) + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, width, height, compression, batch_size=1): + c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) + b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageB_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "conditioning": ("CONDITIONING",), + "stage_c": ("LATENT",), + }} + RETURN_TYPES = ("CONDITIONING",) + + FUNCTION = "set_prior" + + CATEGORY = "_for_testing/stable_cascade" + + def set_prior(self, conditioning, stage_c): + c = [] + for t in conditioning: + d = t[1].copy() + d['stable_cascade_prior'] = stage_c['samples'] + n = [t[0], d] + c.append(n) + return (c, ) + +NODE_CLASS_MAPPINGS = { + "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, + "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, +} diff --git a/nodes.py b/nodes.py index d4f6f3ab5af..7b8f09dead9 100644 --- a/nodes.py +++ b/nodes.py @@ -1967,6 +1967,7 @@ def init_custom_nodes(): "nodes_sdupscale.py", "nodes_photomaker.py", "nodes_cond.py", + "nodes_stable_cascade.py", ] for node_file in extras_files: From 97d03ae04a228d9a51581106eb9f4e90009ac4f6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 16 Feb 2024 13:29:04 -0500 Subject: [PATCH 37/64] StableCascade CLIP model support. --- comfy/sd.py | 14 +++++++++++--- comfy/sd1_clip.py | 4 ++-- comfy/sdxl_clip.py | 22 ++++++++++++++++++++++ comfy/supported_models.py | 2 +- nodes.py | 9 +++++++-- 5 files changed, 43 insertions(+), 8 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index f3ec62b3a36..d8c0bfa7ca3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,5 @@ import torch +from enum import Enum from comfy import model_management from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine @@ -309,8 +310,11 @@ def load_style_model(ckpt_path): model.load_state_dict(model_data) return StyleModel(model) +class CLIPType(Enum): + STABLE_DIFFUSION = 1 + STABLE_CASCADE = 2 -def load_clip(ckpt_paths, embedding_directory=None): +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): clip_data = [] for p in ckpt_paths: clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) @@ -326,8 +330,12 @@ class EmptyClass: clip_target.params = {} if len(clip_data) == 1: if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: - clip_target.clip = sdxl_clip.SDXLRefinerClipModel - clip_target.tokenizer = sdxl_clip.SDXLTokenizer + if clip_type == CLIPType.STABLE_CASCADE: + clip_target.clip = sdxl_clip.StableCascadeClipModel + clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer + else: + clip_target.clip = sdxl_clip.SDXLRefinerClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 65ea909febc..8287ad2e8b8 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32 + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -88,7 +88,7 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le self.special_tokens = special_tokens self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.enable_attention_masks = False + self.enable_attention_masks = enable_attention_masks self.layer_norm_hidden_state = layer_norm_hidden_state if layer == "hidden": diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index b35056bb9d6..3ce5c7e05e6 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -64,3 +64,25 @@ def load_sd(self, sd): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) + + +class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): + def __init__(self, tokenizer_path=None, embedding_directory=None): + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + +class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer) + +class StableCascadeClipG(sd1_clip.SDClipModel): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) + + def load_sd(self, sd): + return super().load_sd(sd) + +class StableCascadeClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7859bac90c1..3a317edcf8b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -336,7 +336,7 @@ def get_model(self, state_dict, prefix="", device=None): return out def clip_target(self): - return None + return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) class Stable_Cascade_B(Stable_Cascade_C): unet_config = { diff --git a/nodes.py b/nodes.py index 7b8f09dead9..47203f4176c 100644 --- a/nodes.py +++ b/nodes.py @@ -854,15 +854,20 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), + "type": (["stable_diffusion", "stable_cascade"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" - def load_clip(self, clip_name): + def load_clip(self, clip_name, type="stable_diffusion"): + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + if type == "stable_cascade": + clip_type = comfy.sd.CLIPType.STABLE_CASCADE + clip_path = folder_paths.get_full_path("clip", clip_name) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) class DualCLIPLoader: From 0b3c50480c8cfcfe22c2a4059f91cec337114a78 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 16 Feb 2024 23:01:54 -0500 Subject: [PATCH 38/64] Make --force-fp32 disable loading models in bf16. --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index eb7178b4442..f0f4ebf58ab 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -772,6 +772,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return True def should_use_bf16(device=None): + if FORCE_FP32: + return False + if is_intel_xpu(): return True From f2d1d16f4f7c4221af3d6d121817e4bd8f6ddc88 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 16 Feb 2024 23:41:23 -0500 Subject: [PATCH 39/64] Support Stable Cascade Stage B lite. --- comfy/model_detection.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8d4fb7b66f5..8fca6d8c8e4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -46,6 +46,18 @@ def detect_unet_config(state_dict, key_prefix): unet_config['c_cond'] = 2048 elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: unet_config['stable_cascade_stage'] = 'b' + w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)] + if w.shape[-1] == 640: + unet_config['c_hidden'] = [320, 640, 1280, 1280] + unet_config['nhead'] = [-1, -1, 20, 20] + unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] + elif w.shape[-1] == 576: #stage b lite + unet_config['c_hidden'] = [320, 576, 1152, 1152] + unet_config['nhead'] = [-1, 9, 18, 18] + unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] + return unet_config unet_config = { From 805c36ac9c0611a483b1f494e2dfe6b67f09fe36 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 00:42:22 -0500 Subject: [PATCH 40/64] Make Stable Cascade work on old pytorch 2.0 --- comfy/ldm/cascade/stage_a.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 55fdbf17dfc..260ccfc0b5d 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -150,7 +150,11 @@ def forward(self, x): mods = self.gammas x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] - x = x + self.depthwise(x_temp) * mods[2] + try: + x = x + self.depthwise(x_temp) * mods[2] + except: #operation not implemented for bf16 + x_temp = self.depthwise[0](x_temp.float()).to(x.dtype) + x = x + self.depthwise[1](x_temp) * mods[2] x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] From 6c875d846b54ee34eb032fdb790a2fd1621d18f4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 07:46:17 -0500 Subject: [PATCH 41/64] Fix clip attention mask issues on some hardware. --- comfy/clip_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 09e7bbca152..9ba4e039033 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -97,7 +97,7 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f x = self.embeddings(input_tokens) mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) From 929e266f3e298478a1433fcff8b0209e52790068 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 08:13:17 -0500 Subject: [PATCH 42/64] Manual cast for bf16 on older GPUs. --- comfy/model_management.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index f0f4ebf58ab..681208ea091 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -499,7 +499,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor if should_use_fp16(device=device, model_params=model_params, manual_cast=True): if torch.float16 in supported_dtypes: return torch.float16 - if should_use_bf16(device): + if should_use_bf16(device, model_params=model_params, manual_cast=True): if torch.bfloat16 in supported_dtypes: return torch.bfloat16 return torch.float32 @@ -771,10 +771,24 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma return True -def should_use_bf16(device=None): +def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): + if device is not None: + if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow + return False + + if device is not None: #TODO not sure about mps bf16 support + if is_device_mps(device): + return False + if FORCE_FP32: return False + if directml_enabled: + return False + + if cpu_mode() or mps_mode(): + return False + if is_intel_xpu(): return True @@ -785,6 +799,13 @@ def should_use_bf16(device=None): if props.major >= 8: return True + bf16_works = torch.cuda.is_bf16_supported() + + if bf16_works or manual_cast: + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + if (not prioritize_performance) or model_params * 4 > free_model_memory: + return True + return False def soft_empty_cache(force=False): From 5b40e7a5ed192e217575c55e061c17a52cf9a15d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 11:38:47 -0500 Subject: [PATCH 43/64] Implement shift schedule for cascade stage C. --- comfy/model_sampling.py | 25 ++++++++++++++++++++++--- comfy/supported_models.py | 8 ++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index b1fbf3e2113..f42f3015fab 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -136,9 +136,16 @@ def percent_to_sigma(self, percent): class StableCascadeSampling(ModelSamplingDiscrete): def __init__(self, model_config=None): super().__init__() + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + self.num_timesteps = 1000 + self.shift = sampling_settings.get("shift", 1.0) cosine_s=8e-3 - self.cosine_s = torch.tensor([cosine_s]) + self.cosine_s = torch.tensor(cosine_s) sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 for x in range(self.num_timesteps): @@ -148,11 +155,23 @@ def __init__(self, model_config=None): self.set_sigmas(sigmas) def sigma(self, timestep): - alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod).clamp(0.0001, 0.9999) + alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod) + + if self.shift != 1.0: + var = alpha_cumprod + logSNR = (var/(1-var)).log() + logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift)) + alpha_cumprod = logSNR.sigmoid() + + alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999) return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5 def timestep(self, sigma): - return super().timestep(sigma) / 1000.0 + var = 1 / ((sigma * sigma) + 1) + var = var.clamp(0, 1.0) + s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t def percent_to_sigma(self, percent): if percent <= 0.0: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3a317edcf8b..1a673646ee5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -316,6 +316,10 @@ class Stable_Cascade_C(supported_models_base.BASE): latent_format = latent_formats.SC_Prior supported_inference_dtypes = [torch.bfloat16, torch.float32] + sampling_settings = { + "shift": 2.0, + } + def process_unet_state_dict(self, state_dict): key_list = list(state_dict.keys()) for y in ["weight", "bias"]: @@ -348,6 +352,10 @@ class Stable_Cascade_B(Stable_Cascade_C): latent_format = latent_formats.SC_B supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + sampling_settings = { + "shift": 1.0, + } + def get_model(self, state_dict, prefix="", device=None): out = model_base.StableCascade_B(self, device=device) return out From 3b9969c1c5a0428fc1d8be79129a9e97cfcc5e7d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 12:13:13 -0500 Subject: [PATCH 44/64] Properly fix attention masks in CLIP with batches. --- comfy/clip_model.py | 2 +- comfy/ldm/modules/attention.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 9ba4e039033..9b82a246b2c 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -97,7 +97,7 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f x = self.embeddings(input_tokens) mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 9c9cb761dd7..bb539def6e2 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -114,7 +114,8 @@ def attention_basic(q, k, v, heads, mask=None): mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) else: - sim += mask + mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(sim.shape) + sim.add_(mask) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) @@ -165,6 +166,9 @@ def attention_sub_quad(query, key, value, heads, mask=None): if query_chunk_size is None: query_chunk_size = 512 + if mask is not None: + mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + hidden_states = efficient_dot_product_attention( query, key, @@ -223,6 +227,9 @@ def attention_split(q, k, v, heads, mask=None): raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + if mask is not None: + mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) first_op_done = False cleared_cache = False From f8706546f3842fdc160c7ab831c2100701d5456e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 15:22:21 -0500 Subject: [PATCH 45/64] Fix attention mask batch size in some attention functions. --- comfy/ldm/modules/attention.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index bb539def6e2..f1dca2c2823 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -114,7 +114,11 @@ def attention_basic(q, k, v, heads, mask=None): mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) else: - mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(sim.shape) + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) sim.add_(mask) # attention, what we cannot get enough of @@ -167,7 +171,11 @@ def attention_sub_quad(query, key, value, heads, mask=None): query_chunk_size = 512 if mask is not None: - mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) hidden_states = efficient_dot_product_attention( query, @@ -228,7 +236,11 @@ def attention_split(q, k, v, heads, mask=None): f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') if mask is not None: - mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) first_op_done = False From 11e3221f1fd05d49b261ccec7dd99b704a86a89f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 15:27:31 -0500 Subject: [PATCH 46/64] fp8 weight support for Stable Cascade. --- comfy/ldm/cascade/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py index c2ef3ec4b94..124902c09a4 100644 --- a/comfy/ldm/cascade/common.py +++ b/comfy/ldm/cascade/common.py @@ -84,7 +84,7 @@ def __init__(self, dim, dtype=None, device=None): def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma.to(x.device) * (x * Nx) + self.beta.to(x.device) + x + return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x class ResBlock(nn.Module): From 6bcf57ff10f4488f7dfc6e3c47ac516967567c22 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 17 Feb 2024 16:15:18 -0500 Subject: [PATCH 47/64] Fix attention masks properly for multiple batches. --- comfy/ldm/modules/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f1dca2c2823..48399bc07e3 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -118,7 +118,7 @@ def attention_basic(q, k, v, heads, mask=None): bs = 1 else: bs = mask.shape[0] - mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) sim.add_(mask) # attention, what we cannot get enough of @@ -175,7 +175,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): bs = 1 else: bs = mask.shape[0] - mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) hidden_states = efficient_dot_product_attention( query, @@ -240,7 +240,7 @@ def attention_split(q, k, v, heads, mask=None): bs = 1 else: bs = mask.shape[0] - mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) first_op_done = False From 8b60d33bb7ce969a53fc5e25bfa0e2dca7a17b23 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 18 Feb 2024 00:55:23 -0500 Subject: [PATCH 48/64] Add ModelSamplingStableCascade to control the shift sampling parameter. shift is 2.0 by default on Stage C and 1.0 by default on Stage B. --- comfy/model_sampling.py | 12 ++++++++---- comfy_extras/nodes_model_advanced.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index f42f3015fab..ae42d81f200 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -142,12 +142,16 @@ def __init__(self, model_config=None): else: sampling_settings = {} - self.num_timesteps = 1000 - self.shift = sampling_settings.get("shift", 1.0) - cosine_s=8e-3 + self.set_parameters(sampling_settings.get("shift", 1.0)) + + def set_parameters(self, shift=1.0, cosine_s=8e-3): + self.shift = shift self.cosine_s = torch.tensor(cosine_s) - sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 + + #This part is just for compatibility with some schedulers in the codebase + self.num_timesteps = 1000 + sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) for x in range(self.num_timesteps): t = x / self.num_timesteps sigmas[x] = self.sigma(t) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 541ce8fa5cc..ac7c1c17a16 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -99,6 +99,32 @@ class ModelSamplingAdvanced(sampling_base, sampling_type): m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingStableCascade: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.StableCascadeSampling + sampling_type = comfy.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -171,5 +197,6 @@ def rescale_cfg(args): NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "ModelSamplingStableCascade": ModelSamplingStableCascade, "RescaleCFG": RescaleCFG, } From dccca1daa5af1954d55918f365e83a3331019549 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 18 Feb 2024 02:20:23 -0500 Subject: [PATCH 49/64] Fix gligen lowvram mode. --- comfy/gligen.py | 52 +++++++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/comfy/gligen.py b/comfy/gligen.py index 71892dfb1d4..592522767e9 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -2,7 +2,8 @@ from torch import nn from .ldm.modules.attention import CrossAttention from inspect import isfunction - +import comfy.ops +ops = comfy.ops.manual_cast def exists(val): return val is not None @@ -22,7 +23,7 @@ def default(val, d): class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = ops.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -35,14 +36,14 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( - nn.Linear(dim, inner_dim), + ops.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + ops.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -57,11 +58,12 @@ def __init__(self, query_dim, context_dim, n_heads, d_head): query_dim=query_dim, context_dim=context_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -87,17 +89,18 @@ def __init__(self, query_dim, context_dim, n_heads, d_head): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( query_dim=query_dim, context_dim=query_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -126,14 +129,14 @@ def __init__(self, query_dim, context_dim, n_heads, d_head): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( - query_dim=query_dim, context_dim=query_dim, dim_head=d_head) + query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -201,11 +204,11 @@ def __init__(self, in_dim, out_dim, fourier_freqs=8): self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy self.linears = nn.Sequential( - nn.Linear(self.in_dim + self.position_dim, 512), + ops.Linear(self.in_dim + self.position_dim, 512), nn.SiLU(), - nn.Linear(512, 512), + ops.Linear(512, 512), nn.SiLU(), - nn.Linear(512, out_dim), + ops.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter( @@ -215,16 +218,15 @@ def __init__(self, in_dim, out_dim, fourier_freqs=8): def forward(self, boxes, masks, positive_embeddings): B, N, _ = boxes.shape - dtype = self.linears[0].weight.dtype - masks = masks.unsqueeze(-1).to(dtype) - positive_embeddings = positive_embeddings.to(dtype) + masks = masks.unsqueeze(-1) + positive_embeddings = positive_embeddings # embedding position (it may includes padding as placeholder) - xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C # learnable null embedding - positive_null = self.null_positive_feature.view(1, 1, -1) - xyxy_null = self.null_position_feature.view(1, 1, -1) + positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) + xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) # replace padding with learnable null embedding positive_embeddings = positive_embeddings * \ @@ -251,7 +253,7 @@ def _set_position(self, boxes, masks, positive_embeddings): def func(x, extra_options): key = extra_options["transformer_index"] module = self.module_list[key] - return module(x, objs) + return module(x, objs.to(device=x.device, dtype=x.dtype)) return func def set_position(self, latent_image_shape, position_params, device): From 51714141435d050c255bba0cf588349f77bbc634 Mon Sep 17 00:00:00 2001 From: shiimizu Date: Sat, 17 Feb 2024 22:03:34 -0800 Subject: [PATCH 50/64] Support additional PNG info. --- web/scripts/pnginfo.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 83a4ebc86c4..1696092098f 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -24,7 +24,7 @@ export function getPngMetadata(file) { const length = dataView.getUint32(offset); // Get the chunk type const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8)); - if (type === "tEXt" || type == "comf") { + if (type === "tEXt" || type == "comf" || type === "iTXt") { // Get the keyword let keyword_end = offset + 8; while (pngData[keyword_end] !== 0) { @@ -33,7 +33,7 @@ export function getPngMetadata(file) { const keyword = String.fromCharCode(...pngData.slice(offset + 8, keyword_end)); // Get the text const contentArraySegment = pngData.slice(keyword_end + 1, offset + 8 + length); - const contentJson = Array.from(contentArraySegment).map(s=>String.fromCharCode(s)).join('') + const contentJson = new TextDecoder("utf-8").decode(contentArraySegment); txt_chunks[keyword] = contentJson; } From 3b2e579926d5cf8231de0e68e79096d1ee8091f2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 04:06:49 -0500 Subject: [PATCH 51/64] Support loading the Stable Cascade effnet and previewer as a VAE. The effnet can be used to encode images for img2img with Stage C. --- comfy/sd.py | 42 ++++++++++++++++++++++++++++++++++++++---- nodes.py | 20 ++++---------------- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d8c0bfa7ca3..00633e10768 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -4,6 +4,7 @@ from comfy import model_management from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.cascade.stage_a import StageA +from .ldm.cascade.stage_c_coder import StageC_coder import yaml @@ -158,6 +159,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.downscale_ratio = 8 + self.upscale_ratio = 8 self.latent_channels = 4 self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) @@ -176,11 +178,31 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade self.first_stage_model = StageA() self.downscale_ratio = 4 + self.upscale_ratio = 4 #TODO #self.memory_used_encode #self.memory_used_decode self.process_input = lambda image: image self.process_output = lambda image: image + elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["encoder.{}".format(k)] = sd[k] + sd = new_sd + elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["previewer.{}".format(k)] = sd[k] + sd = new_sd + elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 else: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -188,6 +210,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE ddconfig['ch_mult'] = [1, 2, 4] self.downscale_ratio = 4 + self.upscale_ratio = 4 self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: @@ -213,6 +236,15 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def vae_encode_crop_pixels(self, pixels): + x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio + y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % self.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % self.downscale_ratio) // 2 + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + return pixels + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -221,9 +253,9 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() output = self.process_output( - (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar)) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) / 3.0) return output @@ -248,7 +280,7 @@ def decode(self, samples_in): batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device) + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) @@ -265,6 +297,7 @@ def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): return output.movedim(1,-1) def encode(self, pixel_samples): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) @@ -284,6 +317,7 @@ def encode(self, pixel_samples): return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) diff --git a/nodes.py b/nodes.py index 47203f4176c..a577c212628 100644 --- a/nodes.py +++ b/nodes.py @@ -309,18 +309,7 @@ def INPUT_TYPES(s): CATEGORY = "latent" - @staticmethod - def vae_encode_crop_pixels(pixels): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 - if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 - pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] - return pixels - def encode(self, vae, pixels): - pixels = self.vae_encode_crop_pixels(pixels) t = vae.encode(pixels[:,:,:,:3]) return ({"samples":t}, ) @@ -336,7 +325,6 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size): - pixels = VAEEncode.vae_encode_crop_pixels(pixels) t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) return ({"samples":t}, ) @@ -350,14 +338,14 @@ def INPUT_TYPES(s): CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 + x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio + y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 + x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] From a7b5eaa7e3d5ecc1ae1d395e50d781124bb4e611 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 04:25:46 -0500 Subject: [PATCH 52/64] Forgot to commit this. --- comfy/ldm/cascade/stage_c_coder.py | 96 ++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 comfy/ldm/cascade/stage_c_coder.py diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py new file mode 100644 index 00000000000..98c9a0b6147 --- /dev/null +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -0,0 +1,96 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +import torch +import torchvision +from torch import nn + + +# EfficientNet +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) + self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225])) + + def forward(self, x): + x = x * 0.5 + 0.5 + x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) + o = self.mapper(self.backbone(x)) + print(o.shape) + return o + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return (self.blocks(x) - 0.5) * 2.0 + +class StageC_coder(nn.Module): + def __init__(self): + super().__init__() + self.previewer = Previewer() + self.encoder = EfficientNetEncoder() + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.previewer(x) From dbe0979b3f8ee4215e55012700f6d0afb0fec5b2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 08:59:34 -0500 Subject: [PATCH 53/64] Larger range for min/max compression for StableCascade_EmptyLatentImage. --- comfy_extras/nodes_stable_cascade.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 5d31c1e59c7..efe3586d282 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -29,8 +29,8 @@ def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), - "compression": ("INT", {"default": 42, "min": 32, "max": 64, "step": 1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}) + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}) }} RETURN_TYPES = ("LATENT", "LATENT") RETURN_NAMES = ("stage_c", "stage_b") From d91f45ef280a5acbdc22f3cc757f8fdbb254261b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 10:29:18 -0500 Subject: [PATCH 54/64] Some cleanups to how the text encoders are loaded. --- comfy/sd.py | 23 +++++++++++++---------- comfy/supported_models.py | 31 +++++++++++++++---------------- comfy/supported_models_base.py | 6 ++++-- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 00633e10768..7a77bb177a3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -138,8 +138,11 @@ def encode(self, text): tokens = self.tokenize(text) return self.encode_from_tokens(tokens) - def load_sd(self, sd): - return self.cond_stage_model.load_sd(sd) + def load_sd(self, sd, full_model=False): + if full_model: + return self.cond_stage_model.load_state_dict(sd, strict=False) + else: + return self.cond_stage_model.load_sd(sd) def get_sd(self): return self.cond_stage_model.state_dict() @@ -494,9 +497,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") load_device = model_management.get_torch_device() - class WeightsLoader(torch.nn.Module): - pass - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) @@ -521,14 +521,17 @@ class WeightsLoader(torch.nn.Module): vae = VAE(sd=vae_sd) if output_clip: - w = WeightsLoader() clip_target = model_config.clip_target() if clip_target is not None: - sd = model_config.process_clip_state_dict(sd) - if any(k.startswith('cond_stage_model.') for k in sd): + clip_sd = model_config.process_clip_state_dict(sd) + if len(clip_sd) > 0: clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_model_weights(w, sd) + m, u = clip.load_sd(clip_sd, full_model=True) + if len(m) > 0: + print("clip missing:", m) + + if len(u) > 0: + print("clip unexpected:", u) else: print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1a673646ee5..f29f7f3d9ea 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -40,8 +40,8 @@ def process_clip_state_dict(self, state_dict): state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() replace_prefix = {} - replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) + replace_prefix["cond_stage_model."] = "clip_l." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -72,10 +72,10 @@ def model_type(self, state_dict, prefix=""): def process_clip_state_dict(self, state_dict): replace_prefix = {} - replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) - - state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) + replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format + replace_prefix["cond_stage_model.model."] = "clip_h." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) + state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -131,11 +131,10 @@ def get_model(self, state_dict, prefix="", device=None): def process_clip_state_dict(self, state_dict): keys_to_replace = {} replace_prefix = {} + replace_prefix["conditioner.embedders.0.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" - + state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict @@ -179,13 +178,13 @@ def process_clip_state_dict(self, state_dict): keys_to_replace = {} replace_prefix = {} - replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" - state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) - keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" - keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" + replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model" + replace_prefix["conditioner.embedders.1.model."] = "clip_g." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) + + state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32) + keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection" - state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3bd4f9c6523..4d7e2593669 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -22,6 +22,7 @@ class BASE: sampling_settings = {} latent_format = latent_formats.LatentFormat vae_key_prefix = ["first_stage_model."] + text_encoder_key_prefix = ["cond_stage_model."] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] manual_cast_dtype = None @@ -55,6 +56,7 @@ def get_model(self, state_dict, prefix="", device=None): return out def process_clip_state_dict(self, state_dict): + state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True) return state_dict def process_unet_state_dict(self, state_dict): @@ -64,7 +66,7 @@ def process_vae_state_dict(self, state_dict): return state_dict def process_clip_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "cond_stage_model."} + replace_prefix = {"": self.text_encoder_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_clip_vision_state_dict_for_saving(self, state_dict): @@ -78,7 +80,7 @@ def process_unet_state_dict_for_saving(self, state_dict): return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_vae_state_dict_for_saving(self, state_dict): - replace_prefix = {"": "first_stage_model."} + replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) def set_inference_dtype(self, dtype, manual_cast_dtype): From 3711b31dff3530ba584e9a30f1bb32feaf2b5886 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 11:20:48 -0500 Subject: [PATCH 55/64] Support Stable Cascade in checkpoint format. --- comfy/supported_models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index f29f7f3d9ea..5bb98d88a96 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -319,6 +319,10 @@ class Stable_Cascade_C(supported_models_base.BASE): "shift": 2.0, } + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoder."] + clip_vision_prefix = "clip_l_vision." + def process_unet_state_dict(self, state_dict): key_list = list(state_dict.keys()) for y in ["weight", "bias"]: @@ -355,6 +359,8 @@ class Stable_Cascade_B(Stable_Cascade_C): "shift": 1.0, } + clip_vision_prefix = None + def get_model(self, state_dict, prefix="", device=None): out = model_base.StableCascade_B(self, device=device) return out From e93cdd0ad043a637a81c9d398498b1e8a8eca3b0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 11:47:26 -0500 Subject: [PATCH 56/64] Remove print. --- comfy/ldm/cascade/stage_c_coder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py index 98c9a0b6147..0cb7c49fc90 100644 --- a/comfy/ldm/cascade/stage_c_coder.py +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -36,7 +36,6 @@ def forward(self, x): x = x * 0.5 + 0.5 x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) o = self.mapper(self.backbone(x)) - print(o.shape) return o From 88f300401c0815eb5185683cb69ecf8b52cb6e7b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 12:00:48 -0500 Subject: [PATCH 57/64] Enable fp16 by default on mps. --- comfy/model_management.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 681208ea091..adcc0e8ace2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -720,9 +720,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if FORCE_FP16: return True - if device is not None: #TODO + if device is not None: if is_device_mps(device): - return False + return True if FORCE_FP32: return False @@ -730,8 +730,11 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if directml_enabled: return False - if cpu_mode() or mps_mode(): - return False #TODO ? + if mps_mode(): + return True + + if cpu_mode(): + return False if is_intel_xpu(): return True From a31152496990913211c6deb3267144bd3095c1ee Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 13:36:20 -0500 Subject: [PATCH 58/64] Node to make stable cascade image to image easier. --- comfy_extras/nodes_stable_cascade.py | 35 ++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index efe3586d282..b795d008335 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -18,6 +18,7 @@ import torch import nodes +import comfy.utils class StableCascade_EmptyLatentImage: @@ -47,6 +48,39 @@ def generate(self, width, height, compression, batch_size=1): "samples": b_latent, }) +class StableCascade_StageC_VAEEncode: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, image, vae, compression): + width = image.shape[-2] + height = image.shape[-3] + out_width = (width // compression) * vae.downscale_ratio + out_height = (height // compression) * vae.downscale_ratio + + s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) + + c_latent = vae.encode(s[:,:,:,:3]) + b_latent = torch.zeros([c_latent.shape[0], 4, height // 4, width // 4]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + class StableCascade_StageB_Conditioning: @classmethod def INPUT_TYPES(s): @@ -71,4 +105,5 @@ def set_prior(self, conditioning, stage_c): NODE_CLASS_MAPPINGS = { "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, + "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, } From ec4d89cee946faacff60681a7f16443305350260 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Feb 2024 13:41:55 -0500 Subject: [PATCH 59/64] Add to Readme that stable cascade is supported. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ff3ab64204e..a94a212ad3d 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. -- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) +- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) and [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) From c6b7a157ed30eb2dd59891ba465b7b5be97a687a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Feb 2024 04:05:39 -0500 Subject: [PATCH 60/64] Align simple scheduling closer to official stable cascade scheduler. --- comfy/model_sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index ae42d81f200..97e91a01d67 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -150,10 +150,10 @@ def set_parameters(self, shift=1.0, cosine_s=8e-3): self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 #This part is just for compatibility with some schedulers in the codebase - self.num_timesteps = 1000 + self.num_timesteps = 10000 sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) for x in range(self.num_timesteps): - t = x / self.num_timesteps + t = (x + 1) / self.num_timesteps sigmas[x] = self.sigma(t) self.set_sigmas(sigmas) From 0d0fbabd1d153611a1c21aea3515d16339abc84f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Feb 2024 04:23:25 -0500 Subject: [PATCH 61/64] Pass pooled CLIP to stage b. --- comfy/model_base.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index fefce76378c..421f271b28a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -474,15 +474,11 @@ def extra_conds(self, **kwargs): clip_text_pooled = kwargs["pooled_output"] if clip_text_pooled is not None: - out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + out['clip'] = comfy.conds.CONDRegular(clip_text_pooled) #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) out["effnet"] = comfy.conds.CONDRegular(prior) out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) - - cross_attn = kwargs.get("cross_attn", None) - if cross_attn is not None: - out['clip'] = comfy.conds.CONDCrossAttn(cross_attn) return out From 18c151b3e3f6838fab4028e7a8ba526e30e610d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Feb 2024 10:57:24 -0500 Subject: [PATCH 62/64] Add some latent2rgb matrices for previews. --- comfy/latent_formats.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 68fd73d0b5d..03fd59e3da0 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -37,11 +37,41 @@ def __init__(self): class SD_X4(LatentFormat): def __init__(self): self.scale_factor = 0.08333 + self.latent_rgb_factors = [ + [-0.2340, -0.3863, -0.3257], + [ 0.0994, 0.0885, -0.0908], + [-0.2833, -0.2349, -0.3741], + [ 0.2523, -0.0055, -0.1651] + ] class SC_Prior(LatentFormat): def __init__(self): self.scale_factor = 1.0 + self.latent_rgb_factors = [ + [-0.0326, -0.0204, -0.0127], + [-0.1592, -0.0427, 0.0216], + [ 0.0873, 0.0638, -0.0020], + [-0.0602, 0.0442, 0.1304], + [ 0.0800, -0.0313, -0.1796], + [-0.0810, -0.0638, -0.1581], + [ 0.1791, 0.1180, 0.0967], + [ 0.0740, 0.1416, 0.0432], + [-0.1745, -0.1888, -0.1373], + [ 0.2412, 0.1577, 0.0928], + [ 0.1908, 0.0998, 0.0682], + [ 0.0209, 0.0365, -0.0092], + [ 0.0448, -0.0650, -0.1728], + [-0.1658, -0.1045, -0.1308], + [ 0.0542, 0.1545, 0.1325], + [-0.0352, -0.1672, -0.2541] + ] class SC_B(LatentFormat): def __init__(self): self.scale_factor = 1.0 + self.latent_rgb_factors = [ + [ 0.1121, 0.2006, 0.1023], + [-0.2093, -0.0222, -0.0195], + [-0.3087, -0.1535, 0.0366], + [ 0.0290, -0.1574, -0.4078] + ] From 7faa4507ecbd2ad67afcdea44b46ecdceec75232 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 21 Feb 2024 08:05:43 -0500 Subject: [PATCH 63/64] ModelSamplingDiscrete: x0 model support that predict a denoised image. --- comfy_extras/nodes_model_advanced.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index ac7c1c17a16..1b3f3945e38 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -17,6 +17,10 @@ def calculate_denoised(self, sigma, model_output, model_input): return c_out * x0 + c_skip * model_input +class X0(comfy.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 @@ -68,7 +72,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm"],), + "sampling": (["eps", "v_prediction", "lcm", "x0"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -88,6 +92,8 @@ def patch(self, model, sampling, zsnr): elif sampling == "lcm": sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled + elif sampling == "x0": + sampling_type = X0 class ModelSamplingAdvanced(sampling_base, sampling_type): pass From f81dbe26e2e363c28ad043db67b59c11bb33f446 Mon Sep 17 00:00:00 2001 From: Rick Love Date: Wed, 21 Feb 2024 19:21:24 -0600 Subject: [PATCH 64/64] FIX recursive_will_execute performance (simple ~300x performance increase} (#2852) * FIX recursive_will_execute performance * Minimize code changes * memo must be created outside lambda --- execution.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/execution.py b/execution.py index 00908eadd46..3e9d53b0ed2 100644 --- a/execution.py +++ b/execution.py @@ -194,8 +194,12 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute return (True, None, None) -def recursive_will_execute(prompt, outputs, current_item): +def recursive_will_execute(prompt, outputs, current_item, memo={}): unique_id = current_item + + if unique_id in memo: + return memo[unique_id] + inputs = prompt[unique_id]['inputs'] will_execute = [] if unique_id in outputs: @@ -207,9 +211,10 @@ def recursive_will_execute(prompt, outputs, current_item): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id) + will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo) - return will_execute + [unique_id] + memo[unique_id] = will_execute + [unique_id] + return memo[unique_id] def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): unique_id = current_item @@ -377,7 +382,8 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + memo = {} + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute))) output_node_id = to_execute.pop(0)[-1] # This call shouldn't raise anything if there's an error deep in