Skip to content

Commit

Permalink
Merge pull request #14475 from Learwin/negative_prompt
Browse files Browse the repository at this point in the history
Adding negative prompts to Loras in extra networks
  • Loading branch information
AUTOMATIC1111 authored Dec 31, 2023
2 parents a84e842 + b6f74e9 commit f3af8c8
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 13 deletions.
9 changes: 7 additions & 2 deletions extensions-builtin/Lora/ui_edit_user_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ def __init__(self, ui, tabname, page):
self.slider_preferred_weight = None
self.edit_notes = None

def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight
user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes

self.write_user_metadata(name, user_metadata)
Expand Down Expand Up @@ -127,6 +128,7 @@ def put_values_into_components(self, name):
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)),
user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
]
Expand Down Expand Up @@ -162,7 +164,7 @@ def create_editor(self):
self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)

self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt:
with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
Expand Down Expand Up @@ -198,6 +200,7 @@ def select_tag(activation_text, evt: gr.SelectData):
self.taginfo,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_negative_text,
row_random_prompt,
random_prompt,
]
Expand All @@ -211,7 +214,9 @@ def select_tag(activation_text, evt: gr.SelectData):
self.select_sd_version,
self.edit_activation_text,
self.slider_preferred_weight,
self.edit_negative_text,
self.edit_notes,
]


self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
5 changes: 5 additions & 0 deletions extensions-builtin/Lora/ui_extra_networks_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def create_item(self, name, index=None, enable_filter=True):
if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text)

negative_prompt = item["user_metadata"].get("negative text")
item["negative_prompt"] = quote_js("")
if negative_prompt:
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')

sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version
Expand Down
31 changes: 21 additions & 10 deletions javascript/extraNetworks.js
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,19 @@ onUiLoaded(setupExtraNetworks);
var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;

function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
var m = text.match(re_extranet);
var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/;
var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) {
var m = text.match(isNeg ? re_extranet_neg : re_extranet);
var replaced = false;
var newTextareaText;
if (m) {
var extraTextBeforeNet = opts.extra_networks_add_text_separator;
var extraTextAfterNet = m[2];
var partToSearch = m[1];
var foundAtPosition = -1;
newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) {
m = found.match(re_extranet);
newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) {
m = found.match(isNeg ? re_extranet_neg : re_extranet);
if (m[1] == partToSearch) {
replaced = true;
foundAtPosition = pos;
Expand All @@ -205,7 +207,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
});

if (foundAtPosition >= 0) {
if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
}
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
Expand All @@ -230,14 +232,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
return false;
}

function cardClicked(tabname, textToAdd, allowNegativePrompt) {
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
function updatePromptArea(text, textArea, isNeg) {

if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) {
textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd;
if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) {
textArea.value = textArea.value + opts.extra_networks_add_text_separator + text;
}

updateInput(textarea);
updateInput(textArea);
}

function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) {
if (textToAddNegative.length > 0) {
updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"));
updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true);
} else {
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
updatePromptArea(textToAdd, textarea);
}
}

function saveCardPreview(event, tabname, filename) {
Expand Down
5 changes: 4 additions & 1 deletion modules/ui_extra_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ def create_html_for_item(self, item, tabname):

onclick = item.get("onclick", None)
if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
if "negative_prompt" in item:
onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {item["negative_prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
else:
onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {'""'}, {"true" if self.allow_negative_prompt else "false"})""") + '"'

height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
Expand Down

0 comments on commit f3af8c8

Please sign in to comment.