Skip to content

Commit

Permalink
Merge pull request #14 from tctien342/revert-13-revert-11-feature/Nod…
Browse files Browse the repository at this point in the history
…eBypass

Revert "Revert "feat(builder): bypass nodes""
  • Loading branch information
tctien342 authored Dec 2, 2024
2 parents 2d13a45 + 5cf7bfc commit b9463bf
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 32 deletions.
143 changes: 116 additions & 27 deletions src/call-wrapper.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { NodeProgress } from "./types/api";
import { NodeData, NodeDef, NodeProgress } from "./types/api";
import { ComfyApi } from "./client";
import { PromptBuilder } from "./prompt-builder";
import { TExecutionCached } from "./types/event";
Expand All @@ -10,29 +10,30 @@ import {
CustomEventError,
ExecutionFailedError,
ExecutionInterruptedError,
MissingNodeError,
} from "./types/error";

/**
* Represents a wrapper class for making API calls using the ComfyApi client.
* Provides methods for setting callback functions and executing the job.
*/
export class CallWrapper<T extends PromptBuilder<string, string, object>> {
export class CallWrapper<I extends string, O extends string, T extends NodeData> {
private client: ComfyApi;
private prompt: T;
private prompt: PromptBuilder<I, O, T>;
private started = false;
private promptId?: string;
private output: Record<keyof T["mapOutputKeys"], any> = {} as any;
private output: Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> = {} as any;

private onPreviewFn?: (ev: Blob, promptId?: string) => void;
private onPendingFn?: (promptId?: string) => void;
private onStartFn?: (promptId?: string) => void;
private onOutputFn?: (
key: keyof T["mapOutputKeys"],
key: keyof PromptBuilder<I, string, T>["mapOutputKeys"],
data: any,
promptId?: string
) => void;
private onFinishedFn?: (
data: Record<keyof T["mapOutputKeys"], any>,
data: Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any>,
promptId?: string
) => void;
private onFailedFn?: (err: Error, promptId?: string) => void;
Expand All @@ -54,7 +55,7 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
* @param client The ComfyApi client.
* @param workflow The workflow object.
*/
constructor(client: ComfyApi, workflow: T) {
constructor(client: ComfyApi, workflow: PromptBuilder<I, O, T>) {
this.client = client;
this.prompt = workflow;
return this;
Expand Down Expand Up @@ -102,7 +103,7 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
* @returns The current instance of the class.
*/
onOutput(
fn: (key: keyof T["mapOutputKeys"], data: any, promptId?: string) => void
fn: (key: keyof PromptBuilder<I, O, T>["mapOutputKeys"], data: any, promptId?: string) => void
) {
this.onOutputFn = fn;
return this;
Expand All @@ -116,7 +117,7 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
* @returns The current instance of the CallWrapper.
*/
onFinished(
fn: (data: Record<keyof T["mapOutputKeys"], any>, promptId?: string) => void
fn: (data: Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any>, promptId?: string) => void
) {
this.onFinishedFn = fn;
return this;
Expand Down Expand Up @@ -154,7 +155,7 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
* or `false` if the job execution fails.
*/
async run(): Promise<
Record<keyof T["mapOutputKeys"], any> | undefined | false
Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> | undefined | false
> {
/**
* Start the job execution.
Expand All @@ -170,8 +171,8 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
promptLoadTrigger = resolve;
})

let jobDoneTrigger!: (value: Record<keyof T["mapOutputKeys"], any> | false) => void;
const jobDonePromise: Promise<Record<keyof T["mapOutputKeys"], any> | false> = new Promise(
let jobDoneTrigger!: (value: Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> | false) => void;
const jobDonePromise: Promise<Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> | false> = new Promise(
(resolve) => {
jobDoneTrigger = resolve;
}
Expand Down Expand Up @@ -212,7 +213,7 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
// race condition handling
let wentMissing = false;
let cachedOutputDone = false;
let cachedOutputPromise: Promise<false | Record<keyof T["mapOutputKeys"], any> | null> = Promise.resolve(null);
let cachedOutputPromise: Promise<false | Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> | null> = Promise.resolve(null);

const statusHandler = async () => {
const queue = await this.client.getQueue();
Expand Down Expand Up @@ -283,9 +284,88 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
return jobDonePromise;
}

private async bypassWorkflowNodes(workflow: NodeData) {
const nodeDefs: Record<string, NodeDef> = {}; // cache node definitions

for (const nodeId of this.prompt.bypassNodes) {
if (!workflow[nodeId as string]) {
throw new MissingNodeError(`Node ${nodeId.toString()} is missing from the workflow!`);
}

const classType = workflow[nodeId as string].class_type;

const def = nodeDefs[classType] || (await this.client.getNodeDefs(classType))?.[classType];
if (!def) {
throw new MissingNodeError(`Node type ${workflow[nodeId as string].class_type} is missing from server!`);
}
nodeDefs[classType] = def;

const connections = new Map<number, any>();
const connectedInputs: string[] = [];

// connect output nodes to matching input nodes
for (const [outputIdx, outputType] of def.output.entries()) {
for (const [inputName, inputValue] of Object.entries(workflow[nodeId as string].inputs)) {
if (connectedInputs.includes(inputName)) {
continue;
}

if (def.input.required[inputName]?.[0] === outputType) {
connections.set(outputIdx, inputValue);
connectedInputs.push(inputName);
break;
}

if (def.input.optional?.[inputName]?.[0] === outputType) {
connections.set(outputIdx, inputValue);
connectedInputs.push(inputName);
break;
}
}
}

// search and replace all nodes' inputs referencing this node based on matching output type, or remove reference
// if no matching output type was found
for (const [conNodeId, conNode] of Object.entries(workflow)) {
for (const [conInputName, conInputValue] of Object.entries(conNode.inputs)) {
if (!Array.isArray(conInputValue) || conInputValue[0] !== nodeId) {
continue;
}

if (connections.has(conInputValue[1])) {
workflow[conNodeId].inputs[conInputName] = connections.get(conInputValue[1]);
} else {
delete workflow[conNodeId].inputs[conInputName];
}
}
}

delete workflow[nodeId as string];
}

return workflow;
}

private async enqueueJob() {
let workflow = structuredClone(this.prompt.workflow) as NodeData;

if (this.prompt.bypassNodes.length > 0) {
try {
workflow = await this.bypassWorkflowNodes(workflow);
} catch (e) {
if (e instanceof Response) {
this.onFailedFn?.(
new MissingNodeError("Failed to get workflow node definitions", { cause: await e.json() })
);
} else {
this.onFailedFn?.(new MissingNodeError("There was a missing node in the workflow bypass.", { cause: e }));
}
return null;
}
}

const job = await this.client
.appendPrompt(this.prompt.workflow)
.appendPrompt(workflow)
.catch(async (e) => {
if (e instanceof Response) {
this.onFailedFn?.(
Expand All @@ -310,7 +390,7 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {

private async handleCachedOutput(
promptId: string
): Promise<Record<keyof T["mapOutputKeys"], any> | false | null> {
): Promise<Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> | false | null> {
const hisData = await this.client.getHistory(promptId);
if (hisData?.status?.completed) {
const output = this.mapOutput(hisData.outputs);
Expand All @@ -324,14 +404,14 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
return null;
}

private mapOutput(outputNodes: any): Record<keyof T["mapOutputKeys"], any> {
private mapOutput(outputNodes: any): Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> {
const outputMapped = this.prompt.mapOutputKeys;
const output: Record<keyof T["mapOutputKeys"], any> = {} as any;
const output: Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> = {} as any;

for (const key in outputMapped) {
const node = outputMapped[key];
if (node) {
output[key as keyof T["mapOutputKeys"]] = outputNodes[node];
output[key as keyof PromptBuilder<I, O, T>["mapOutputKeys"]] = outputNodes[node];
}
}

Expand All @@ -340,7 +420,7 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {

private handleJobExecution(
promptId: string,
jobDoneTrigger: (value: Record<keyof T["mapOutputKeys"], any> | false) => void
jobDoneTrigger: (value: Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> | false) => void
): void {
const reverseOutputMapped = this.reverseMapOutputKeys();

Expand All @@ -351,7 +431,8 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
this.onPreviewFn?.(ev.detail, this.promptId)
);

let totalOutput = Object.keys(reverseOutputMapped).length;
const totalOutput = Object.keys(reverseOutputMapped).length;
let remainingOutput = totalOutput;

const executionHandler = (ev: CustomEvent) => {
if (ev.detail.prompt_id !== promptId) return;
Expand All @@ -361,21 +442,29 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
ev.detail.node as keyof typeof this.prompt.mapOutputKeys
];
if (outputKey) {
this.output[outputKey as keyof T["mapOutputKeys"]] =
this.output[outputKey as keyof PromptBuilder<I, O, T>["mapOutputKeys"]] =
ev.detail.output;
this.onOutputFn?.(outputKey, ev.detail.output, this.promptId);
totalOutput--;
remainingOutput--;
}

if (totalOutput === 0) {
if (remainingOutput === 0) {
this.cleanupListeners();
this.onFinishedFn?.(this.output, this.promptId);
jobDoneTrigger(this.output);
}
};

const executedEnd = () => {
if (totalOutput !== 0) {
const executedEnd = async () => {
if (remainingOutput !== 0) {
// some cached output nodes might output after executedEnd, so check history data if an output is really missing
const hisData = await this.client.getHistory(promptId);
if (hisData?.status?.completed) {
const outputCount = Object.keys(hisData.outputs).length;
if (outputCount > 0 && outputCount - totalOutput === 0) {
return;
}
}
this.onFailedFn?.(new ExecutionFailedError("Execution failed"), this.promptId);
this.cleanupListeners();
jobDoneTrigger(false);
Expand Down Expand Up @@ -405,7 +494,7 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
}

private reverseMapOutputKeys(): Record<string, string> {
const outputMapped = this.prompt.mapOutputKeys;
const outputMapped: Partial<Record<string, string>> = this.prompt.mapOutputKeys;
return Object.entries(outputMapped).reduce((acc, [k, v]) => {
if (v) acc[v] = k;
return acc;
Expand All @@ -423,7 +512,7 @@ export class CallWrapper<T extends PromptBuilder<string, string, object>> {
private handleError(
ev: CustomEvent,
promptId: string,
resolve: (value: Record<keyof T["mapOutputKeys"], any> | false) => void
resolve: (value: Record<keyof PromptBuilder<I, O, T>["mapOutputKeys"], any> | false) => void
) {
if (ev.detail.prompt_id !== promptId) return;
this.onFailedFn?.(
Expand Down
58 changes: 55 additions & 3 deletions src/prompt-builder.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import { encodeNTPath, encodePosixPath } from "./tools";
import { OSType } from "./types/api";
import { NodeData, OSType } from "./types/api";
import { DeepKeys, Simplify } from "./types/tool";

export class PromptBuilder<I extends string, O extends string, T = unknown> {
export class PromptBuilder<I extends string, O extends string, T extends NodeData> {
prompt: T;
mapInputKeys: Partial<Record<I, string | string[]>> = {};
mapOutputKeys: Partial<Record<O, string>> = {};
bypassNodes: (keyof T)[] = [];

constructor(prompt: T, inputKeys: I[], outputKeys: O[]) {
this.prompt = prompt;
this.prompt = structuredClone(prompt);
inputKeys.forEach((key) => {
this.mapInputKeys[key] = undefined;
});
Expand All @@ -31,6 +32,57 @@ export class PromptBuilder<I extends string, O extends string, T = unknown> {
);
newBuilder.mapInputKeys = { ...this.mapInputKeys };
newBuilder.mapOutputKeys = { ...this.mapOutputKeys };
newBuilder.bypassNodes = [ ...this.bypassNodes ];
return newBuilder;
}


/**
* Marks a node to be bypassed at generation.
*
* @param node Node which will be bypassed.
*/
bypass(node: keyof T): PromptBuilder<I, O, T>;

/**
* Marks multiple nodes to be bypassed at generation.
*
* @param nodes Array of nodes which will be bypassed.
*/
bypass(nodes: (keyof T)[]): PromptBuilder<I, O, T>;

bypass(nodes: keyof T | (keyof T)[]) {
if (!Array.isArray(nodes)) {
nodes = [nodes];
}
const newBuilder = this.clone();
newBuilder.bypassNodes.push(...nodes);
return newBuilder;
}

/**
* Unmarks a node from bypass at generation.
*
* @param node Node to reverse bypass on.
*/
reinstate(node: keyof T): PromptBuilder<I, O, T>;

/**
* Unmarks a collection of nodes from bypass at generation.
*
* @param nodes Array of nodes to reverse bypass on.
*/
reinstate(nodes: (keyof T)[]): PromptBuilder<I, O, T>;

reinstate(nodes: keyof T | (keyof T)[]) {
if (!Array.isArray(nodes)) {
nodes = [nodes];
}

const newBuilder = this.clone();
for (const node of nodes) {
newBuilder.bypassNodes.splice(newBuilder.bypassNodes.indexOf(node), 1);
}
return newBuilder;
}

Expand Down
9 changes: 9 additions & 0 deletions src/types/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,21 @@ export interface NodeDef {
| TBoolInput
| TNumberInput;
};
optional?: {
[key: string]:
| [string[], { tooltip?: string }]
| [string, { tooltip?: string }]
| TStringInput
| TBoolInput
| TNumberInput;
};
hidden: {
[key: string]: string;
};
};
input_order: {
required: string[];
optional?: string[],
hidden: string[];
};
output: string[];
Expand Down
8 changes: 6 additions & 2 deletions src/types/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,9 @@ export class CustomEventError extends CallWrapperError {
}

export class ExecutionInterruptedError extends CallWrapperError {
name = ' ExecutionInterruptedError';
}
name = 'ExecutionInterruptedError';
}

export class MissingNodeError extends CallWrapperError {
name = 'MissingNodeError';
}

0 comments on commit b9463bf

Please sign in to comment.