Skip to content

Commit

Permalink
Add a script to manage models
Browse files Browse the repository at this point in the history
  • Loading branch information
Unit1208 committed Oct 6, 2024
1 parent 51d6104 commit 31a2921
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,4 @@ cython_debug/
*.webp
*.png
*.jpg
tmp/
163 changes: 163 additions & 0 deletions scripts/modify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from pathlib import Path
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
import json
import os
import hashlib
import requests
import hashlib
from tqdm import tqdm
import re

STYLE_OPTIONS = ["anime", "artistic", "furry", "generalist", "other", "realistic"]
ACTION_OPTIONS = ["add", "update", "remove"]
style_completer = WordCompleter(STYLE_OPTIONS, ignore_case=True)
baseline_completer = WordCompleter(
[
"stable diffusion 1",
"stable diffusion 2",
"stable_diffusion_xl",
"stable_cascade",
"flux_1",
],
ignore_case=True,
)
action_completer = WordCompleter(ACTION_OPTIONS, ignore_case=True)


def load_models(json_file):
if os.path.exists(json_file):
with open(json_file, "r") as f:
return json.load(f)
return {}


def save_models(json_file, models):
with open(json_file, "w") as f:
json.dump(models, f, indent=4)


def download_and_get_size(url):
response = requests.get(url, stream=True)
total_size = int(
response.headers.get("Content-Length", 0)
) # Total size of the file
content_disposition = response.headers.get("Content-Disposition")

if content_disposition:
match = re.findall('filename="(.+)"', content_disposition)
if match:
file_name = match[0]
else:
file_name = url.split("/")[-1]
else:
file_name = url.split("/")[-1]

file_size = 0
sha256 = hashlib.sha256()
file_path: Path = Path.cwd() / "tmp" / file_name
file_path.parent.mkdir(parents=True, exist_ok=True)

chunk_size = 8192
with open(file_path, "wb") as f, tqdm(
total=total_size, unit="B", unit_scale=True, desc=file_name
) as progress_bar:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk: # Filter out keep-alive chunks
file_size += len(chunk)
sha256.update(chunk)
f.write(chunk)
progress_bar.update(len(chunk))

return file_name, file_size, sha256.hexdigest().upper(), file_path


def get_model_info():
name = prompt("Model Name: ")
baseline = prompt("Baseline: ", completer=baseline_completer)
inpainting = prompt("Inpainting (t/f): ", default="false").lower()[0] == "t"
description = prompt("Description: ")
version = prompt("Version: ")
style = prompt("Style (realistic, artistic, etc.): ", completer=style_completer)
homepage = prompt("Homepage URL: ")
nsfw = prompt("NSFW (t/f): ").lower()[0] == "t"

url = prompt("Download URL: ")
file_name, size_on_disk, sha256, file_path = download_and_get_size(url)
if prompt("Delete downloaded model (t/f): ", default="t").lower()[0] == "t":
os.remove(file_path)
config = {
"files": [{"path": file_name, "sha256sum": sha256}],
"download": [{"file_name": file_name, "file_path": "", "file_url": url}],
}

return {
"name": name,
"baseline": baseline,
"type": "ckpt",
"inpainting": inpainting,
"description": description,
"version": version,
"style": style,
"homepage": homepage,
"nsfw": nsfw,
"download_all": False,
"config": config,
"size_on_disk_bytes": size_on_disk,
}


def add_model(json_file):
models = load_models(json_file)
model = get_model_info()

models[model["name"]] = model
save_models(json_file, models)
print(f"Model '{model['name']}' added successfully!")


def update_model(json_file):
models = load_models(json_file)
name = prompt("Model Name to update: ")

if name not in models:
print(f"Model '{name}' not found.")
return

model = get_model_info()
models[name] = model
save_models(json_file, models)
print(f"Model '{name}' updated successfully!")


def remove_model(json_file):
models = load_models(json_file)
name = prompt("Model Name to remove: ")

if name in models:
del models[name]
save_models(json_file, models)
print(f"Model '{name}' removed successfully!")
else:
print(f"Model '{name}' not found.")


def main():
json_file = "stable_diffusion.json"

action = prompt(
"Choose action (add/update/remove): ", completer=action_completer
).lower()[0]

if action == "a":
add_model(json_file)
elif action == "u":
update_model(json_file)
elif action == "r":
remove_model(json_file)
else:
print("Invalid action. Please choose from add, update, or remove.")


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions scripts/requirements.modify.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
prompt_toolkit~=3.0.48
tqdm~=4.66.5

0 comments on commit 31a2921

Please sign in to comment.