Skip to content

Commit

Permalink
Add support to run scripts from API
Browse files Browse the repository at this point in the history
For more context see this pr AUTOMATIC1111/stable-diffusion-webui#6469
  • Loading branch information
pablogpz committed Jan 17, 2023
1 parent a2cee02 commit fe269dc
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 44 deletions.
77 changes: 77 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,83 @@ result4.images[1]
```
![extra_batch_images_2](https://user-images.githubusercontent.com/1288793/200459542-aa8547a0-f6db-436b-bec1-031a93a7b1d4.jpg)

### Scripts support
Scripts from AUTOMATIC1111's Web UI are supported, but there aren't official models that define a script's interface.

To find out the list of arguments that are accepted by a particular script look up the associated python file from
AUTOMATIC1111's repo `scripts/[script_name].py`. Search for its `run(p, **args)` function and the arguments that come
after 'p' is the list of accepted arguments

#### Example for X/Y Plot script:
```
(scripts/xy_grid.py file from AUTOMATIC1111's repo)
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
...
```
List of accepted arguments:
* _x_type_: Index of the axis for X axis. Indexes start from [0: Nothing]
* _x_values_: String of comma-separated values for the X axis
* _y_type_: Index of the axis type for Y axis. As the X axis, indexes start from [0: Nothing]
* _y_values_: String of comma-separated values for the Y axis
* _draw_legend_: "True" or "False". IMPORTANT: It needs to be a string and not a Boolean value
* _include_lone_images_: "True" or "False". IMPORTANT: It needs to be a string and not a Boolean value
* _no_fixed_seeds_: "True" or "False". IMPORTANT: It needs to be a string and not a Boolean value
```
# Available Axis options
XYPlotAvailableScripts = [
"Nothing",
"Seed",
"Var. seed",
"Var. strength",
"Steps",
"CFG Scale",
"Prompt S/R",
"Prompt order",
"Sampler",
"Checkpoint Name",
"Hypernetwork",
"Hypernet str.",
"Sigma Churn",
"Sigma min",
"Sigma max",
"Sigma noise",
"Eta",
"Clip skip",
"Denoising",
"Hires upscaler",
"Cond. Image Mask Weight",
"VAE",
"Styles"
]
# Example call
XAxisType = "Steps"
XAxisValues = "8,16,32,64"
YAxisType = "Sampler"
YAxisValues = "k_euler_a, k_euler, k_lms, plms, k_heun, ddim, k_dpm_2, k_dpm_2_a"
drawLegend = "True"
includeSeparateImages = "False"
keepRandomSeed = "False"
result = api.txt2img(
prompt="cute squirrel",
negative_prompt="ugly, out of frame",
seed=1003,
styles=["anime"],
cfg_scale=7,
script_name="X/Y Plot",
script_args=[
XYPlotAvailableScripts.index(XAxisType),
XAxisValues,
XYPlotAvailableScripts.index(YAxisType),
YAxisValues,
drawLegend,
includeSeparateImages,
keepRandomSeed
]
)
```

### Configuration APIs
```
Expand Down
87 changes: 49 additions & 38 deletions webuiapi/webuiapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class WebUIApiResult:
images: list
parameters: dict
info: dict

@property
def image(self):
return self.images[0]
Expand All @@ -35,28 +35,28 @@ def __init__(self,
baseurl = f'https://{host}:{port}/sdapi/v1'
else:
baseurl = f'http://{host}:{port}/sdapi/v1'

self.baseurl = baseurl
self.default_sampler = sampler
self.default_steps = steps

self.session = requests.Session()

def set_auth(self, username, password):
self.session.auth = (username, password)

def _to_api_result(self, response):

if response.status_code != 200:
raise RuntimeError(response.status_code, response.text)

r = response.json()
images = []
if 'images' in r.keys():
images = [Image.open(io.BytesIO(base64.b64decode(i))) for i in r['images']]
elif 'image' in r.keys():
images = [Image.open(io.BytesIO(base64.b64decode(r['image'])))]

info = ''
if 'info' in r.keys():
try:
Expand All @@ -71,8 +71,8 @@ def _to_api_result(self, response):
parameters = r['parameters']

return WebUIApiResult(images, parameters, info)
def txt2img(self,

def txt2img(self,
enable_hr=False,
denoising_strength=0.0,
firstphase_width=0,
Expand All @@ -99,18 +99,22 @@ def txt2img(self,
s_noise=1,
override_settings={},
override_settings_restore_afterwards=True,
sampler_name=None, # use this instead of sampler_index
sampler_name=None, # use this instead of sampler_index
sampler_index=None,
steps=None,
):
script_name=None,
script_args=None # List of arguments for the script "script_name"
):
if sampler_index is None:
sampler_index = self.default_sampler
if sampler_name is None:
sampler_name = self.default_sampler
if steps is None:
steps = self.default_steps
if script_args is None:
script_args = []

payload = {
payload = {
"enable_hr": enable_hr,
"denoising_strength": denoising_strength,
"firstphase_width": firstphase_width,
Expand Down Expand Up @@ -140,14 +144,15 @@ def txt2img(self,
"override_settings_restore_afterwards": override_settings_restore_afterwards,
"sampler_name": sampler_name,
"sampler_index": sampler_index,
"script_name": script_name,
"script_args": script_args
}
response = self.session.post(url=f'{self.baseurl}/txt2img', json=payload)
return self._to_api_result(response)


def img2img(self,
images=[], # list of PIL Image
mask_image=None, # PIL Image mask
images=[], # list of PIL Image
mask_image=None, # PIL Image mask
resize_mode=0,
denoising_strength=0.75,
mask_blur=4,
Expand Down Expand Up @@ -180,15 +185,19 @@ def img2img(self,
override_settings_restore_afterwards=True,
include_init_images=False,
steps=None,
sampler_name=None, # use this instead of sampler_index
sampler_name=None, # use this instead of sampler_index
sampler_index=None,
):
script_name=None,
script_args=None # List of arguments for the script "script_name"
):
if sampler_name is None:
sampler_name = self.default_sampler
if sampler_index is None:
sampler_index = self.default_sampler
if steps is None:
steps = self.default_steps
if script_args is None:
script_args = []

payload = {
"init_images": [b64_img(x) for x in images],
Expand Down Expand Up @@ -226,15 +235,17 @@ def img2img(self,
"sampler_name": sampler_name,
"sampler_index": sampler_index,
"include_init_images": include_init_images,
"script_name": script_name,
"script_args": script_args
}
if mask_image is not None:
payload['mask']= b64_img(mask_image)
payload['mask'] = b64_img(mask_image)

response = self.session.post(url=f'{self.baseurl}/img2img', json=payload)
return self._to_api_result(response)

def extra_single_image(self,
image, # PIL Image
image, # PIL Image
resize_mode=0,
show_extras_results=True,
gfpgan_visibility=0,
Expand All @@ -248,7 +259,7 @@ def extra_single_image(self,
upscaler_2="None",
extras_upscaler_2_visibility=0,
upscale_first=False,
):
):
payload = {
"resize_mode": resize_mode,
"show_extras_results": show_extras_results,
Expand All @@ -265,13 +276,13 @@ def extra_single_image(self,
"upscale_first": upscale_first,
"image": b64_img(image),
}

response = self.session.post(url=f'{self.baseurl}/extra-single-image', json=payload)
return self._to_api_result(response)

def extra_batch_images(self,
images, # list of PIL images
name_list=None, # list of image names
images, # list of PIL images
name_list=None, # list of image names
resize_mode=0,
show_extras_results=True,
gfpgan_visibility=0,
Expand All @@ -285,21 +296,21 @@ def extra_batch_images(self,
upscaler_2="None",
extras_upscaler_2_visibility=0,
upscale_first=False,
):
):
if name_list is not None:
if len(name_list) != len(images):
raise RuntimeError('len(images) != len(name_list)')
else:
name_list = [f'image{i+1:05}' for i in range(len(images))]
name_list = [f'image{i + 1:05}' for i in range(len(images))]
images = [b64_img(x) for x in images]

image_list = []
for name, image in zip(name_list, images):
image_list.append({
"data": image,
"name": name
})

payload = {
"resize_mode": resize_mode,
"show_extras_results": show_extras_results,
Expand All @@ -316,16 +327,16 @@ def extra_batch_images(self,
"upscale_first": upscale_first,
"imageList": image_list,
}

response = self.session.post(url=f'{self.baseurl}/extra-batch-images', json=payload)
return self._to_api_result(response)

# XXX 500 error (2022/12/26)
def png_info(self, image):
payload = {
"image": b64_img(image),
}

response = self.session.post(url=f'{self.baseurl}/png-info', json=payload)
return self._to_api_result(response)

Expand All @@ -334,20 +345,20 @@ def interrogate(self, image):
payload = {
"image": b64_img(image),
}

response = self.session.post(url=f'{self.baseurl}/interrogate', json=payload)
return self._to_api_result(response)

def get_options(self):
def get_options(self):
response = self.session.get(url=f'{self.baseurl}/options')
return response.json()

# working (2022/11/21)
def set_options(self, options):
def set_options(self, options):
response = self.session.post(url=f'{self.baseurl}/options', json=options)
return response.json()

def get_cmd_flags(self):
def get_cmd_flags(self):
response = self.session.get(url=f'{self.baseurl}/cmd-flags')
return response.json()
def get_samplers(self):
Expand Down Expand Up @@ -380,7 +391,7 @@ def get_artists(self):
def refresh_checkpoints(self):
response = self.session.post(url=f'{self.baseurl}/refresh-checkpoints')
return response.json()

def get_endpoint(self, endpoint, baseurl):
if baseurl:
return f'{self.baseurl}/{endpoint}'
Expand Down Expand Up @@ -431,7 +442,7 @@ def util_get_current_model(self):
return self.get_options()['sd_model_checkpoint']


class Upscaler(str, Enum):
class Upscaler(str, Enum):
none = 'None'
Lanczos = 'Lanczos'
Nearest = 'Nearest'
Expand Down
83 changes: 77 additions & 6 deletions webuiapi_demo.ipynb

Large diffs are not rendered by default.

0 comments on commit fe269dc

Please sign in to comment.