diff --git a/questionpy_sdk/commands/run.py b/questionpy_sdk/commands/run.py index 84d7a3b..e169700 100644 --- a/questionpy_sdk/commands/run.py +++ b/questionpy_sdk/commands/run.py @@ -2,17 +2,36 @@ # The QuestionPy SDK is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus +import asyncio from pathlib import Path +from typing import TYPE_CHECKING import click from questionpy_sdk.commands._helper import get_package_location +from questionpy_sdk.watcher import Watcher from questionpy_sdk.webserver.app import WebServer +from questionpy_server.worker.runtime.package_location import DirPackageLocation + +if TYPE_CHECKING: + from collections.abc import Coroutine + + +async def run_watcher(pkg_path: Path, pkg_location: DirPackageLocation, host: str, port: int) -> None: + async with Watcher(pkg_path, pkg_location, host, port) as watcher: + await watcher.run_forever() @click.command() @click.argument("package") -def run(package: str) -> None: +@click.option( + "--host", "-h", "host", default="localhost", show_default=True, type=click.STRING, help="Host to listen on." +) +@click.option( + "--port", "-p", "port", default=8080, show_default=True, type=click.IntRange(1024, 65535), help="Port to bind to." +) +@click.option("--watch", "-w", "watch", is_flag=True, help="Watch source directory and rebuild on changes.") +def run(package: str, host: str, port: int, *, watch: bool) -> None: """Run a package. \b @@ -22,5 +41,15 @@ def run(package: str) -> None: - a source directory (built on-the-fly). """ # noqa: D301 pkg_path = Path(package).resolve() - web_server = WebServer(get_package_location(package, pkg_path)) - web_server.start_server() + pkg_location = get_package_location(package, pkg_path) + coro: Coroutine + + if watch: + if not isinstance(pkg_location, DirPackageLocation) or pkg_path == pkg_location.path: + msg = "The --watch option only works with source directories." + raise click.BadParameter(msg) + coro = run_watcher(pkg_path, pkg_location, host=host, port=port) + else: + coro = WebServer(pkg_location, host=host, port=port).run_forever() + + asyncio.run(coro) diff --git a/questionpy_sdk/package/builder.py b/questionpy_sdk/package/builder.py index 74ebae4..1e90be8 100644 --- a/questionpy_sdk/package/builder.py +++ b/questionpy_sdk/package/builder.py @@ -23,7 +23,7 @@ from questionpy_sdk.package.errors import PackageBuildError from questionpy_sdk.package.source import PackageSource -log = logging.getLogger(__name__) +log = logging.getLogger("questionpy-sdk:builder") class PackageBuilderBase(AbstractContextManager): diff --git a/questionpy_sdk/watcher.py b/questionpy_sdk/watcher.py new file mode 100644 index 0000000..19b1a46 --- /dev/null +++ b/questionpy_sdk/watcher.py @@ -0,0 +1,174 @@ +# This file is part of the QuestionPy SDK. (https://questionpy.org) +# The QuestionPy SDK is free software released under terms of the MIT license. See LICENSE.md. +# (c) Technische Universität Berlin, innoCampus + +import asyncio +import logging +from collections.abc import Awaitable, Callable +from contextlib import AbstractAsyncContextManager +from pathlib import Path +from types import TracebackType +from typing import TYPE_CHECKING, Self + +from watchdog.events import ( + FileClosedEvent, + FileOpenedEvent, + FileSystemEvent, + FileSystemEventHandler, + FileSystemMovedEvent, +) +from watchdog.observers import Observer +from watchdog.utils.event_debouncer import EventDebouncer + +from questionpy_common.constants import DIST_DIR +from questionpy_sdk.package.builder import DirPackageBuilder +from questionpy_sdk.package.errors import PackageBuildError, PackageSourceValidationError +from questionpy_sdk.package.source import PackageSource +from questionpy_sdk.webserver.app import WebServer +from questionpy_server.worker.runtime.package_location import DirPackageLocation + +if TYPE_CHECKING: + from watchdog.observers.api import ObservedWatch + +log = logging.getLogger("questionpy-sdk:watcher") + +_DEBOUNCE_INTERVAL = 0.5 # seconds + + +class _EventHandler(FileSystemEventHandler): + """Debounces events for watchdog file monitoring, ignoring events in the `dist` directory.""" + + def __init__( + self, loop: asyncio.AbstractEventLoop, notify_callback: Callable[[], Awaitable[None]], watch_path: Path + ) -> None: + self._loop = loop + self._notify_callback = notify_callback + self._watch_path = watch_path + + self._event_debouncer = EventDebouncer(_DEBOUNCE_INTERVAL, self._on_file_changes) + + def start(self) -> None: + self._event_debouncer.start() + + def stop(self) -> None: + if self._event_debouncer.is_alive(): + self._event_debouncer.stop() + self._event_debouncer.join() + + def dispatch(self, event: FileSystemEvent) -> None: + # filter events and debounce + if not self._ignore_event(event): + self._event_debouncer.handle_event(event) + + def _on_file_changes(self, events: list[FileSystemEvent]) -> None: + # skip synchronization hassle by delegating this to the event loop in the main thread + asyncio.run_coroutine_threadsafe(self._notify_callback(), self._loop) + + def _ignore_event(self, event: FileSystemEvent) -> bool: + """Ignores events that should not trigger a rebuild. + + Args: + event: The event to check. + + Returns: + `True` if event should be ignored, otherwise `False`. + """ + if isinstance(event, FileOpenedEvent | FileClosedEvent): + return True + + # ignore events events in `dist` dir + relevant_path = event.dest_path if isinstance(event, FileSystemMovedEvent) else event.src_path + try: + return Path(relevant_path).relative_to(self._watch_path).parts[0] == DIST_DIR + except IndexError: + return False + + +class Watcher(AbstractAsyncContextManager): + """Watch a package source path and rebuild package/restart server on file changes.""" + + def __init__(self, source_path: Path, pkg_location: DirPackageLocation, host: str, port: int) -> None: + self._source_path = source_path + self._pkg_location = pkg_location + self._host = host + self._port = port + + self._event_handler = _EventHandler(asyncio.get_running_loop(), self._notify, self._source_path) + self._observer = Observer() + self._webserver = WebServer(self._pkg_location, host=self._host, port=self._port) + self._on_change_event = asyncio.Event() + self._watch: ObservedWatch | None = None + + async def __aenter__(self) -> Self: + self._event_handler.start() + self._observer.start() + log.info("Watching '%s' for changes...", self._source_path) + + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> None: + if self._observer.is_alive(): + self._observer.stop() + self._event_handler.stop() + await self._webserver.stop_server() + + def _schedule(self) -> None: + if self._watch is None: + log.debug("Starting file watching...") + self._watch = self._observer.schedule(self._event_handler, self._source_path, recursive=True) + + def _unschedule(self) -> None: + if self._watch: + log.debug("Stopping file watching...") + self._observer.unschedule(self._watch) + self._watch = None + + async def _notify(self) -> None: + self._on_change_event.set() + + async def run_forever(self) -> None: + try: + await self._webserver.start_server() + except Exception: + log.exception("Failed to start webserver. The exception was:") + # When user messed up the their package on initial run, we just bail out. + return + + self._schedule() + + while True: + await self._on_change_event.wait() + + # Try to rebuild package and restart web server which might fail. + self._unschedule() + await self._rebuild_and_restart() + self._schedule() + + self._on_change_event.clear() + + async def _rebuild_and_restart(self) -> None: + log.info("File changes detected. Rebuilding package...") + + # Stop webserver. + try: + await self._webserver.stop_server() + except Exception: + log.exception("Failed to stop web server. The exception was:") + raise # Should not happen, thus we're propagating. + + # Build package. + try: + package_source = PackageSource(self._source_path) + with DirPackageBuilder(package_source) as builder: + builder.write_package() + except (PackageBuildError, PackageSourceValidationError): + log.exception("Failed to build package. The exception was:") + return + + # Start server. + try: + await self._webserver.start_server() + except Exception: + log.exception("Failed to start web server. The exception was:") diff --git a/questionpy_sdk/webserver/app.py b/questionpy_sdk/webserver/app.py index e1e9a89..e707f30 100644 --- a/questionpy_sdk/webserver/app.py +++ b/questionpy_sdk/webserver/app.py @@ -1,6 +1,8 @@ # This file is part of the QuestionPy SDK. (https://questionpy.org) # The QuestionPy SDK is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus +import asyncio +import logging import traceback from enum import StrEnum from functools import cached_property @@ -22,6 +24,8 @@ if TYPE_CHECKING: from questionpy_server.worker.worker import Worker +log = logging.getLogger("questionpy-sdk:web-server") + async def _extract_manifest(app: web.Application) -> None: webserver = app[SDK_WEBSERVER_APP_KEY] @@ -56,29 +60,37 @@ def __init__( self, package_location: PackageLocation, state_storage_path: Path = Path(__file__).parent / "question_state_storage", + host: str = "localhost", + port: int = 8080, ) -> None: - # We import here, so we don't have to work around circular imports. - from questionpy_sdk.webserver.routes.attempt import routes as attempt_routes # noqa: PLC0415 - from questionpy_sdk.webserver.routes.options import routes as options_routes # noqa: PLC0415 - from questionpy_sdk.webserver.routes.worker import routes as worker_routes # noqa: PLC0415 - self.package_location = package_location self._state_storage_root = state_storage_path + self._host = host + self._port = port - self.web_app = web.Application() - self.web_app[SDK_WEBSERVER_APP_KEY] = self + self._web_app: web.Application | None = None + self._runner: web.AppRunner | None = None + self.worker_pool: WorkerPool = WorkerPool(1, 500 * MiB, worker_type=ThreadWorker) - self.web_app.add_routes(attempt_routes) - self.web_app.add_routes(options_routes) - self.web_app.add_routes(worker_routes) - self.web_app.router.add_static("/static", Path(__file__).parent / "static") + async def start_server(self) -> None: + if self._web_app: + msg = "Web app is already running" + raise RuntimeError(msg) - self.web_app.on_startup.append(_extract_manifest) - self.web_app.middlewares.append(_invalid_question_state_middleware) + self._web_app = self._create_webapp() + self._runner = web.AppRunner(self._web_app) + await self._runner.setup() + await web.TCPSite(self._runner, self._host, self._port).start() - jinja2_extensions = ["jinja2.ext.do"] - aiohttp_jinja2.setup(self.web_app, loader=PackageLoader(__package__), extensions=jinja2_extensions) - self.worker_pool: WorkerPool = WorkerPool(1, 500 * MiB, worker_type=ThreadWorker) + async def stop_server(self) -> None: + if self._runner: + await self._runner.cleanup() + self._web_app = None + self._runner = None + + async def run_forever(self) -> None: + await self.start_server() + await asyncio.Event().wait() # run forever def read_state_file(self, filename: StateFilename) -> str | None: try: @@ -97,12 +109,35 @@ def delete_state_files(self, filename_1: StateFilename, *filenames: StateFilenam # Remove package state dir if it's now empty. self._package_state_dir.rmdir() - def start_server(self) -> None: - web.run_app(self.web_app) + def _create_webapp(self) -> web.Application: + # We import here, so we don't have to work around circular imports. + from questionpy_sdk.webserver.routes.attempt import routes as attempt_routes # noqa: PLC0415 + from questionpy_sdk.webserver.routes.options import routes as options_routes # noqa: PLC0415 + from questionpy_sdk.webserver.routes.worker import routes as worker_routes # noqa: PLC0415 + + app = web.Application() + app[SDK_WEBSERVER_APP_KEY] = self + + app.add_routes(attempt_routes) + app.add_routes(options_routes) + app.add_routes(worker_routes) + app.router.add_static("/static", Path(__file__).parent / "static") + + app.on_startup.append(_extract_manifest) + app.middlewares.append(_invalid_question_state_middleware) + + jinja2_extensions = ["jinja2.ext.do"] + aiohttp_jinja2.setup(app, loader=PackageLoader(__package__), extensions=jinja2_extensions) + + return app @cached_property def _package_state_dir(self) -> Path: - manifest = self.web_app[MANIFEST_APP_KEY] + if self._web_app is None: + msg = "Web app not initialized" + raise RuntimeError(msg) + + manifest = self._web_app[MANIFEST_APP_KEY] return self._state_storage_root / f"{manifest.namespace}-{manifest.short_name}-{manifest.version}" diff --git a/tests/conftest.py b/tests/conftest.py index 1da1497..aea5ae5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ # The QuestionPy SDK is free software released under terms of the MIT license. See LICENSE.md. # (c) Technische Universität Berlin, innoCampus +from collections.abc import Callable from pathlib import Path from shutil import copytree @@ -20,3 +21,8 @@ def source_path(request: pytest.FixtureRequest, tmp_path: Path) -> Path: copytree(src_path, dest_path, ignore=lambda src, names: (DIST_DIR,)) return dest_path + + +@pytest.fixture +def port(unused_tcp_port_factory: Callable) -> int: + return unused_tcp_port_factory() diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index f8761fe..719b36e 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -4,25 +4,19 @@ import asyncio import threading -from collections.abc import Callable, Iterator +from collections.abc import Iterator from pathlib import Path import pytest -from aiohttp import web from selenium import webdriver from questionpy_sdk.webserver.app import WebServer @pytest.fixture -def sdk_web_server(tmp_path: Path, request: pytest.FixtureRequest) -> WebServer: +def sdk_web_server(tmp_path: Path, request: pytest.FixtureRequest, port: int) -> WebServer: # We DON'T want state files to persist between tests, so we use a temp dir which is removed after each test. - return WebServer(request.function.qpy_package_location, state_storage_path=tmp_path) - - -@pytest.fixture -def port(unused_tcp_port_factory: Callable) -> int: - return unused_tcp_port_factory() + return WebServer(request.function.qpy_package_location, state_storage_path=tmp_path, port=port) @pytest.fixture @@ -38,18 +32,12 @@ def driver() -> Iterator[webdriver.Chrome]: yield chrome_driver -def start_runner(web_app: web.Application, unused_port: int) -> None: - runner = web.AppRunner(web_app) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(runner.setup()) - site = web.TCPSite(runner, "localhost", unused_port) - loop.run_until_complete(site.start()) - loop.run_forever() +def start_runner(web_app: WebServer) -> None: + asyncio.run(web_app.run_forever()) @pytest.fixture -def _start_runner_thread(sdk_web_server: WebServer, port: int) -> None: - app_thread = threading.Thread(target=start_runner, args=(sdk_web_server.web_app, port)) +def _start_runner_thread(sdk_web_server: WebServer) -> None: + app_thread = threading.Thread(target=start_runner, args=(sdk_web_server,)) app_thread.daemon = True # Set the thread as a daemon to automatically stop when main thread exits app_thread.start() diff --git a/tests/questionpy_sdk/commands/conftest.py b/tests/questionpy_sdk/commands/conftest.py index d9d78d0..f89eb23 100644 --- a/tests/questionpy_sdk/commands/conftest.py +++ b/tests/questionpy_sdk/commands/conftest.py @@ -75,25 +75,29 @@ async def long_running_cmd(args: Iterable[str], timeout: float = 5) -> AsyncIter popen_args = [sys.executable, "-m", "questionpy_sdk", "--", *args] proc = await asyncio.create_subprocess_exec(*popen_args, stdin=PIPE, stdout=PIPE, stderr=PIPE) + def terminate() -> None: + with contextlib.suppress(ProcessLookupError): + proc.send_signal(signal.SIGTERM) + # ensure tests don't hang indefinitely - async def kill_after_timeout() -> None: + async def terminate_after_timeout() -> None: await asyncio.sleep(timeout) - proc.send_signal(signal.SIGINT) + terminate() - kill_task = asyncio.create_task(kill_after_timeout()) + kill_task = asyncio.create_task(terminate_after_timeout()) yield proc finally: if kill_task: kill_task.cancel() - proc.send_signal(signal.SIGINT) + terminate() await proc.wait() -async def assert_webserver_is_up(session: aiohttp.ClientSession, url: str = "http://localhost:8080/") -> None: +async def assert_webserver_is_up(session: aiohttp.ClientSession, port: int) -> None: for _ in range(50): # allow 5 sec to come up try: - async with session.get(url) as response: + async with session.get(f"http://localhost:{port}/") as response: assert response.status == 200 return except aiohttp.ClientConnectionError: diff --git a/tests/questionpy_sdk/commands/test_run.py b/tests/questionpy_sdk/commands/test_run.py index 222ab1c..9aeb2b7 100644 --- a/tests/questionpy_sdk/commands/test_run.py +++ b/tests/questionpy_sdk/commands/test_run.py @@ -9,7 +9,7 @@ from questionpy_common.constants import DIST_DIR, MANIFEST_FILENAME from questionpy_sdk.commands.run import run -from questionpy_sdk.package.builder import DirPackageBuilder +from questionpy_sdk.package.builder import DirPackageBuilder, ZipPackageBuilder from questionpy_sdk.package.source import PackageSource from tests.questionpy_sdk.commands.conftest import assert_webserver_is_up, long_running_cmd @@ -33,18 +33,46 @@ def test_run_non_zip_file(runner: CliRunner, cwd: Path) -> None: assert "'README.md' doesn't look like a QPy package file, source directory, or dist directory." in result.stdout -async def test_run_source_dir_builds_package(source_path: Path, client_session: ClientSession) -> None: - async with long_running_cmd(("run", str(source_path))) as proc: +async def test_run_source_dir_builds_package(source_path: Path, client_session: ClientSession, port: int) -> None: + async with long_running_cmd(("run", "--port", str(port), str(source_path))) as proc: assert proc.stdout first_line = (await proc.stdout.readline()).decode("utf-8") assert f"Successfully built package '{source_path}'" in first_line assert (source_path / DIST_DIR / MANIFEST_FILENAME).exists() - await assert_webserver_is_up(client_session) + await assert_webserver_is_up(client_session, port) -async def test_run_dist_dir(source_path: Path, client_session: ClientSession) -> None: +async def test_run_dist_dir(source_path: Path, client_session: ClientSession, port: int) -> None: with DirPackageBuilder(PackageSource(source_path)) as builder: builder.write_package() - async with long_running_cmd(("run", str(source_path / DIST_DIR))): - await assert_webserver_is_up(client_session) + async with long_running_cmd(("run", "--port", str(port), str(source_path / DIST_DIR))): + await assert_webserver_is_up(client_session, port) + + +async def test_run_watch_with_source_dir(source_path: Path, client_session: ClientSession, port: int) -> None: + async with long_running_cmd(("run", "--watch", "--port", str(port), str(source_path))): + await assert_webserver_is_up(client_session, port) + + +async def test_run_watch_with_dist_dir(source_path: Path, port: int) -> None: + with DirPackageBuilder(PackageSource(source_path)) as builder: + builder.write_package() + + async with long_running_cmd(("run", "--watch", "--port", str(port), str(source_path / DIST_DIR))) as proc: + assert proc.stderr + assert await proc.wait() != 0 + stderr = (await proc.stderr.read()).decode("utf-8") + assert "The --watch option only works with source directories." in stderr + + +async def test_run_watch_with_qpy_file(cwd: Path, source_path: Path, port: int) -> None: + qpy_path = cwd / "test.qpy" + with ZipPackageBuilder(qpy_path, PackageSource(source_path)) as builder: + builder.write_package() + + async with long_running_cmd(("run", "--watch", "--port", str(port), str(qpy_path))) as proc: + assert proc.stderr + assert await proc.wait() != 0 + stderr = (await proc.stderr.read()).decode("utf-8") + assert "The --watch option only works with source directories." in stderr diff --git a/tests/questionpy_sdk/test_watcher.py b/tests/questionpy_sdk/test_watcher.py new file mode 100644 index 0000000..d81d219 --- /dev/null +++ b/tests/questionpy_sdk/test_watcher.py @@ -0,0 +1,76 @@ +# This file is part of the QuestionPy SDK. (https://questionpy.org) +# The QuestionPy SDK is free software released under terms of the MIT license. See LICENSE.md. +# (c) Technische Universität Berlin, innoCampus + +import asyncio +from pathlib import Path +from typing import cast + +import pytest +from watchdog.events import ( + DirCreatedEvent, + DirDeletedEvent, + DirModifiedEvent, + DirMovedEvent, + FileClosedEvent, + FileCreatedEvent, + FileDeletedEvent, + FileModifiedEvent, + FileMovedEvent, + FileOpenedEvent, + FileSystemEvent, +) + +from questionpy_common.constants import DIST_DIR +from questionpy_sdk.watcher import _EventHandler + +some_path = Path("/", "path", "to") + + +@pytest.fixture +def event_handler() -> _EventHandler: + async def notify() -> None: + pass + + mock_loop = cast(asyncio.AbstractEventLoop, None) + return _EventHandler(mock_loop, notify, some_path) + + +@pytest.mark.parametrize( + "event", + [ + DirCreatedEvent(src_path=str(some_path / "foo")), + DirDeletedEvent(src_path=str(some_path / "foo")), + DirModifiedEvent(src_path=str(some_path / "foo")), + DirMovedEvent(src_path=str(some_path / DIST_DIR / "foo"), dest_path=str(some_path / "foo")), + FileCreatedEvent(src_path=str(some_path / "foo")), + FileCreatedEvent(src_path=str(some_path / "python" / "foo" / "bar" / "module.py")), + FileDeletedEvent(src_path=str(some_path / "foo")), + FileDeletedEvent(src_path=str(some_path / "python" / "foo" / "bar" / "module.py")), + FileModifiedEvent(src_path=str(some_path)), + FileModifiedEvent(src_path=str(some_path / "python" / "foo" / "bar" / "module.py")), + FileMovedEvent(src_path=str(some_path / DIST_DIR / "foo"), dest_path=str(some_path / "foo")), + ], +) +def test_should_not_ignore_events(event: FileSystemEvent, event_handler: _EventHandler) -> None: + assert not event_handler._ignore_event(event) + + +# test that the watcher is ignoring certain events, like moving a file into the `dist` folder +@pytest.mark.parametrize( + "event", + [ + DirCreatedEvent(src_path=str(some_path / DIST_DIR / "foo")), + DirDeletedEvent(src_path=str(some_path / DIST_DIR / "foo")), + DirModifiedEvent(src_path=str(some_path / DIST_DIR / "foo")), + FileClosedEvent(src_path=str(some_path / "foo")), + DirMovedEvent(src_path=str(some_path / "foo"), dest_path=str(some_path / DIST_DIR / "foo")), + FileCreatedEvent(src_path=str(some_path / DIST_DIR / "foo")), + FileDeletedEvent(src_path=str(some_path / DIST_DIR / "foo")), + FileModifiedEvent(src_path=str(some_path / DIST_DIR)), + FileMovedEvent(src_path=str(some_path / "foo"), dest_path=str(some_path / DIST_DIR / "foo")), + FileOpenedEvent(src_path=str(some_path / "foo")), + ], +) +def test_should_ignore_events(event: FileSystemEvent, event_handler: _EventHandler) -> None: + assert event_handler._ignore_event(event)