Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev bbs patch 2 #229

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 88 additions & 3 deletions js/model_apply.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import { hideWidget } from './subassembly/tools.js'
app.registerExtension({
name: "bizyair.siliconcloud.share.lora.loader.new",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "BizyAir_LoraLoaderNew") {
if (nodeData.name === "BizyAir_LoraLoader") {

function setWigetCallback() {

}
Expand All @@ -22,7 +23,7 @@ app.registerExtension({


async nodeCreated(node) {
if (node?.comfyClass === "BizyAir_LoraLoaderNew") {
if (node?.comfyClass === "BizyAir_LoraLoader") {
const original_onMouseDown = node.onMouseDown;

let lastClickTime = 0;
Expand All @@ -32,7 +33,7 @@ app.registerExtension({

node.onMouseDown = function( e, pos, canvas ) {
const lora_name = this.widgets.find(widget => widget.name === "lora_name")
const model_widget = this.widgets.find(widget => widget.name === "model_version_id")
const model_widget = this.widgets.find(widget => widget.name === "model_version_id") // hidden
if (pos[1] - lora_name.last_y > 0 && pos[1] - lora_name.last_y < 20) {
const litecontextmenu = document.querySelector('.litegraph.litecontextmenu')
if (litecontextmenu) {
Expand Down Expand Up @@ -66,3 +67,87 @@ app.registerExtension({
}
}
})




app.registerExtension({
name: "bizyair.siliconcloud.share.controlnet.loader.new",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "BizyAir_ControlNetLoader") {

function setWigetCallback() {
console.log(this)
}
const onNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function () {
onNodeCreated?.apply(this, arguments);
setWigetCallback.call(this, arguments);
};
}
},



async nodeCreated(node) {
if (node?.comfyClass === "BizyAir_ControlNetLoader") {
const original_onMouseDown = node.onMouseDown;

let lastClickTime = 0;
const DEBOUNCE_DELAY = 300; // 300ms防抖延迟

hideWidget(node, "model_version_id");

node.onMouseDown = function( e, pos, canvas ) {
console.log(this.size, this.widgets)
const lora_name = this.widgets.find(widget => widget.name === "control_net_name")
const model_widget = this.widgets.find(widget => widget.name === "model_version_id") // hidden
if (pos[1] - lora_name.last_y > 0 && pos[1] - lora_name.last_y < 20) {
const litecontextmenu = document.querySelector('.litegraph.litecontextmenu')
if (litecontextmenu) {
litecontextmenu.style.display = 'none'
}
e.stopImmediatePropagation();
e.preventDefault();
if (e.button !== 0) {
return false;
}
const currentTime = new Date().getTime();
if (currentTime - lastClickTime < DEBOUNCE_DELAY) {
return false;
}
lastClickTime = currentTime;
bizyAirLib.showModelSelect({
modelType:["Controlnet"],
selectedBaseModels:["Flux.1 D","SDXL"],
onApply: (version,model) => {
if(model && model_widget && lora_name && version){
lora_name.value = model
model_widget.value = version.id
}
}
})
// const aasd = dialog({
// content: $el('div', {
// style: {
// width: '1000px',
// height: '500px'
// },
// onclick: () => {
// lora_name.value = '123243'
// aasd.close()
// }
// }, ["123"]),
// noText: 'Close',
// onClose: () => {
// console.log('closed')
// }
// })
return false; // 确保事件结束
} else {
return original_onMouseDown?.apply(this, arguments);
}
}
}
}
})
69 changes: 60 additions & 9 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import comfy

from bizyair import BizyAirBaseNode, BizyAirNodeIO, create_node_data, data_types
from bizyair.configs.conf import config_manager
from bizyair.path_utils import path_manager as folder_paths

LOGO = "☁️"
PREFIX = f"{LOGO}BizyAir"

MAX_RESOLUTION = 16384 # https://github.com/comfyanonymous/ComfyUI/blob/7390ff3b1ec2e15017ba4a52d6eaabc4aa4636e3/nodes.py#L45


Expand Down Expand Up @@ -241,7 +241,7 @@ def decode(self, vae, samples):
return new_vae.send_request()


class BizyAir_LoraLoader(BizyAirBaseNode):
class BizyAir_LoraLoader_Legacy(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
Expand All @@ -262,7 +262,7 @@ def INPUT_TYPES(s):

RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
RETURN_NAMES = ("MODEL", "CLIP")

DEPRECATED = True
FUNCTION = "load_lora"

CATEGORY = f"{PREFIX}/loaders"
Expand Down Expand Up @@ -345,7 +345,7 @@ def encode(self, vae, pixels, mask, grow_mask_by=6):
return new_vae.send_request()


class BizyAir_ControlNetLoader(BizyAirBaseNode):
class BizyAir_ControlNetLoader_Legacy(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
Expand Down Expand Up @@ -374,6 +374,52 @@ def load_controlnet(self, control_net_name):
return (node,)


class BizyAir_ControlNetLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"control_net_name": (
[
"to choose",
],
),
"model_version_id": ("STRING", {"default": "", "multiline": False}),
}
}

RETURN_TYPES = (data_types.CONTROL_NET,)
RETURN_NAMES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"

CATEGORY = f"{PREFIX}/loaders"

@classmethod
def VALIDATE_INPUTS(cls, control_net_name, model_version_id):
if control_net_name == "to choose":
return False
if model_version_id is not None and model_version_id != "":
return True
return True

def load_controlnet(self, control_net_name, model_version_id):
if model_version_id is not None and model_version_id != "":
control_net_name = (
f"{config_manager.get_model_version_id_prefix()}{model_version_id}"
)

node_data = create_node_data(
class_type="ControlNetLoader",
inputs={
"control_net_name": control_net_name,
},
outputs={"slot_index": 0},
)
assigned_id = self.assigned_id
node = BizyAirNodeIO(assigned_id, {assigned_id: node_data})
return (node,)


class BizyAir_ControlNetApplyAdvanced(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -771,7 +817,7 @@ def INPUT_TYPES(s):
CATEGORY = "conditioning/inpaint"


class SharedLoraLoader(BizyAir_LoraLoader):
class SharedLoraLoader(BizyAir_LoraLoader_Legacy):
@classmethod
def INPUT_TYPES(s):
return {
Expand Down Expand Up @@ -958,7 +1004,7 @@ def INPUT_TYPES(s):
CATEGORY = "conditioning"


class BizyAir_LoraLoaderNew(BizyAirBaseNode):
class BizyAir_LoraLoader(BizyAirBaseNode):
@classmethod
def INPUT_TYPES(s):
return {
Expand Down Expand Up @@ -1000,21 +1046,26 @@ def load_lora(
lora_name,
strength_model,
strength_clip,
model_version_id=None,
model_version_id: str = None,
):
assigned_id = self.assigned_id
new_model: BizyAirNodeIO = model.copy(assigned_id)
new_clip: BizyAirNodeIO = clip.copy(assigned_id)
instances: List[BizyAirNodeIO] = [new_model, new_clip]
BIZYAIR_MODEL_VERSION_ID_PREFIX = "BIZYAIR_MODEL_VERSION_ID:"

if model_version_id is not None and model_version_id != "":
# use model version id as lora name
lora_name = (
f"{config_manager.get_model_version_id_prefix()}{model_version_id}"
)

for slot_index, ins in zip(range(2), instances):
ins.add_node_data(
class_type="LoraLoader",
inputs={
"model": model,
"clip": clip,
"lora_name": f"{BIZYAIR_MODEL_VERSION_ID_PREFIX}{model_version_id}",
"lora_name": lora_name,
"strength_model": strength_model,
"strength_clip": strength_clip,
},
Expand Down
21 changes: 19 additions & 2 deletions src/bizyair/commands/processors/prompt_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
BIZYAIR_DEV_REQUEST_URL,
BIZYAIR_SERVER_ADDRESS,
)
from bizyair.configs.conf import ModelRule
from bizyair.configs.conf import ModelRule, config_manager
from bizyair.path_utils import (
convert_prompt_label_path_to_real_path,
guess_url_from_node,
Expand Down Expand Up @@ -83,13 +83,30 @@ def validate_input(


class PromptProcessor(Processor):
def _exec_info(self, prompt: Dict[str, Dict[str, Any]]):
exec_info = {
"model_version_ids": [],
}
model_version_id_prefix = config_manager.get_model_version_id_prefix()
for node_id, node_data in prompt.items():
for k, v in node_data.get("inputs", {}).items():
if isinstance(v, str) and v.startswith(model_version_id_prefix):
exec_info["model_version_ids"].append(
v[len(model_version_id_prefix) :]
)
return exec_info

def process(
self, url: str, prompt: Dict[str, Dict[str, Any]], last_node_ids: List[str]
):
return client.send_request(
url=url,
data=json.dumps(
{"prompt": prompt, "last_node_id": last_node_ids[0]}
{
"prompt": prompt,
"last_node_id": last_node_ids[0],
"exec_info": self._exec_info(prompt),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收集 exec_info 信息发往服务端, 比如:exec_info = {'model_version_ids': ['5068']}

}
).encode("utf-8"),
)

Expand Down
1 change: 1 addition & 0 deletions src/bizyair/commands/servers/prompt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def execute(
*args,
**kwargs,
):

prompt = encode_data(prompt)
if BIZYAIR_DEBUG:
debug_info = {
Expand Down
6 changes: 5 additions & 1 deletion src/bizyair/configs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def get_filenames(self, folder_name: str) -> List[str]:
class ConfigManager:
def __init__(self, model_path_config: str, model_rule_config: str):
self.model_path_manager = ModelPathManager(config_path=model_path_config)
self.model_rule_config = load_config_file(model_rule_config)
self.model_rules = ModelRuleManager(
model_rules=load_config_file(model_rule_config)["model_rules"]
model_rules=self.model_rule_config["model_rules"]
)

def get_filenames(self, folder_name: str) -> List[str]:
Expand All @@ -89,6 +90,9 @@ def get_rules(self, class_type: str) -> List[ModelRule]:
class_type = class_type[8:]
return self.model_rules.find_rules(class_type)

def get_model_version_id_prefix(self):
return self.model_rule_config["model_version_config"]["model_version_id_prefix"]


model_path_config = os.path.join(os.path.dirname(__file__), "models.json")
model_rule_config = os.path.join(os.path.dirname(__file__), "models.yaml")
Expand Down
4 changes: 4 additions & 0 deletions src/bizyair/configs/models.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Common configuration
model_version_config:
model_version_id_prefix: "BIZYAIR_MODEL_VERSION_ID:"


model_hub:
find_model:
route: /models/files
Expand Down