This repository has been archived by the owner on Jul 24, 2024. It is now read-only.
forked from d8ahazard/sd_dreambooth_extension
-
Notifications
You must be signed in to change notification settings - Fork 0
/
module_dreambooth.py
92 lines (77 loc) · 3.3 KB
/
module_dreambooth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import asyncio
import logging
import os
import shutil
from core.handlers.models import ModelHandler
from core.handlers.status import StatusHandler
from core.handlers.websocket import SocketHandler
from core.modules.base.module_base import BaseModule
from fastapi import FastAPI
import scripts.api
from dreambooth.dataclasses.db_config import DreamboothConfig
from dreambooth.sd_to_diff import extract_checkpoint
logger = logging.getLogger(__name__)
class DreamboothModule(BaseModule):
def __init__(self):
self.name: str = "Dreambooth"
self.path = os.path.abspath(os.path.dirname(__file__))
self.model_handler = ModelHandler()
super().__init__(self.name, self.path)
def initialize(self, app: FastAPI, handler: SocketHandler):
self._initialize_api(app)
self._initialize_websocket(handler)
def _initialize_api(self, app: FastAPI):
return scripts.api.dreambooth_api(None, app)
def _initialize_websocket(self, handler: SocketHandler):
handler.register("train_dreambooth", _train_dreambooth)
handler.register("create_dreambooth", _create_model)
async def _train_dreambooth(data):
logger.debug(f"Train dreambooth called: {data}")
await asyncio.sleep(1)
return {"status": "Training started."}
async def _create_model(data):
mh = ModelHandler(user_name=data["user"] if "user" in data else None)
sh = StatusHandler(user_name=data["user"] if "user" in data else None)
msg_id = data["id"]
logger.debug(f"Full message: {data}")
data = data["data"] if "data" in data else None
logger.debug(f"Create model called: {data}")
model_name = data["new_model_name"] if "new_model_name" in data else None
src = data["new_model_src"]["path"]
shared_src = data["new_model_shared_src"]["path"] if "new_model_shared_src" in data else None
from_hub = data["create_from_hub"] if "create_from_hub" in data else False
logger.debug(f"SRC - {src} and {from_hub}")
if src and not from_hub:
sh.start(1, "Copying source weights.")
copy_model(model_name, src, data["512_model"], mh)
sh.step()
sh.end("Model created.")
else:
loop = asyncio.get_running_loop()
loop.create_task(extract_checkpoint(
model_name,
src,
shared_src,
True,
data["new_model_url"],
data["new_model_token"],
data["new_model_extract_ema"],
data["train_unfrozen"],
data["512_model"]
))
return {"name": "create_model", "message": "Creating model.", "id": msg_id}
def copy_model(model_name: str, src: str, is_512: bool, mh: ModelHandler):
logger.debug("Copying model!")
models_path = mh.models_path
logger.debug(f"Models paths: {models_path}")
model_dir = models_path[0]
dreambooth_models_path = os.path.join(model_dir, "dreambooth")
dest_dir = os.path.join(model_dir, "dreambooth", model_name, "working")
if os.path.exists(dest_dir):
shutil.rmtree(dest_dir, True)
if not os.path.exists(dest_dir):
shutil.copytree(src, dest_dir)
cfg = DreamboothConfig(model_name=model_name, src=src, resolution=is_512, models_path=dreambooth_models_path)
cfg.save()
else:
logger.debug(f"Destination directory '{dest_dir}' already exists, skipping copy.")