From 068adf95a0f2d19dcca2c995bd67c53b468230f2 Mon Sep 17 00:00:00 2001 From: rfletchr Date: Thu, 23 Nov 2023 01:48:15 +0000 Subject: [PATCH 1/3] implement path setup wizard --- src/cocktail/__main__.py | 1 + src/cocktail/core/database/api.py | 4 +- src/cocktail/ui/settings/controller.py | 31 ++++- src/cocktail/ui/startup/controller.py | 61 +++++++--- src/cocktail/ui/startup/view.py | 157 ++++++++++++++++++++++++- 5 files changed, 232 insertions(+), 22 deletions(-) diff --git a/src/cocktail/__main__.py b/src/cocktail/__main__.py index 1c750a0..8f8a09f 100644 --- a/src/cocktail/__main__.py +++ b/src/cocktail/__main__.py @@ -64,6 +64,7 @@ def start(): startup_controller = StartupController() startup_controller.complete.connect(start) + startup_controller.canceled.connect(app.quit) startup_controller.start() app.exec() diff --git a/src/cocktail/core/database/api.py b/src/cocktail/core/database/api.py index 438bf03..8de07a7 100644 --- a/src/cocktail/core/database/api.py +++ b/src/cocktail/core/database/api.py @@ -129,13 +129,13 @@ def get_update_period(db): return data_classes.Period.AllTime -def get_db_path(): +def get_database_path(): dirname = platformdirs.user_cache_dir("cocktail", "cocktail") return os.path.join(dirname, "cocktail.sqlite3") def get_connection(filepath=None): - filepath = filepath or get_db_path() + filepath = filepath or get_database_path() os.makedirs(os.path.dirname(filepath), exist_ok=True) logger.info(f"Connecting to database at {filepath}") diff --git a/src/cocktail/ui/settings/controller.py b/src/cocktail/ui/settings/controller.py index c22c0ce..9993262 100644 --- a/src/cocktail/ui/settings/controller.py +++ b/src/cocktail/ui/settings/controller.py @@ -1,4 +1,5 @@ -__all__ = ["SettingsController"] +__all__ = ["SettingsController", "detect_tool"] +import os import logging from typing import Any from PySide6 import QtCore, QtWidgets, QtGui, QtSql @@ -29,6 +30,34 @@ } +def detect_tool(directory): + """ + Detects the type of diffusion tool in the given directory. + """ + is_automatic = all( + [ + os.path.isdir(os.path.join(directory, "models", "Stable-diffusion")), + os.path.isdir(os.path.join(directory, "extensions-builtin")), + os.path.isfile(os.path.join(directory, "webui.bat")), + ] + ) + + if is_automatic: + return "Automatic-1111" + + comfy_ui = all( + [ + os.path.isdir(os.path.join(directory, "comfy")), + os.path.isdir(os.path.join(directory, "models", "checkpoints")), + ] + ) + + if comfy_ui: + return "ComfyUI" + + return None + + def walk_namespaces(namespace: str) -> list[str]: """ Walks a namespace and returns a list of all the namespaces and keys. diff --git a/src/cocktail/ui/startup/controller.py b/src/cocktail/ui/startup/controller.py index 6dfa15f..33848e1 100644 --- a/src/cocktail/ui/startup/controller.py +++ b/src/cocktail/ui/startup/controller.py @@ -4,7 +4,8 @@ import zipfile import platformdirs from PySide6 import QtCore, QtNetwork -from cocktail.ui.startup.view import CocktailSplashScreen +from cocktail.ui.startup.view import CocktailSplashScreen, SetupWizard +from cocktail.core.database.api import get_database_path def get_db_url(data): @@ -87,13 +88,15 @@ class StartupController(QtCore.QObject): """ complete = QtCore.Signal() + canceled = QtCore.Signal() api_url = "https://api.github.com/repos/cocktail-collective/cocktail/releases" def __init__(self, parent=None): super().__init__(parent) - self.view = CocktailSplashScreen() - self.cache_dir = platformdirs.user_cache_dir("cocktail", "cocktail") + self.splash = CocktailSplashScreen() + self.wizard = SetupWizard() + self.database_path = get_database_path() self.network_manager = QtNetwork.QNetworkAccessManager() self.get_releases_step = DownloadStep(self.network_manager) @@ -104,26 +107,27 @@ def __init__(self, parent=None): self.unzip_thread = QtCore.QThread() self.unzip_db_step.moveToThread(self.unzip_thread) - self.get_releases_step.progress.connect(self.view.setProgress) - self.download_db_step.progress.connect(self.view.setProgress) - self.unzip_db_step.progress.connect(self.view.setProgress) + self.get_releases_step.progress.connect(self.splash.setProgress) + self.download_db_step.progress.connect(self.splash.setProgress) + self.unzip_db_step.progress.connect(self.splash.setProgress) self.get_releases_step.complete.connect(self.onReleasesReady) self.download_db_step.complete.connect(self.onZipDownloaded) - self.unzip_db_step.complete.connect(self.onCompleted) + self.unzip_db_step.complete.connect(self.onZipExtracted) + self.wizard.rejected.connect(self.onCanceled) + self.wizard.accepted.connect(self.onCompleted) def start(self): """ begin the startup process. """ - db_path = os.path.join(self.cache_dir, "cocktail.sqlite3") - if os.path.exists(db_path): + if os.path.exists(self.database_path): self.onCompleted() return - self.view.show() - self.view.setText("Getting database...") - self.view.setProgress(0, 0) + self.splash.show() + self.splash.setText("Getting database...") + self.splash.setProgress(0, 0) self.get_releases_step.download(self.api_url) def onReleasesReady(self, reply: QtNetwork.QNetworkReply): @@ -138,26 +142,41 @@ def onReleasesReady(self, reply: QtNetwork.QNetworkReply): if url is None: return - self.view.setText("Downloading database...") - self.view.setProgress(0, 0) + self.splash.setText("Downloading database...") + self.splash.setProgress(0, 0) self.download_db_step.download(url) def onZipDownloaded(self, reply: QtNetwork.QNetworkReply): """ after downloading the database, we need to extract it. """ - self.view.setText("Extracting database...") - self.view.setProgress(0, 0) - self.unzip_db_step.extract(reply, self.cache_dir) + self.splash.setText("Extracting database...") + self.splash.setProgress(0, 0) + self.unzip_db_step.extract(reply, os.path.dirname(self.database_path)) + + def onZipExtracted(self): + """ + after extracting the database, we need to show the setup wizard. + """ + self.splash.close() + self.wizard.show() def onCompleted(self): """ cleanup and signal completion. """ self.unzip_thread.quit() - self.view.close() + self.splash.close() self.complete.emit() + def onCanceled(self): + """ + cleanup and signal completion. + """ + self.unzip_thread.quit() + self.splash.close() + self.canceled.emit() + if __name__ == "__main__": from PySide6 import QtWidgets @@ -169,8 +188,14 @@ def onCompleted(): app.closeAllWindows() app.quit() + def onCanceled(): + print("canceled") + app.closeAllWindows() + app.quit() + controller = StartupController() controller.complete.connect(onCompleted) + controller.canceled.connect(onCanceled) controller.start() diff --git a/src/cocktail/ui/startup/view.py b/src/cocktail/ui/startup/view.py index 4c404ff..f8d1f01 100644 --- a/src/cocktail/ui/startup/view.py +++ b/src/cocktail/ui/startup/view.py @@ -1,8 +1,163 @@ import os +import functools +from typing import Any from PySide6 import QtCore, QtWidgets, QtGui, QtSql, QtNetwork from cocktail import resources +import qtawesome as qta +from cocktail.ui.settings.controller import PRESETS, detect_tool -from cocktail.ui.settings.view import DirectoryPicker + +SELECT_DIRECTORY_DESCRIPTION = """ +Please select the location where your diffusion tool is installed. +

+e.g. the location where you put Automatic1111 or ComfyUI... +""" + + +class PageBase(QtWidgets.QWizardPage): + def __init__(self, parent=None): + super().__init__(parent) + self._fields = [] + + def registerField(self, name, widget, property="text"): + super().registerField(name, widget, property) + self._fields.append(name) + + def fields(self): + return {k: self.field(k) for k in self._fields} + + +class SelectToolDirectoryStep(PageBase): + def __init__(self, parent=None): + super().__init__(parent) + self.setTitle("Where is your diffusion tool installed?") + self.description_box = QtWidgets.QTextEdit() + self.description_box.setReadOnly(True) + self.description_box.setHtml(SELECT_DIRECTORY_DESCRIPTION) + + self.directory_edit = QtWidgets.QLineEdit() + self.browser_button = QtWidgets.QPushButton() + self.browser_button.setIcon(qta.icon("mdi.folder-open")) + + edit_layout = QtWidgets.QHBoxLayout() + edit_layout.addWidget(self.directory_edit) + edit_layout.addWidget(self.browser_button) + + layout = QtWidgets.QVBoxLayout(self) + layout.addWidget(self.description_box) + layout.addLayout(edit_layout) + + self.registerField("paths/root", self.directory_edit) + + self.browser_button.clicked.connect(self.onBrowseClicked) + + def onBrowseClicked(self): + directory = QtWidgets.QFileDialog.getExistingDirectory( + self, "Select Diffusion Tool Directory" + ) + if directory: + self.directory_edit.setText(directory) + + def validatePage(self) -> bool: + return super().validatePage() and os.path.isdir(self.directory_edit.text()) + + +class PathsTool(PageBase): + def __init__(self, parent=None): + super().__init__(parent) + self.setTitle("Select Diffusion Tool") + self._label = QtWidgets.QLabel() + + self.paths_layout = QtWidgets.QFormLayout() + + main_layout = QtWidgets.QVBoxLayout(self) + main_layout.addWidget(self._label) + main_layout.addLayout(self.paths_layout) + + self._path_keys = [] + + def initializePage(self) -> None: + directory = self.field("paths/root") + tool = detect_tool(directory) + preset = PRESETS.get(tool) + + if tool: + self._label.setText(f"Detected {tool}") + else: + self._label.setText("Uknown tool, please select paths manually") + + if preset: + for key, value in preset.items(): + name = key.partition("/")[2] + self.addPath(key, name, value) + else: + preset = list(PRESETS.keys())[0] + + for k, _ in PRESETS[preset].items(): + name = k.partition("/")[2] + self.addPath(k, name, "") + + def addPath(self, key, name, value): + self._path_keys.append(key) + layout = QtWidgets.QHBoxLayout() + label = QtWidgets.QLabel(name) + edit = QtWidgets.QLineEdit(value) + browse_button = QtWidgets.QPushButton() + browse_button.setIcon(qta.icon("mdi.folder-open")) + + layout.addWidget(label) + layout.addWidget(edit) + layout.addWidget(browse_button) + self.paths_layout.addRow(layout) + + callback = functools.partial(self.browse, edit) + + self.registerField(key, edit) + browse_button.clicked.connect(callback) + + def browse(self, editor): + root = self.field("paths/root") + directory = QtWidgets.QFileDialog.getExistingDirectory( + self, "Select Diffusion Tool Directory", dir=root + ) + if directory: + if directory.startswith(root): + directory = os.path.relpath(directory, self.field("paths/root")) + editor.setText(directory) + + def validatePage(self) -> bool: + all_filled = [self.field(k) for k in self._path_keys] + if not all(all_filled): + success = ( + QtWidgets.QMessageBox.question( + self, + "Missing Paths", + "Some paths are missing, do you want to continue?", + QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No, + ) + == QtWidgets.QMessageBox.Yes + ) + + else: + success = True + + if success: + settings = QtCore.QSettings("cocktail", "cocktail") + settings.setValue("paths/root", self.field("paths/root")) + for k in self._path_keys: + settings.setValue(k, self.field(k)) + + return success + + +class SetupWizard(QtWidgets.QWizard): + def __init__(self, parent=None): + super().__init__(parent) + self.setWindowTitle("Diffusion Tool Setup") + self.setWizardStyle(QtWidgets.QWizard.ModernStyle) + + self.addPage(SelectToolDirectoryStep(self)) + self.addPage(PathsTool(self)) class CocktailSplashScreen(QtWidgets.QDialog): From ffa05f5db63d27ab16e6e3e03bc68eede0da8e9f Mon Sep 17 00:00:00 2001 From: rfletchr Date: Thu, 23 Nov 2023 01:52:13 +0000 Subject: [PATCH 2/3] remove un-used base class --- src/cocktail/ui/startup/view.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/src/cocktail/ui/startup/view.py b/src/cocktail/ui/startup/view.py index f8d1f01..3478b6c 100644 --- a/src/cocktail/ui/startup/view.py +++ b/src/cocktail/ui/startup/view.py @@ -14,20 +14,7 @@ """ -class PageBase(QtWidgets.QWizardPage): - def __init__(self, parent=None): - super().__init__(parent) - self._fields = [] - - def registerField(self, name, widget, property="text"): - super().registerField(name, widget, property) - self._fields.append(name) - - def fields(self): - return {k: self.field(k) for k in self._fields} - - -class SelectToolDirectoryStep(PageBase): +class SelectToolDirectoryStep(QtWidgets.QWizardPage): def __init__(self, parent=None): super().__init__(parent) self.setTitle("Where is your diffusion tool installed?") @@ -62,7 +49,7 @@ def validatePage(self) -> bool: return super().validatePage() and os.path.isdir(self.directory_edit.text()) -class PathsTool(PageBase): +class PathsTool(QtWidgets.QWizardPage): def __init__(self, parent=None): super().__init__(parent) self.setTitle("Select Diffusion Tool") From cda0d1806ef26f2995c6738c935bcd2a62ae62ec Mon Sep 17 00:00:00 2001 From: rfletchr Date: Thu, 23 Nov 2023 01:53:17 +0000 Subject: [PATCH 3/3] version bump --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6969529..4a54441 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def run(self): setup( name="cocktail", - version="0.2.0", + version="0.3.0", description="Cocktail", package_dir={"": "src"}, packages=find_namespace_packages(where="src"),