diff --git a/.github/workflows/test-ui.yaml b/.github/workflows/test-ui.yaml index 4b8b9793479..d947e9d5fc6 100644 --- a/.github/workflows/test-ui.yaml +++ b/.github/workflows/test-ui.yaml @@ -24,3 +24,7 @@ jobs: npm run test:generate npm test -- --verbose working-directory: ./tests-ui + - name: Run Unit Tests + run: | + pip install -r tests-unit/requirements.txt + python -m pytest tests-unit diff --git a/.gitignore b/.gitignore index a9beebe73a7..5092c98f463 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ venv/ !/web/extensions/core/ /tests-ui/data/object_info.json /user/ -*.log \ No newline at end of file +*.log +web_custom_versions/ \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/app/frontend_management.py b/app/frontend_management.py new file mode 100644 index 00000000000..0855f4160f0 --- /dev/null +++ b/app/frontend_management.py @@ -0,0 +1,187 @@ +import argparse +import logging +import os +import re +import tempfile +import zipfile +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path +from typing import TypedDict + +import requests +from typing_extensions import NotRequired +from comfy.cli_args import DEFAULT_VERSION_STRING + + +REQUEST_TIMEOUT = 10 # seconds + + +class Asset(TypedDict): + url: str + + +class Release(TypedDict): + id: int + tag_name: str + name: str + prerelease: bool + created_at: str + published_at: str + body: str + assets: NotRequired[list[Asset]] + + +@dataclass +class FrontEndProvider: + owner: str + repo: str + + @property + def folder_name(self) -> str: + return f"{self.owner}_{self.repo}" + + @property + def release_url(self) -> str: + return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases" + + @cached_property + def all_releases(self) -> list[Release]: + releases = [] + api_url = self.release_url + while api_url: + response = requests.get(api_url, timeout=REQUEST_TIMEOUT) + response.raise_for_status() # Raises an HTTPError if the response was an error + releases.extend(response.json()) + # GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly. + if "next" in response.links: + api_url = response.links["next"]["url"] + else: + api_url = None + return releases + + @cached_property + def latest_release(self) -> Release: + latest_release_url = f"{self.release_url}/latest" + response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT) + response.raise_for_status() # Raises an HTTPError if the response was an error + return response.json() + + def get_release(self, version: str) -> Release: + if version == "latest": + return self.latest_release + else: + for release in self.all_releases: + if release["tag_name"] in [version, f"v{version}"]: + return release + raise ValueError(f"Version {version} not found in releases") + + +def download_release_asset_zip(release: Release, destination_path: str) -> None: + """Download dist.zip from github release.""" + asset_url = None + for asset in release.get("assets", []): + if asset["name"] == "dist.zip": + asset_url = asset["url"] + break + + if not asset_url: + raise ValueError("dist.zip not found in the release assets") + + # Use a temporary file to download the zip content + with tempfile.TemporaryFile() as tmp_file: + headers = {"Accept": "application/octet-stream"} + response = requests.get( + asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT + ) + response.raise_for_status() # Ensure we got a successful response + + # Write the content to the temporary file + tmp_file.write(response.content) + + # Go back to the beginning of the temporary file + tmp_file.seek(0) + + # Extract the zip file content to the destination path + with zipfile.ZipFile(tmp_file, "r") as zip_ref: + zip_ref.extractall(destination_path) + + +class FrontendManager: + DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web") + CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") + + @classmethod + def parse_version_string(cls, value: str) -> tuple[str, str, str]: + """ + Args: + value (str): The version string to parse. + + Returns: + tuple[str, str]: A tuple containing provider name and version. + + Raises: + argparse.ArgumentTypeError: If the version string is invalid. + """ + VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(\d+\.\d+\.\d+|latest)$" + match_result = re.match(VERSION_PATTERN, value) + if match_result is None: + raise argparse.ArgumentTypeError(f"Invalid version string: {value}") + + return match_result.group(1), match_result.group(2), match_result.group(3) + + @classmethod + def init_frontend_unsafe(cls, version_string: str) -> str: + """ + Initializes the frontend for the specified version. + + Args: + version_string (str): The version string. + + Returns: + str: The path to the initialized frontend. + + Raises: + Exception: If there is an error during the initialization process. + main error source might be request timeout or invalid URL. + """ + if version_string == DEFAULT_VERSION_STRING: + return cls.DEFAULT_FRONTEND_PATH + + repo_owner, repo_name, version = cls.parse_version_string(version_string) + provider = FrontEndProvider(repo_owner, repo_name) + release = provider.get_release(version) + + semantic_version = release["tag_name"].lstrip("v") + web_root = str( + Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version + ) + if not os.path.exists(web_root): + os.makedirs(web_root, exist_ok=True) + logging.info( + "Downloading frontend(%s) version(%s) to (%s)", + provider.folder_name, + semantic_version, + web_root, + ) + logging.debug(release) + download_release_asset_zip(release, destination_path=web_root) + return web_root + + @classmethod + def init_frontend(cls, version_string: str) -> str: + """ + Initializes the frontend with the specified version string. + + Args: + version_string (str): The version string to initialize the frontend with. + + Returns: + str: The path of the initialized frontend. + """ + try: + return cls.init_frontend_unsafe(version_string) + except Exception as e: + logging.error("Failed to initialize frontend: %s", e) + logging.info("Falling back to the default frontend.") + return cls.DEFAULT_FRONTEND_PATH diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b72bf3998ae..6251e3d28fc 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,7 +1,10 @@ import argparse import enum +import os +from typing import Optional import comfy.options + class EnumAction(argparse.Action): """ Argparse action for handling Enums @@ -124,6 +127,38 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.") +# The default built-in provider hosted under web/ +DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" + +parser.add_argument( + "--front-end-version", + type=str, + default=DEFAULT_VERSION_STRING, + help=""" + Specifies the version of the frontend to be used. This command needs internet connectivity to query and + download available frontend implementations from GitHub releases. + + The version string should be in the format of: + [repoOwner]/[repoName]@[version] + where version is one of: "latest" or a valid version number (e.g. "1.0.0") + """, +) + +def is_valid_directory(path: Optional[str]) -> Optional[str]: + """Validate if the given path is a directory.""" + if path is None: + return None + + if not os.path.isdir(path): + raise argparse.ArgumentTypeError(f"{path} is not a valid directory.") + return path + +parser.add_argument( + "--front-end-root", + type=is_valid_directory, + default=None, + help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.", +) if comfy.options.args_parsing: args = parser.parse_args() diff --git a/pytest.ini b/pytest.ini index b5a68e0f12f..8b7a747e760 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,8 @@ [pytest] markers = inference: mark as inference test (deselect with '-m "not inference"') -testpaths = tests -addopts = -s \ No newline at end of file +testpaths = + tests + tests-unit +addopts = -s +pythonpath = . diff --git a/server.py b/server.py index ce7c7532dab..03c4c421906 100644 --- a/server.py +++ b/server.py @@ -25,9 +25,10 @@ from comfy.cli_args import args import comfy.utils import comfy.model_management - +from app.frontend_management import FrontendManager from app.user_manager import UserManager + class BinaryEventTypes: PREVIEW_IMAGE = 1 UNENCODED_PREVIEW_IMAGE = 2 @@ -83,8 +84,12 @@ def __init__(self, loop): max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.sockets = dict() - self.web_root = os.path.join(os.path.dirname( - os.path.realpath(__file__)), "web") + self.web_root = ( + FrontendManager.init_frontend(args.front_end_version) + if args.front_end_root is None + else args.front_end_root + ) + logging.info(f"[Prompt Server] web root: {self.web_root}") routes = web.RouteTableDef() self.routes = routes self.last_node_id = None diff --git a/tests-unit/README.md b/tests-unit/README.md new file mode 100644 index 00000000000..94abd985346 --- /dev/null +++ b/tests-unit/README.md @@ -0,0 +1,8 @@ +# Pytest Unit Tests + +## Install test dependencies + +`pip install -r tests-units/requirements.txt` + +## Run tests +`pytest tests-units/` diff --git a/tests-unit/app_test/__init__.py b/tests-unit/app_test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py new file mode 100644 index 00000000000..637869cfbf5 --- /dev/null +++ b/tests-unit/app_test/frontend_manager_test.py @@ -0,0 +1,100 @@ +import argparse +import pytest +from requests.exceptions import HTTPError + +from app.frontend_management import ( + FrontendManager, + FrontEndProvider, + Release, +) +from comfy.cli_args import DEFAULT_VERSION_STRING + + +@pytest.fixture +def mock_releases(): + return [ + Release( + id=1, + tag_name="1.0.0", + name="Release 1.0.0", + prerelease=False, + created_at="2022-01-01T00:00:00Z", + published_at="2022-01-01T00:00:00Z", + body="Release notes for 1.0.0", + assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}], + ), + Release( + id=2, + tag_name="2.0.0", + name="Release 2.0.0", + prerelease=False, + created_at="2022-02-01T00:00:00Z", + published_at="2022-02-01T00:00:00Z", + body="Release notes for 2.0.0", + assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}], + ), + ] + + +@pytest.fixture +def mock_provider(mock_releases): + provider = FrontEndProvider( + owner="test-owner", + repo="test-repo", + ) + provider.all_releases = mock_releases + provider.latest_release = mock_releases[1] + FrontendManager.PROVIDERS = [provider] + return provider + + +def test_get_release(mock_provider, mock_releases): + version = "1.0.0" + release = mock_provider.get_release(version) + assert release == mock_releases[0] + + +def test_get_release_latest(mock_provider, mock_releases): + version = "latest" + release = mock_provider.get_release(version) + assert release == mock_releases[1] + + +def test_get_release_invalid_version(mock_provider): + version = "invalid" + with pytest.raises(ValueError): + mock_provider.get_release(version) + + +def test_init_frontend_default(): + version_string = DEFAULT_VERSION_STRING + frontend_path = FrontendManager.init_frontend(version_string) + assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH + + +def test_init_frontend_invalid_version(): + version_string = "test-owner/test-repo@1.100.99" + with pytest.raises(HTTPError): + FrontendManager.init_frontend_unsafe(version_string) + + +def test_init_frontend_invalid_provider(): + version_string = "invalid/invalid@latest" + with pytest.raises(HTTPError): + FrontendManager.init_frontend_unsafe(version_string) + + +def test_parse_version_string(): + version_string = "owner/repo@1.0.0" + repo_owner, repo_name, version = FrontendManager.parse_version_string( + version_string + ) + assert repo_owner == "owner" + assert repo_name == "repo" + assert version == "1.0.0" + + +def test_parse_version_string_invalid(): + version_string = "invalid" + with pytest.raises(argparse.ArgumentTypeError): + FrontendManager.parse_version_string(version_string) diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt new file mode 100644 index 00000000000..0587502f877 --- /dev/null +++ b/tests-unit/requirements.txt @@ -0,0 +1 @@ +pytest>=7.8.0