Skip to content

Commit

Permalink
Merge pull request #9 from cocktail-collective/add-path-setup-wizard
Browse files Browse the repository at this point in the history
implement path setup wizard
  • Loading branch information
rfletchr committed Nov 23, 2023
2 parents deec63c + cda0d18 commit c217115
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 23 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run(self):

setup(
name="cocktail",
version="0.2.0",
version="0.3.0",
description="Cocktail",
package_dir={"": "src"},
packages=find_namespace_packages(where="src"),
Expand Down
1 change: 1 addition & 0 deletions src/cocktail/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def start():

startup_controller = StartupController()
startup_controller.complete.connect(start)
startup_controller.canceled.connect(app.quit)
startup_controller.start()

app.exec()
Expand Down
4 changes: 2 additions & 2 deletions src/cocktail/core/database/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ def get_update_period(db):
return data_classes.Period.AllTime


def get_db_path():
def get_database_path():
dirname = platformdirs.user_cache_dir("cocktail", "cocktail")
return os.path.join(dirname, "cocktail.sqlite3")


def get_connection(filepath=None):
filepath = filepath or get_db_path()
filepath = filepath or get_database_path()
os.makedirs(os.path.dirname(filepath), exist_ok=True)

logger.info(f"Connecting to database at {filepath}")
Expand Down
31 changes: 30 additions & 1 deletion src/cocktail/ui/settings/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__all__ = ["SettingsController"]
__all__ = ["SettingsController", "detect_tool"]
import os
import logging
from typing import Any
from PySide6 import QtCore, QtWidgets, QtGui, QtSql
Expand Down Expand Up @@ -29,6 +30,34 @@
}


def detect_tool(directory):
"""
Detects the type of diffusion tool in the given directory.
"""
is_automatic = all(
[
os.path.isdir(os.path.join(directory, "models", "Stable-diffusion")),
os.path.isdir(os.path.join(directory, "extensions-builtin")),
os.path.isfile(os.path.join(directory, "webui.bat")),
]
)

if is_automatic:
return "Automatic-1111"

comfy_ui = all(
[
os.path.isdir(os.path.join(directory, "comfy")),
os.path.isdir(os.path.join(directory, "models", "checkpoints")),
]
)

if comfy_ui:
return "ComfyUI"

return None


def walk_namespaces(namespace: str) -> list[str]:
"""
Walks a namespace and returns a list of all the namespaces and keys.
Expand Down
61 changes: 43 additions & 18 deletions src/cocktail/ui/startup/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import zipfile
import platformdirs
from PySide6 import QtCore, QtNetwork
from cocktail.ui.startup.view import CocktailSplashScreen
from cocktail.ui.startup.view import CocktailSplashScreen, SetupWizard
from cocktail.core.database.api import get_database_path


def get_db_url(data):
Expand Down Expand Up @@ -87,13 +88,15 @@ class StartupController(QtCore.QObject):
"""

complete = QtCore.Signal()
canceled = QtCore.Signal()

api_url = "https://api.github.com/repos/cocktail-collective/cocktail/releases"

def __init__(self, parent=None):
super().__init__(parent)
self.view = CocktailSplashScreen()
self.cache_dir = platformdirs.user_cache_dir("cocktail", "cocktail")
self.splash = CocktailSplashScreen()
self.wizard = SetupWizard()
self.database_path = get_database_path()
self.network_manager = QtNetwork.QNetworkAccessManager()

self.get_releases_step = DownloadStep(self.network_manager)
Expand All @@ -104,26 +107,27 @@ def __init__(self, parent=None):
self.unzip_thread = QtCore.QThread()
self.unzip_db_step.moveToThread(self.unzip_thread)

self.get_releases_step.progress.connect(self.view.setProgress)
self.download_db_step.progress.connect(self.view.setProgress)
self.unzip_db_step.progress.connect(self.view.setProgress)
self.get_releases_step.progress.connect(self.splash.setProgress)
self.download_db_step.progress.connect(self.splash.setProgress)
self.unzip_db_step.progress.connect(self.splash.setProgress)

self.get_releases_step.complete.connect(self.onReleasesReady)
self.download_db_step.complete.connect(self.onZipDownloaded)
self.unzip_db_step.complete.connect(self.onCompleted)
self.unzip_db_step.complete.connect(self.onZipExtracted)
self.wizard.rejected.connect(self.onCanceled)
self.wizard.accepted.connect(self.onCompleted)

def start(self):
"""
begin the startup process.
"""
db_path = os.path.join(self.cache_dir, "cocktail.sqlite3")
if os.path.exists(db_path):
if os.path.exists(self.database_path):
self.onCompleted()
return

self.view.show()
self.view.setText("Getting database...")
self.view.setProgress(0, 0)
self.splash.show()
self.splash.setText("Getting database...")
self.splash.setProgress(0, 0)
self.get_releases_step.download(self.api_url)

def onReleasesReady(self, reply: QtNetwork.QNetworkReply):
Expand All @@ -138,26 +142,41 @@ def onReleasesReady(self, reply: QtNetwork.QNetworkReply):
if url is None:
return

self.view.setText("Downloading database...")
self.view.setProgress(0, 0)
self.splash.setText("Downloading database...")
self.splash.setProgress(0, 0)
self.download_db_step.download(url)

def onZipDownloaded(self, reply: QtNetwork.QNetworkReply):
"""
after downloading the database, we need to extract it.
"""
self.view.setText("Extracting database...")
self.view.setProgress(0, 0)
self.unzip_db_step.extract(reply, self.cache_dir)
self.splash.setText("Extracting database...")
self.splash.setProgress(0, 0)
self.unzip_db_step.extract(reply, os.path.dirname(self.database_path))

def onZipExtracted(self):
"""
after extracting the database, we need to show the setup wizard.
"""
self.splash.close()
self.wizard.show()

def onCompleted(self):
"""
cleanup and signal completion.
"""
self.unzip_thread.quit()
self.view.close()
self.splash.close()
self.complete.emit()

def onCanceled(self):
"""
cleanup and signal completion.
"""
self.unzip_thread.quit()
self.splash.close()
self.canceled.emit()


if __name__ == "__main__":
from PySide6 import QtWidgets
Expand All @@ -169,8 +188,14 @@ def onCompleted():
app.closeAllWindows()
app.quit()

def onCanceled():
print("canceled")
app.closeAllWindows()
app.quit()

controller = StartupController()
controller.complete.connect(onCompleted)
controller.canceled.connect(onCanceled)

controller.start()

Expand Down
144 changes: 143 additions & 1 deletion src/cocktail/ui/startup/view.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,150 @@
import os
import functools
from typing import Any
from PySide6 import QtCore, QtWidgets, QtGui, QtSql, QtNetwork
from cocktail import resources
import qtawesome as qta
from cocktail.ui.settings.controller import PRESETS, detect_tool

from cocktail.ui.settings.view import DirectoryPicker

SELECT_DIRECTORY_DESCRIPTION = """
Please select the location where your diffusion tool is installed.
<br><br>
e.g. the location where you put Automatic1111 or ComfyUI...
"""


class SelectToolDirectoryStep(QtWidgets.QWizardPage):
def __init__(self, parent=None):
super().__init__(parent)
self.setTitle("Where is your diffusion tool installed?")
self.description_box = QtWidgets.QTextEdit()
self.description_box.setReadOnly(True)
self.description_box.setHtml(SELECT_DIRECTORY_DESCRIPTION)

self.directory_edit = QtWidgets.QLineEdit()
self.browser_button = QtWidgets.QPushButton()
self.browser_button.setIcon(qta.icon("mdi.folder-open"))

edit_layout = QtWidgets.QHBoxLayout()
edit_layout.addWidget(self.directory_edit)
edit_layout.addWidget(self.browser_button)

layout = QtWidgets.QVBoxLayout(self)
layout.addWidget(self.description_box)
layout.addLayout(edit_layout)

self.registerField("paths/root", self.directory_edit)

self.browser_button.clicked.connect(self.onBrowseClicked)

def onBrowseClicked(self):
directory = QtWidgets.QFileDialog.getExistingDirectory(
self, "Select Diffusion Tool Directory"
)
if directory:
self.directory_edit.setText(directory)

def validatePage(self) -> bool:
return super().validatePage() and os.path.isdir(self.directory_edit.text())


class PathsTool(QtWidgets.QWizardPage):
def __init__(self, parent=None):
super().__init__(parent)
self.setTitle("Select Diffusion Tool")
self._label = QtWidgets.QLabel()

self.paths_layout = QtWidgets.QFormLayout()

main_layout = QtWidgets.QVBoxLayout(self)
main_layout.addWidget(self._label)
main_layout.addLayout(self.paths_layout)

self._path_keys = []

def initializePage(self) -> None:
directory = self.field("paths/root")
tool = detect_tool(directory)
preset = PRESETS.get(tool)

if tool:
self._label.setText(f"Detected {tool}")
else:
self._label.setText("Uknown tool, please select paths manually")

if preset:
for key, value in preset.items():
name = key.partition("/")[2]
self.addPath(key, name, value)
else:
preset = list(PRESETS.keys())[0]

for k, _ in PRESETS[preset].items():
name = k.partition("/")[2]
self.addPath(k, name, "")

def addPath(self, key, name, value):
self._path_keys.append(key)
layout = QtWidgets.QHBoxLayout()
label = QtWidgets.QLabel(name)
edit = QtWidgets.QLineEdit(value)
browse_button = QtWidgets.QPushButton()
browse_button.setIcon(qta.icon("mdi.folder-open"))

layout.addWidget(label)
layout.addWidget(edit)
layout.addWidget(browse_button)
self.paths_layout.addRow(layout)

callback = functools.partial(self.browse, edit)

self.registerField(key, edit)
browse_button.clicked.connect(callback)

def browse(self, editor):
root = self.field("paths/root")
directory = QtWidgets.QFileDialog.getExistingDirectory(
self, "Select Diffusion Tool Directory", dir=root
)
if directory:
if directory.startswith(root):
directory = os.path.relpath(directory, self.field("paths/root"))
editor.setText(directory)

def validatePage(self) -> bool:
all_filled = [self.field(k) for k in self._path_keys]
if not all(all_filled):
success = (
QtWidgets.QMessageBox.question(
self,
"Missing Paths",
"Some paths are missing, do you want to continue?",
QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
)
== QtWidgets.QMessageBox.Yes
)

else:
success = True

if success:
settings = QtCore.QSettings("cocktail", "cocktail")
settings.setValue("paths/root", self.field("paths/root"))
for k in self._path_keys:
settings.setValue(k, self.field(k))

return success


class SetupWizard(QtWidgets.QWizard):
def __init__(self, parent=None):
super().__init__(parent)
self.setWindowTitle("Diffusion Tool Setup")
self.setWizardStyle(QtWidgets.QWizard.ModernStyle)

self.addPage(SelectToolDirectoryStep(self))
self.addPage(PathsTool(self))


class CocktailSplashScreen(QtWidgets.QDialog):
Expand Down

0 comments on commit c217115

Please sign in to comment.