Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run scripts from API #6469

Merged
merged 5 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from secrets import compare_digest

import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras
Expand All @@ -28,8 +28,13 @@ def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")

def script_name_to_index(name, scripts):
try:
return [script.title().lower() for script in scripts].index(name.lower())
except:
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")

def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
Expand Down Expand Up @@ -144,6 +149,14 @@ def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})

def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
if txt2imgreq.script_name is not None:
if scripts.scripts_txt2img.scripts == []:
scripts.scripts_txt2img.initialize_scripts(False)
ui.create_ui()

script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts)
script = scripts.scripts_txt2img.selectable_scripts[script_idx]

populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
Expand All @@ -153,11 +166,20 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on

args = vars(populate)
args.pop('script_name', None)

with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)

shared.state.begin()
processed = process_images(p)
if 'script' in locals():
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args)
else:
processed = process_images(p)
shared.state.end()


Expand All @@ -170,6 +192,14 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")

if img2imgreq.script_name is not None:
if scripts.scripts_img2img.scripts == []:
scripts.scripts_img2img.initialize_scripts(True)
ui.create_ui()

script_idx = script_name_to_index(img2imgreq.script_name, scripts.scripts_img2img.selectable_scripts)
script = scripts.scripts_img2img.selectable_scripts[script_idx]

mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
Expand All @@ -186,13 +216,20 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):

args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
args.pop('script_name', None)

with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]

shared.state.begin()
processed = process_images(p)
if 'script' in locals():
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
processed = scripts.scripts_img2img.run(p, *p.script_args)
else:
processed = process_images(p)
shared.state.end()

b64images = list(map(encode_pil_to_base64, processed.images))
Expand Down
4 changes: 2 additions & 2 deletions modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def generate_model(self):
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()

StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()

class TextToImageResponse(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)

Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
self.seed_resize_from_w = 0

self.scripts = None
self.script_args = None
self.script_args = script_args
self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None
Expand Down
2 changes: 2 additions & 0 deletions scripts/sd_upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def ui(self, is_img2img):
return [info, overlap, upscaler_index, scale_factor]

def run(self, p, _, overlap, upscaler_index, scale_factor):
if isinstance(upscaler_index, str):
upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
processing.fix_seed(p)
upscaler = shared.sd_upscalers[upscaler_index]

Expand Down