Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added common install state primitives with strong typing #27

Merged
merged 33 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
396 changes: 368 additions & 28 deletions README.md

Large diffs are not rendered by default.

Binary file added docs/pytest-installation-asserts.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ classifiers = [
]
dependencies = ["databricks-sdk>=0.16.0"]

[project.optional-dependencies]
yaml = ["PyYAML>=6.0.0,<7.0.0"]

[project.urls]
Issues = "https://github.com/databrickslabs/blueprint/issues"
Source = "https://github.com/databrickslabs/blueprint"
Expand All @@ -33,6 +36,7 @@ path = "src/databricks/labs/blueprint/__about__.py"

[tool.hatch.envs.default]
dependencies = [
"databricks-labs-blueprint[yaml]",
"coverage[toml]>=6.5",
"pytest",
"pytest-xdist",
Expand All @@ -52,8 +56,8 @@ python="3.10"
path = ".venv"

[tool.hatch.envs.default.scripts]
test = "pytest -n auto --cov src --cov-report=xml --timeout 30 tests/unit --durations 20"
coverage = "pytest -n auto --cov src tests/unit --timeout 30 --cov-report=html --durations 20"
test = "pytest -n 2 --cov src --cov-report=xml --timeout 30 tests/unit --durations 20"
coverage = "pytest -n 2 --cov src tests/unit --timeout 30 --cov-report=html --durations 20"
integration = "pytest -n 10 --cov src tests/integration --durations 20"
fmt = ["isort .",
"ruff format",
Expand All @@ -68,7 +72,7 @@ verify = ["black --check .",
profile = "black"

[tool.pytest.ini_options]
addopts = "-s -p no:warnings -vv --cache-clear"
addopts = "--no-header"
cache_dir = ".venv/pytest-cache"

[tool.black]
Expand Down
632 changes: 632 additions & 0 deletions src/databricks/labs/blueprint/installation.py

Large diffs are not rendered by default.

87 changes: 27 additions & 60 deletions src/databricks/labs/blueprint/installer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import json
import logging
import threading
from json import JSONDecodeError
from typing import TypedDict
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any

from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.service.workspace import ImportFormat
from databricks.sdk.retries import retried

from databricks.labs.blueprint.installation import IllegalState, Installation

logger = logging.getLogger(__name__)

Resources = dict[str, str]
Json = dict[str, Any]


@dataclass
class RawState:
__file__ = "state.json"
__version__ = 1

class RawState(TypedDict):
resources: dict[str, Resources]
resources: dict[str, dict[str, str]] = field(default_factory=dict)


class IllegalState(ValueError):
class StateError(IllegalState):
pass


Expand All @@ -26,63 +31,25 @@

_state: RawState | None = None

def __init__(
self, ws: WorkspaceClient, product: str, config_version: int = 1, *, install_folder: str | None = None
):
self._ws = ws
self._product = product
self._install_folder = install_folder
self._config_version = config_version
def __init__(self, ws: WorkspaceClient, product: str, *, install_folder: str | None = None):
self._installation = Installation(ws, product, install_folder=install_folder)
self._lock = threading.Lock()

def product(self) -> str:
return self._product

def install_folder(self) -> str:
if self._install_folder:
return self._install_folder
me = self._ws.current_user.me()
self._install_folder = f"/Users/{me.user_name}/.{self._product}"
return self._install_folder
def install_folder(self):
return self._installation.install_folder()

def __getattr__(self, item: str) -> Resources:
@retried(on=[StateError], timeout=timedelta(seconds=5))
def __getattr__(self, item: str) -> dict[str, str]:
with self._lock:
if not self._state:
self._state = self._load()
if item not in self._state["resources"]:
self._state["resources"][item] = {}
return self._state["resources"][item]

def _state_file(self) -> str:
return f"{self.install_folder()}/state.json"

def _load(self) -> RawState:
"""Loads remote state"""
default_state: RawState = {"resources": {}}
try:
raw = json.load(self._ws.workspace.download(self._state_file()))
version = raw.pop("$version", None)
if version != self._config_version:
msg = f"expected state $version={self._config_version}, got={version}"
raise IllegalState(msg)
return raw
except NotFound:
return default_state
except JSONDecodeError:
logger.warning(f"JSON state file corrupt: {self._state_file}")
return default_state
self._state = self._installation.load(RawState)
if not self._state:
raise StateError("Failed to load raw state")

Check warning on line 47 in src/databricks/labs/blueprint/installer.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/blueprint/installer.py#L47

Added line #L47 was not covered by tests
if item not in self._state.resources:
self._state.resources[item] = {}
return self._state.resources[item]

def save(self) -> None:
"""Saves remote state"""
with self._lock:
state: dict = {}
if self._state:
state = self._state.copy() # type: ignore[assignment]
state["$version"] = self._config_version
state_dump = json.dumps(state, indent=2).encode("utf8")
self._ws.workspace.upload(
self._state_file(),
state_dump, # type: ignore[arg-type]
format=ImportFormat.AUTO,
overwrite=True,
)
self._installation.save(self._state)
10 changes: 5 additions & 5 deletions src/databricks/labs/blueprint/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import os
import re
import threading
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Generic, TypeVar
from typing import Collection, Generic, TypeVar

MIN_THREADS = 8

Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, name, tasks: Sequence[Task[Result]], num_threads: int):
@classmethod
def gather(
cls, name: str, tasks: Sequence[Task[Result]], num_threads: int | None = None
) -> tuple[Iterable[Result], list[Exception]]:
) -> tuple[Collection[Result], list[Exception]]:
if num_threads is None:
num_cpus = os.cpu_count()
if num_cpus is None:
Expand All @@ -50,13 +50,13 @@ def gather(
return cls(name, tasks, num_threads=num_threads)._run()

@classmethod
def strict(cls, name: str, tasks: Sequence[Task[Result]]) -> Iterable[Result]:
def strict(cls, name: str, tasks: Sequence[Task[Result]]) -> Collection[Result]:
collected, errs = cls.gather(name, tasks)
if errs:
raise ManyError(errs)
return collected

def _run(self) -> tuple[Iterable[Result], list[Exception]]:
def _run(self) -> tuple[Collection[Result], list[Exception]]:
given_cnt = len(self._tasks)
if given_cnt == 0:
return [], []
Expand Down
45 changes: 26 additions & 19 deletions src/databricks/labs/blueprint/wheels.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import datetime
import logging
import os
import shutil
import subprocess
import sys
import tempfile
from contextlib import AbstractContextManager
from dataclasses import dataclass
from pathlib import Path

from databricks.sdk import WorkspaceClient
from databricks.sdk.mixins.compute import SemVer
from databricks.sdk.service.workspace import ImportFormat

from databricks.labs.blueprint.entrypoint import find_project_root
from databricks.labs.blueprint.installation import Installation
from databricks.labs.blueprint.installer import InstallState

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -123,38 +123,35 @@
return version_data["__version__"]


class Wheels(AbstractContextManager):
@dataclass
class Version:
version: str
wheel: str


class WheelsV2(AbstractContextManager):
"""Wheel builder"""

__version: str | None = None

def __init__(
self, ws: WorkspaceClient, install_state: InstallState, product_info: ProductInfo, *, verbose: bool = False
):
self._ws = ws
self._install_state = install_state
def __init__(self, installation: Installation, product_info: ProductInfo, *, verbose: bool = False):
self._installation = installation
self._product_info = product_info
self._verbose = verbose

def upload_to_dbfs(self) -> str:
with self._local_wheel.open("rb") as f:
self._ws.dbfs.mkdirs(self._remote_dir_name)
logger.info(f"Uploading wheel to dbfs:{self._remote_wheel}")
self._ws.dbfs.upload(self._remote_wheel, f, overwrite=True)
return self._remote_wheel
return self._installation.upload_dbfs(f"wheels/{self._local_wheel.name}", f.read())

def upload_to_wsfs(self) -> str:
with self._local_wheel.open("rb") as f:
self._ws.workspace.mkdirs(self._remote_dir_name)
logger.info(f"Uploading wheel to /Workspace{self._remote_wheel}")
self._ws.workspace.upload(self._remote_wheel, f, overwrite=True, format=ImportFormat.AUTO)
return self._remote_wheel
remote_wheel = self._installation.upload(f"wheels/{self._local_wheel.name}", f.read())
self._installation.save(Version(version=self._product_info.version(), wheel=remote_wheel))
return remote_wheel

def __enter__(self) -> "Wheels":
def __enter__(self) -> "WheelsV2":
self._tmp_dir = tempfile.TemporaryDirectory()
self._local_wheel = self._build_wheel(self._tmp_dir.name, verbose=self._verbose)
self._remote_wheel = f"{self._install_state.install_folder()}/wheels/{self._local_wheel.name}"
self._remote_dir_name = os.path.dirname(self._remote_wheel)
return self

def __exit__(self, __exc_type, __exc_value, __traceback):
Expand Down Expand Up @@ -210,3 +207,13 @@

shutil.copytree(project_root, tmp_dir_path, ignore=copy_ignore)
return tmp_dir_path


class Wheels(WheelsV2):
"""Wheel builder"""

def __init__(
self, ws: WorkspaceClient, install_state: InstallState, product_info: ProductInfo, *, verbose: bool = False
):
installation = Installation(ws, product_info.product_name(), install_folder=install_state.install_folder())
super().__init__(installation, product_info, verbose=verbose)

Check warning on line 219 in src/databricks/labs/blueprint/wheels.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/blueprint/wheels.py#L218-L219

Added lines #L218 - L219 were not covered by tests
3 changes: 3 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import pytest

pytest.register_assert_rewrite("databricks.labs.blueprint.installation")
Loading
Loading