Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the workflow tool of comfyUI #9447

Merged
merged 3 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions api/core/tools/provider/builtin/comfyui/comfyui.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from typing import Any

import websocket
from yarl import URL

from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class ComfyUIProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
ws = websocket.WebSocket()
base_url = URL(credentials.get("base_url"))
ws_address = f"ws://{base_url.authority}/ws?clientId=test123"

try:
ComfyuiStableDiffusionTool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).validate_models()
ws.connect(ws_address)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}")
finally:
ws.close()
25 changes: 3 additions & 22 deletions api/core/tools/provider/builtin/comfyui/comfyui.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ identity:
label:
en_US: ComfyUI
zh_Hans: ComfyUI
pt_BR: ComfyUI
description:
en_US: ComfyUI is a tool for generating images which can be deployed locally.
zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。
pt_BR: ComfyUI is a tool for generating images which can be deployed locally.
icon: icon.png
tags:
- image
Expand All @@ -17,26 +15,9 @@ credentials_for_provider:
type: text-input
required: true
label:
en_US: Base URL
zh_Hans: ComfyUI服务器的Base URL
pt_BR: Base URL
en_US: The URL of ComfyUI Server
zh_Hans: ComfyUI服务器的URL
placeholder:
en_US: Please input your ComfyUI server's Base URL
zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL
pt_BR: Please input your ComfyUI server's Base URL
model:
type: text-input
required: true
label:
en_US: Model with suffix
zh_Hans: 模型, 需要带后缀
pt_BR: Model with suffix
placeholder:
en_US: Please input your model
zh_Hans: 请输入你的模型名称
pt_BR: Please input your model
help:
en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors
pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
url: https://github.com/comfyanonymous/ComfyUI#installing
url: https://docs.dify.ai/guides/tools/tool-configuration/comfyui
105 changes: 105 additions & 0 deletions api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import json
import random
import uuid

import httpx
from websocket import WebSocket
from yarl import URL


class ComfyUiClient:
def __init__(self, base_url: str):
self.base_url = URL(base_url)

def get_history(self, prompt_id: str):
res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id})
history = res.json()[prompt_id]
return history

def get_image(self, filename: str, subfolder: str, folder_type: str):
response = httpx.get(
str(self.base_url / "view"),
params={"filename": filename, "subfolder": subfolder, "type": folder_type},
)
return response.content

def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False):
# plan to support img2img in dify 0.10.0
with open(input_path, "rb") as file:
files = {"image": (name, file, "image/png")}
data = {"type": image_type, "overwrite": str(overwrite).lower()}

res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files)
return res

def queue_prompt(self, client_id: str, prompt: dict):
res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt})
prompt_id = res.json()["prompt_id"]
return prompt_id

def open_websocket_connection(self):
client_id = str(uuid.uuid4())
ws = WebSocket()
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
ws.connect(ws_address)
return ws, client_id

def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""):
"""
find the first KSampler, then can find the prompt node through it.
"""
prompt = origin_prompt.copy()
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1)
positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0]
prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt

if negative_prompt != "":
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
return prompt

def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
node_ids = list(prompt.keys())
finished_nodes = []

while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "progress":
data = message["data"]
current_step = data["value"]
print("In K-Sampler -> Step: ", current_step, " of: ", data["max"])
if message["type"] == "execution_cached":
data = message["data"]
for itm in data["nodes"]:
if itm not in finished_nodes:
finished_nodes.append(itm)
print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done")
if message["type"] == "executing":
data = message["data"]
if data["node"] not in finished_nodes:
finished_nodes.append(data["node"])
print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done")

if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue

def generate_image_by_prompt(self, prompt: dict):
try:
ws, client_id = self.open_websocket_connection()
prompt_id = self.queue_prompt(client_id, prompt)
self.track_progress(prompt, ws, prompt_id)
history = self.get_history(prompt_id)
images = []
for output in history["outputs"].values():
for img in output.get("images", []):
image_data = self.get_image(img["filename"], img["subfolder"], img["type"])
images.append(image_data)
return images
finally:
ws.close()
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
identity:
name: txt2img workflow
name: txt2img
author: Qun
label:
en_US: Txt2Img Workflow
zh_Hans: Txt2Img Workflow
pt_BR: Txt2Img Workflow
en_US: Txt2Img
zh_Hans: Txt2Img
pt_BR: Txt2Img
description:
human:
en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.
Expand Down
32 changes: 32 additions & 0 deletions api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
from typing import Any

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

from .comfyui_client import ComfyUiClient


class ComfyUIWorkflowTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
comfyui = ComfyUiClient(self.runtime.credentials["base_url"])

positive_prompt = tool_parameters.get("positive_prompt")
negative_prompt = tool_parameters.get("negative_prompt")
workflow = tool_parameters.get("workflow_json")

try:
origin_prompt = json.loads(workflow)
except:
return self.create_text_message("the Workflow JSON is not correct")

prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt)
images = comfyui.generate_image_by_prompt(prompt)
result = []
for img in images:
result.append(
self.create_blob_message(
blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
)
)
return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
identity:
name: workflow
author: hjlarry
label:
en_US: workflow
zh_Hans: 工作流
description:
human:
en_US: Run ComfyUI workflow.
zh_Hans: 运行ComfyUI工作流。
llm: Run ComfyUI workflow.
parameters:
- name: positive_prompt
type: string
label:
en_US: Prompt
zh_Hans: 提示词
llm_description: Image prompt, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
form: llm
- name: negative_prompt
type: string
label:
en_US: Negative Prompt
zh_Hans: 负面提示词
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
form: llm
- name: workflow_json
type: string
required: true
label:
en_US: Workflow JSON
human_description:
en_US: exported from ComfyUI workflow
zh_Hans: 从ComfyUI的工作流中导出
form: form