Skip to content

Commit

Permalink
Add FrontendManager to manage non-default front-end impl (comfyanon…
Browse files Browse the repository at this point in the history
…ymous#3897)

* Add frontend manager

* Add tests

* nit

* Add unit test to github CI

* Fix path

* nit

* ignore

* Add logging

* Install test deps

* Remove 'stable' keyword support

* Update test

* Add web-root arg

* Rename web-root to front-end-root

* Add test on non-exist version number

* Use repo owner/name to replace hard coded provider list

* Inline cmd args

* nit

* Fix unit test
  • Loading branch information
huchenlei authored Jul 16, 2024
1 parent 33346fd commit 99458e8
Show file tree
Hide file tree
Showing 11 changed files with 350 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/test-ui.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ jobs:
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ venv/
!/web/extensions/core/
/tests-ui/data/object_info.json
/user/
*.log
*.log
web_custom_versions/
Empty file added app/__init__.py
Empty file.
187 changes: 187 additions & 0 deletions app/frontend_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import argparse
import logging
import os
import re
import tempfile
import zipfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict

import requests
from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING


REQUEST_TIMEOUT = 10 # seconds


class Asset(TypedDict):
url: str


class Release(TypedDict):
id: int
tag_name: str
name: str
prerelease: bool
created_at: str
published_at: str
body: str
assets: NotRequired[list[Asset]]


@dataclass
class FrontEndProvider:
owner: str
repo: str

@property
def folder_name(self) -> str:
return f"{self.owner}_{self.repo}"

@property
def release_url(self) -> str:
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"

@cached_property
def all_releases(self) -> list[Release]:
releases = []
api_url = self.release_url
while api_url:
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
releases.extend(response.json())
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
if "next" in response.links:
api_url = response.links["next"]["url"]
else:
api_url = None
return releases

@cached_property
def latest_release(self) -> Release:
latest_release_url = f"{self.release_url}/latest"
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()

def get_release(self, version: str) -> Release:
if version == "latest":
return self.latest_release
else:
for release in self.all_releases:
if release["tag_name"] in [version, f"v{version}"]:
return release
raise ValueError(f"Version {version} not found in releases")


def download_release_asset_zip(release: Release, destination_path: str) -> None:
"""Download dist.zip from github release."""
asset_url = None
for asset in release.get("assets", []):
if asset["name"] == "dist.zip":
asset_url = asset["url"]
break

if not asset_url:
raise ValueError("dist.zip not found in the release assets")

# Use a temporary file to download the zip content
with tempfile.TemporaryFile() as tmp_file:
headers = {"Accept": "application/octet-stream"}
response = requests.get(
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
)
response.raise_for_status() # Ensure we got a successful response

# Write the content to the temporary file
tmp_file.write(response.content)

# Go back to the beginning of the temporary file
tmp_file.seek(0)

# Extract the zip file content to the destination path
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
zip_ref.extractall(destination_path)


class FrontendManager:
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")

@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Args:
value (str): The version string to parse.
Returns:
tuple[str, str]: A tuple containing provider name and version.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(\d+\.\d+\.\d+|latest)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")

return match_result.group(1), match_result.group(2), match_result.group(3)

@classmethod
def init_frontend_unsafe(cls, version_string: str) -> str:
"""
Initializes the frontend for the specified version.
Args:
version_string (str): The version string.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH

repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)

semantic_version = release["tag_name"].lstrip("v")
web_root = str(
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
return web_root

@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.
Args:
version_string (str): The version string to initialize the frontend with.
Returns:
str: The path of the initialized frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH
35 changes: 35 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import argparse
import enum
import os
from typing import Optional
import comfy.options


class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
Expand Down Expand Up @@ -124,6 +127,38 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")

# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"

parser.add_argument(
"--front-end-version",
type=str,
default=DEFAULT_VERSION_STRING,
help="""
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
download available frontend implementations from GitHub releases.
The version string should be in the format of:
[repoOwner]/[repoName]@[version]
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
""",
)

def is_valid_directory(path: Optional[str]) -> Optional[str]:
"""Validate if the given path is a directory."""
if path is None:
return None

if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
return path

parser.add_argument(
"--front-end-root",
type=is_valid_directory,
default=None,
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)

if comfy.options.args_parsing:
args = parser.parse_args()
Expand Down
7 changes: 5 additions & 2 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[pytest]
markers =
inference: mark as inference test (deselect with '-m "not inference"')
testpaths = tests
addopts = -s
testpaths =
tests
tests-unit
addopts = -s
pythonpath = .
11 changes: 8 additions & 3 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from comfy.cli_args import args
import comfy.utils
import comfy.model_management

from app.frontend_management import FrontendManager
from app.user_manager import UserManager


class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
Expand Down Expand Up @@ -83,8 +84,12 @@ def __init__(self, loop):
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict()
self.web_root = os.path.join(os.path.dirname(
os.path.realpath(__file__)), "web")
self.web_root = (
FrontendManager.init_frontend(args.front_end_version)
if args.front_end_root is None
else args.front_end_root
)
logging.info(f"[Prompt Server] web root: {self.web_root}")
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None
Expand Down
8 changes: 8 additions & 0 deletions tests-unit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Pytest Unit Tests

## Install test dependencies

`pip install -r tests-units/requirements.txt`

## Run tests
`pytest tests-units/`
Empty file added tests-unit/app_test/__init__.py
Empty file.
Loading

0 comments on commit 99458e8

Please sign in to comment.