Skip to content

Commit

Permalink
add utility methods. (util_*)
Browse files Browse the repository at this point in the history
  • Loading branch information
mix1009 committed Jan 13, 2023
1 parent 56aedc7 commit a55770f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,18 @@ api.get_artist_categories()
api.get_artists()
```

### Utility methods
```
# save current model name
old_model = api.util_get_current_model()
# get list of available models
models = api.util_get_model_names()
# set model (use exact name)
api.util_set_model(models[0])
# set model (find closest match)
api.util_set_model('robodiffusion')
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name="webuiapi",
version="0.0.3",
version="0.1.0",
description="Python API client for AUTOMATIC1111/stable-diffusion-webui",
url="https://github.com/mix1009/sdwebuiapi",
author="ChunKoo Park",
Expand Down
30 changes: 30 additions & 0 deletions webuiapi/webuiapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,36 @@ def custom_post(self, endpoint, payload={}, baseurl=True):
response = self.session.post(url=url, json=payload)
return self._to_api_result(response)

def util_get_model_names(self):
return sorted([x['title'] for x in self.get_sd_models()])
def util_set_model(self, name, find_closest=True):
name = name.lower()
models = self.util_get_model_names()
found_model = None
if name in models:
found_model = name
elif find_closest:
import difflib
def str_simularity(a, b):
return difflib.SequenceMatcher(None, a, b).ratio()
max_sim = 0.0
max_model = models[0]
for model in models:
sim = str_simularity(name, model)
if sim >= max_sim:
max_sim = sim
max_model = model
found_model = max_model
if found_model:
print(f'loading {found_model}')
options = {}
options['sd_model_checkpoint'] = found_model
self.set_options(options)
print(f'model changed to {found_model}')
else:
print('model not found')
def util_get_current_model(self):
return self.get_options()['sd_model_checkpoint']


class Upscaler(str, Enum):
Expand Down

0 comments on commit a55770f

Please sign in to comment.