Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement use_shell as parameter #2297

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config example.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copy this file and name it config.toml
# Edit the values to suit your needs

[settings]
use_shell = false # Use shell furing process run of sd-scripts oython code. Most secure is false but some systems may require it to be true to properly run sd-scripts.

# Default folders location
[model]
models_dir = "./models" # Pretrained model name or path
Expand Down
40 changes: 31 additions & 9 deletions kohya_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kohya_gui.custom_logging import setup_logging
from kohya_gui.localization_ext import add_javascript


def UI(**kwargs):
add_javascript(kwargs.get("language"))
css = ""
Expand All @@ -35,23 +36,36 @@ def UI(**kwargs):
interface = gr.Blocks(
css=css, title=f"Kohya_ss GUI {release}", theme=gr.themes.Default()
)

config = KohyaSSGUIConfig(config_file_path=kwargs.get("config"))

if config.is_config_loaded():
log.info(f"Loaded default GUI values from '{kwargs.get('config')}'...")

use_shell_flag = kwargs.get("use_shell", False)
if use_shell_flag == False:
use_shell_flag = config.get("settings.use_shell", False)
if use_shell_flag:
log.info("Using shell=True when running external commands...")

with interface:
with gr.Tab("Dreambooth"):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab(headless=headless, config=config)
) = dreambooth_tab(
headless=headless, config=config, use_shell_flag=use_shell_flag
)
with gr.Tab("LoRA"):
lora_tab(headless=headless, config=config)
lora_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
with gr.Tab("Textual Inversion"):
ti_tab(headless=headless, config=config)
ti_tab(headless=headless, config=config, use_shell_flag=use_shell_flag)
with gr.Tab("Finetuning"):
finetune_tab(headless=headless, config=config)
finetune_tab(
headless=headless, config=config, use_shell_flag=use_shell_flag
)
with gr.Tab("Utilities"):
utilities_tab(
train_data_dir_input=train_data_dir_input,
Expand All @@ -61,9 +75,10 @@ def UI(**kwargs):
enable_copy_info_button=True,
headless=headless,
config=config,
use_shell_flag=use_shell_flag,
)
with gr.Tab("LoRA"):
_ = LoRATools(headless=headless)
_ = LoRATools(headless=headless, use_shell_flag=use_shell_flag)
with gr.Tab("About"):
gr.Markdown(f"kohya_ss GUI release {release}")
with gr.Tab("README"):
Expand Down Expand Up @@ -102,6 +117,7 @@ def UI(**kwargs):
launch_kwargs["debug"] = True
interface.launch(**launch_kwargs)


if __name__ == "__main__":
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -141,11 +157,17 @@ def UI(**kwargs):

parser.add_argument("--use-ipex", action="store_true", help="Use IPEX environment")
parser.add_argument("--use-rocm", action="store_true", help="Use ROCm environment")

parser.add_argument("--do_not_share", action="store_true", help="Do not share the gradio UI")

parser.add_argument(
"--use_shell", action="store_true", help="Use shell environment"
)

parser.add_argument(
"--do_not_share", action="store_true", help="Do not share the gradio UI"
)

args = parser.parse_args()

# Set up logging
log = setup_logging(debug=args.debug)

Expand Down
6 changes: 4 additions & 2 deletions kohya_gui/blip_caption_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def caption_images(
beam_search: bool,
prefix: str = "",
postfix: str = "",
use_shell: bool = False,
) -> None:
"""
Automatically generates captions for images in the specified directory using the BLIP model.
Expand Down Expand Up @@ -96,7 +97,7 @@ def caption_images(
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command in the sd-scripts folder context
subprocess.run(run_cmd, env=env, cwd=f"{scriptdir}/sd-scripts")
subprocess.run(run_cmd, env=env, shell=use_shell, cwd=f"{scriptdir}/sd-scripts")


# Add prefix and postfix
Expand All @@ -115,7 +116,7 @@ def caption_images(
###


def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None):
def gradio_blip_caption_gui_tab(headless=False, default_train_dir=None, use_shell: bool = False):
from .common_gui import create_refresh_button

default_train_dir = (
Expand Down Expand Up @@ -205,6 +206,7 @@ def list_train_dirs(path):
beam_search,
prefix,
postfix,
gr.Checkbox(value=use_shell, visible=False),
],
show_progress=False,
)
Expand Down
17 changes: 9 additions & 8 deletions kohya_gui/class_command_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import psutil
import time
import gradio as gr
import shlex

from .custom_logging import setup_logging

# Set up logging
Expand All @@ -21,7 +21,7 @@ def __init__(self):
self.process = None
self.run_state = gr.Textbox(value="", visible=False)

def execute_command(self, run_cmd: str, **kwargs):
def execute_command(self, run_cmd: str, use_shell: bool = False, **kwargs):
"""
Execute a command if no other command is currently running.

Expand All @@ -36,11 +36,12 @@ def execute_command(self, run_cmd: str, **kwargs):
# log.info(f"{i}: {item}")

# Reconstruct the safe command string for display
command_to_run = ' '.join(run_cmd)
log.info(f"Executing command: {command_to_run}")
command_to_run = " ".join(run_cmd)
log.info(f"Executing command: {command_to_run} with shell={use_shell}")

# Execute the command securely
self.process = subprocess.Popen(run_cmd, **kwargs)
self.process = subprocess.Popen(run_cmd, **kwargs, shell=use_shell)
log.info("Command executed.")

def kill_command(self):
"""
Expand All @@ -64,9 +65,9 @@ def kill_command(self):
log.info(f"Error when terminating process: {e}")
else:
log.info("There is no running process to kill.")

return gr.Button(visible=True), gr.Button(visible=False)

def wait_for_training_to_end(self):
while self.is_running():
time.sleep(1)
Expand All @@ -81,4 +82,4 @@ def is_running(self):
Returns:
- bool: True if the command is running, False otherwise.
"""
return self.process and self.process.poll() is None
return self.process and self.process.poll() is None
11 changes: 11 additions & 0 deletions kohya_gui/class_gui_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,14 @@ def get(self, key: str, default=None):
# Return the final value
log.debug(f"Returned {data}")
return data

def is_config_loaded(self) -> bool:
"""
Checks if the configuration was loaded from a file.

Returns:
bool: True if the configuration was loaded from a file, False otherwise.
"""
is_loaded = self.config != {}
log.debug(f"Configuration was loaded from file: {is_loaded}")
return is_loaded
26 changes: 14 additions & 12 deletions kohya_gui/class_lora_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@


class LoRATools:
def __init__(self, headless: bool = False):
self.headless = headless

def __init__(
self,
headless: bool = False,
use_shell_flag: bool = False,
):
gr.Markdown("This section provide various LoRA tools...")
gradio_extract_dylora_tab(headless=headless)
gradio_convert_lcm_tab(headless=headless)
gradio_extract_lora_tab(headless=headless)
gradio_extract_lycoris_locon_tab(headless=headless)
gradio_merge_lora_tab = GradioMergeLoRaTab()
gradio_merge_lycoris_tab(headless=headless)
gradio_svd_merge_lora_tab(headless=headless)
gradio_resize_lora_tab(headless=headless)
gradio_verify_lora_tab(headless=headless)
gradio_extract_dylora_tab(headless=headless, use_shell=use_shell_flag)
gradio_convert_lcm_tab(headless=headless, use_shell=use_shell_flag)
gradio_extract_lora_tab(headless=headless, use_shell=use_shell_flag)
gradio_extract_lycoris_locon_tab(headless=headless, use_shell=use_shell_flag)
gradio_merge_lora_tab = GradioMergeLoRaTab(use_shell=use_shell_flag)
gradio_merge_lycoris_tab(headless=headless, use_shell=use_shell_flag)
gradio_svd_merge_lora_tab(headless=headless, use_shell=use_shell_flag)
gradio_resize_lora_tab(headless=headless, use_shell=use_shell_flag)
gradio_verify_lora_tab(headless=headless, use_shell=use_shell_flag)
3 changes: 1 addition & 2 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from easygui import msgbox, ynbox
from typing import Optional
from .custom_logging import setup_logging
from .class_command_executor import CommandExecutor

import os
import re
Expand All @@ -12,7 +11,6 @@
import json
import math
import shutil
import time

# Set up logging
log = setup_logging()
Expand All @@ -23,6 +21,7 @@
document_symbol = "\U0001F4C4" # 📄

scriptdir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))

if os.name == "nt":
scriptdir = scriptdir.replace("\\", "/")

Expand Down
22 changes: 17 additions & 5 deletions kohya_gui/convert_lcm_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
PYTHON = sys.executable


def convert_lcm(name, model_path, lora_scale, model_type):
def convert_lcm(
name,
model_path,
lora_scale,
model_type,
use_shell: bool = False,
):
run_cmd = rf'"{PYTHON}" "{scriptdir}/tools/lcm_convert.py"'

# Check if source model exist
Expand Down Expand Up @@ -62,7 +68,7 @@ def convert_lcm(name, model_path, lora_scale, model_type):
run_cmd.append("--ssd-1b")

# Log the command
log.info(' '.join(run_cmd))
log.info(" ".join(run_cmd))

# Set up the environment
env = os.environ.copy()
Expand All @@ -72,13 +78,13 @@ def convert_lcm(name, model_path, lora_scale, model_type):
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command
subprocess.run(run_cmd, env=env)
subprocess.run(run_cmd, env=env, shell=use_shell)

# Return a success message
log.info("Done extracting...")


def gradio_convert_lcm_tab(headless=False):
def gradio_convert_lcm_tab(headless=False, use_shell: bool = False):
current_model_dir = os.path.join(scriptdir, "outputs")
current_save_dir = os.path.join(scriptdir, "outputs")

Expand Down Expand Up @@ -183,6 +189,12 @@ def list_save_to(path):

extract_button.click(
convert_lcm,
inputs=[name, model_path, lora_scale, model_type],
inputs=[
name,
model_path,
lora_scale,
model_type,
gr.Checkbox(value=use_shell, visible=False),
],
show_progress=False,
)
6 changes: 4 additions & 2 deletions kohya_gui/convert_model_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def convert_model(
target_model_type,
target_save_precision_type,
unet_use_linear_projection,
use_shell: bool = False,
):
# Check for caption_text_input
if source_model_type == "":
Expand Down Expand Up @@ -107,7 +108,7 @@ def convert_model(
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command
subprocess.run(run_cmd, env=env)
subprocess.run(run_cmd, env=env, shell=use_shell)



Expand All @@ -116,7 +117,7 @@ def convert_model(
###


def gradio_convert_model_tab(headless=False):
def gradio_convert_model_tab(headless=False, use_shell: bool = False):
from .common_gui import create_refresh_button

default_source_model = os.path.join(scriptdir, "outputs")
Expand Down Expand Up @@ -276,6 +277,7 @@ def list_target_folder(path):
target_model_type,
target_save_precision_type,
unet_use_linear_projection,
gr.Checkbox(value=use_shell, visible=False),
],
show_progress=False,
)
7 changes: 6 additions & 1 deletion kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

# Setup huggingface
huggingface = None
use_shell = False

PYTHON = sys.executable

Expand Down Expand Up @@ -843,7 +844,7 @@ def train_model(

# Run the command

executor.execute_command(run_cmd=run_cmd, env=env)
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)

return (
gr.Button(visible=False),
Expand All @@ -859,10 +860,14 @@ def dreambooth_tab(
# logging_dir=gr.Textbox(),
headless=False,
config: KohyaSSGUIConfig = {},
use_shell_flag: bool = False,
):
dummy_db_true = gr.Checkbox(value=True, visible=False)
dummy_db_false = gr.Checkbox(value=False, visible=False)
dummy_headless = gr.Checkbox(value=headless, visible=False)

global use_shell
use_shell = use_shell_flag

with gr.Tab("Training"), gr.Column(variant="compact"):
gr.Markdown("Train a custom model using kohya dreambooth python code...")
Expand Down
Loading