Skip to content

Commit

Permalink
Implement creating inputs by dragging link to widget (#1021)
Browse files Browse the repository at this point in the history
* Implement creating inputs by dragging link to widget

* Update litegraph

* Add playwright test

* Update test expectations [skip ci]

---------

Co-authored-by: huchenlei <huchenlei@proton.me>
Co-authored-by: github-actions <github-actions@github.com>
  • Loading branch information
3 people authored Oct 1, 2024
1 parent c42222c commit a2bd2a9
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 10 deletions.
51 changes: 49 additions & 2 deletions browser_tests/ComfyPage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,8 @@ export class ComfyPage {
return this.getNodeRefById(id)
}
}
class NodeSlotReference {

export class NodeSlotReference {
constructor(
readonly type: 'input' | 'output',
readonly index: number,
Expand Down Expand Up @@ -980,7 +981,37 @@ class NodeSlotReference {
)
}
}
class NodeReference {

export class NodeWidgetReference {
constructor(
readonly index: number,
readonly node: NodeReference
) {}

async getPosition(): Promise<Position> {
const pos: [number, number] = await this.node.comfyPage.page.evaluate(
([id, index]) => {
const node = window['app'].graph.getNodeById(id)
if (!node) throw new Error(`Node ${id} not found.`)
const widget = node.widgets[index]
if (!widget) throw new Error(`Widget ${index} not found.`)

const [x, y, w, h] = node.getBounding()
return window['app'].canvas.ds.convertOffsetToCanvas([
x + w / 2,
y + window['LiteGraph']['NODE_TITLE_HEIGHT'] + widget.last_y + 1
])
},
[this.node.id, this.index] as const
)
return {
x: pos[0],
y: pos[1]
}
}
}

export class NodeReference {
constructor(
readonly id: NodeId,
readonly comfyPage: ComfyPage
Expand Down Expand Up @@ -1026,6 +1057,9 @@ class NodeReference {
async getInput(index: number) {
return new NodeSlotReference('input', index, this)
}
async getWidget(index: number) {
return new NodeWidgetReference(index, this)
}
async click(position: 'title', options?: Parameters<Page['click']>[1]) {
const nodePos = await this.getPosition()
const nodeSize = await this.getSize()
Expand All @@ -1043,6 +1077,19 @@ class NodeReference {
})
await this.comfyPage.nextFrame()
}
async connectWidget(
originSlotIndex: number,
targetNode: NodeReference,
targetWidgetIndex: number
) {
const originSlot = await this.getOutput(originSlotIndex)
const targetWidget = await targetNode.getWidget(targetWidgetIndex)
await this.comfyPage.dragAndDrop(
await originSlot.getPosition(),
await targetWidget.getPosition()
)
return originSlot
}
async connectOutput(
originSlotIndex: number,
targetNode: NodeReference,
Expand Down
104 changes: 104 additions & 0 deletions browser_tests/assets/primitive_node_unconnected.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
{
"last_node_id": 2,
"last_link_id": 1,
"nodes": [
{
"id": 2,
"type": "KSampler",
"pos": {
"0": 304.3653259277344,
"1": 42.15586471557617
},
"size": [
315,
262
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": null
},
{
"name": "positive",
"type": "CONDITIONING",
"link": null
},
{
"name": "negative",
"type": "CONDITIONING",
"link": null
},
{
"name": "latent_image",
"type": "LATENT",
"link": null
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "KSampler"
},
"widgets_values": [
0,
"randomize",
20,
8,
"euler",
"normal",
1
]
},
{
"id": 1,
"type": "PrimitiveNode",
"pos": {
"0": 14,
"1": 43
},
"size": [
203.1999969482422,
40.368401303242536
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "connect to widget input",
"type": "*",
"links": [],
"slot_index": 0
}
],
"properties": {
"Run widget replace on values": false
},
"widgets_values": []
}
],
"links": [],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 1,
"offset": [
0,
0
]
}
},
"version": 0.4
}
15 changes: 14 additions & 1 deletion browser_tests/primitiveNode.spec.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import { expect } from '@playwright/test'
import { comfyPageFixture as test } from './ComfyPage'
import { type NodeReference, comfyPageFixture as test } from './ComfyPage'

test.describe('Primitive Node', () => {
test('Can load with correct size', async ({ comfyPage }) => {
await comfyPage.loadWorkflow('primitive_node')
await expect(comfyPage.canvas).toHaveScreenshot('primitive_node.png')
})

// When link is dropped on widget, it should automatically convert the widget
// to input.
test('Can connect to widget', async ({ comfyPage }) => {
await comfyPage.loadWorkflow('primitive_node_unconnected')
const primitiveNode: NodeReference = await comfyPage.getNodeRefById(1)
const ksamplerNode: NodeReference = await comfyPage.getNodeRefById(2)
// Connect the output of the primitive node to the input of first widget of the ksampler node
await primitiveNode.connectWidget(0, ksamplerNode, 0)
await expect(comfyPage.canvas).toHaveScreenshot(
'primitive_node_connected.png'
)
})
})
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.2.1",
"@comfyorg/litegraph": "^0.7.82",
"@comfyorg/litegraph": "^0.7.83",
"@primevue/themes": "^4.0.5",
"@vitejs/plugin-vue": "^5.0.5",
"@vueuse/core": "^11.0.0",
Expand Down
29 changes: 27 additions & 2 deletions src/extensions/core/widgetInputs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,26 @@ class PrimitiveNode extends LGraphNode {
}
}

isValidWidgetLink(
originSlot: number,
targetNode: LGraphNode,
targetWidget: IWidget
) {
const config2 = getConfig.call(targetNode, targetWidget.name) ?? [
targetWidget.type,
targetWidget.options || {}
]
if (!isConvertibleWidget(targetWidget, config2)) return false

const output = this.outputs[originSlot]
if (!(output.widget?.[CONFIG] ?? output.widget?.[GET_CONFIG]())) {
// No widget defined for this primitive yet so allow it
return true
}

return !!mergeIfValid.call(this, output, config2)
}

#isValidConnection(input: INodeInputSlot, forceUpdate?: boolean) {
// Only allow connections where the configs match
const output = this.outputs[0]
Expand Down Expand Up @@ -448,15 +468,19 @@ function showWidget(widget) {
}
}

function convertToInput(node: LGraphNode, widget: IWidget, config: InputSpec) {
export function convertToInput(
node: LGraphNode,
widget: IWidget,
config: InputSpec
) {
hideWidget(node, widget)

const { type } = getWidgetType(config)

// Add input and store widget config for creating on primitive node
const sz = node.size
const inputIsOptional = !!widget.options?.inputIsOptional
node.addInput(widget.name, type, {
const input = node.addInput(widget.name, type, {
// @ts-expect-error GET_CONFIG is not defined in LiteGraph
widget: { name: widget.name, [GET_CONFIG]: () => config },
// @ts-expect-error LiteGraph.SlotShape is not typed.
Expand All @@ -469,6 +493,7 @@ function convertToInput(node: LGraphNode, widget: IWidget, config: InputSpec) {

// Restore original size but grow if needed
node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])])
return input
}

function convertToWidget(node, widget) {
Expand Down
48 changes: 48 additions & 0 deletions src/scripts/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import { ModelStore, useModelStore } from '@/stores/modelStore'
import type { ToastMessageOptions } from 'primevue/toast'
import { useWorkspaceStore } from '@/stores/workspaceStateStore'
import { useExecutionStore } from '@/stores/executionStore'
import { IWidget } from '@comfyorg/litegraph'

export const ANIM_PREVIEW_WIDGET = '$$comfy_animation_preview'

Expand Down Expand Up @@ -1694,6 +1695,52 @@ export class ComfyApp {
}
}

#addWidgetLinkHandling() {
app.canvas.getWidgetLinkType = function (widget, node) {
const nodeDefStore = useNodeDefStore()
const nodeDef = nodeDefStore.nodeDefsByName[node.type]
const input = nodeDef.input.getInput(widget.name)
return input?.type
}

type ConnectingWidgetLink = {
subType: 'connectingWidgetLink'
widget: IWidget
node: LGraphNode
link: { node: LGraphNode; slot: number }
}

document.addEventListener(
'litegraph:canvas',
async (e: CustomEvent<ConnectingWidgetLink>) => {
if (e.detail.subType === 'connectingWidgetLink') {
const { convertToInput } = await import(
'@/extensions/core/widgetInputs'
)

const { node, link, widget } = e.detail
if (!node || !link || !widget) return

const nodeData = node.constructor.nodeData
if (!nodeData) return
const all = {
...nodeData?.input?.required,
...nodeData?.input?.optional
}
const inputSpec = all[widget.name]
if (!inputSpec) return

const input = convertToInput(node, widget, inputSpec)
if (!input) return

const originNode = link.node

originNode.connect(link.slot, node, node.inputs.lastIndexOf(input))
}
}
)
}

#addAfterConfigureHandler() {
const app = this
// @ts-expect-error
Expand Down Expand Up @@ -1915,6 +1962,7 @@ export class ComfyApp {
this.#addCopyHandler()
this.#addPasteHandler()
this.#addKeyboardHandler()
this.#addWidgetLinkHandling()

await this.#invokeExtensionsAsync('setup')
}
Expand Down

0 comments on commit a2bd2a9

Please sign in to comment.