Skip to content

Commit

Permalink
Cache workspace content (#2497)
Browse files Browse the repository at this point in the history
## Changes
Loading workspace content is slow and bound by rate limits.
This PR introduces a cache for workspace content.

### Linked issues
None

### Functionality
None

### Tests
- [x] added unit tests

---------

Co-authored-by: Eric Vergnaud <eric.vergnaud@databricks.com>
  • Loading branch information
ericvergnaud and ericvergnaud authored Aug 28, 2024
1 parent e5e0562 commit fc554ae
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 5 deletions.
139 changes: 139 additions & 0 deletions src/databricks/labs/ucx/mixins/cached_workspace_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

import os
from collections import OrderedDict
from collections.abc import Generator
from io import StringIO, BytesIO

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.workspace import ObjectInfo
from databricks.labs.blueprint.paths import WorkspacePath


class _CachedIO:

def __init__(self, content):
self._content = content
self._index = 0

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
return False

def read(self, *args, **_kwargs):
count = -1 if len(args) < 1 or args[0] < 1 else args[0]
if count == -1:
return self._content
start = self._index
end = self._index + count
if start >= len(self._content):
return None
self._index = self._index + count
return self._content[start:end]

def __iter__(self):
if isinstance(self._content, str):
yield from StringIO(self._content)
return
yield from self._as_string_io().__iter__()

def with_mode(self, mode: str):
if 'b' in mode:
return self._as_bytes_io()
return self._as_string_io()

def _as_bytes_io(self):
if isinstance(self._content, bytes):
return self
return BytesIO(self._content.encode("utf-8-sig"))

def _as_string_io(self):
if isinstance(self._content, str):
return self
return StringIO(self._content.decode("utf-8"))


# lru_cache won't let us invalidate cache entries
# so we provide our own custom lru_cache
class _PathLruCache:

def __init__(self, max_entries: int):
self._datas: OrderedDict[str, bytes | str] = OrderedDict()
self._max_entries = max_entries

def open(self, cached_path: _CachedPath, mode, buffering, encoding, errors, newline):
path = str(cached_path)
if path in self._datas:
self._datas.move_to_end(path)
return _CachedIO(self._datas[path]).with_mode(mode)
io_obj = WorkspacePath.open(cached_path, mode, buffering, encoding, errors, newline)
# can't read twice from an IO so need to cache data rather than the io object
data = io_obj.read()
self._datas[path] = data
result = _CachedIO(data).with_mode(mode)
if len(self._datas) > self._max_entries:
self._datas.popitem(last=False)
return result

def clear(self):
self._datas.clear()

def remove(self, path: str):
if path in self._datas:
self._datas.pop(path)


class _CachedPath(WorkspacePath):
def __init__(self, cache: _PathLruCache, ws: WorkspaceClient, *args: str | bytes | os.PathLike):
super().__init__(ws, *args)
self._cache = cache

def with_object_info(self, object_info: ObjectInfo):
self._cached_object_info = object_info
return self

def with_segments(self, *path_segments: bytes | str | os.PathLike) -> _CachedPath:
return type(self)(self._cache, self._ws, *path_segments)

def iterdir(self) -> Generator[_CachedPath, None, None]:
for object_info in self._ws.workspace.list(self.as_posix()):
path = object_info.path
if path is None:
msg = f"Cannot initialise without object path: {object_info}"
raise ValueError(msg)
child = _CachedPath(self._cache, self._ws, path)
yield child.with_object_info(object_info)

def open(
self,
mode: str = "r",
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
):
# only cache reads
if 'r' in mode:
return self._cache.open(self, mode, buffering, encoding, errors, newline)
self._cache.remove(str(self))
return super().open(mode, buffering, encoding, errors, newline)

def _cached_open(self, mode: str, buffering: int, encoding: str | None, errors: str | None, newline: str | None):
return super().open(mode, buffering, encoding, errors, newline)

# _rename calls unlink so no need to override it
def unlink(self, missing_ok: bool = False) -> None:
self._cache.remove(str(self))
return super().unlink(missing_ok)


class WorkspaceCache:

def __init__(self, ws: WorkspaceClient, max_entries=2048):
self._ws = ws
self._cache = _PathLruCache(max_entries)

def get_path(self, path: str):
return _CachedPath(self._cache, self._ws, path)
10 changes: 6 additions & 4 deletions src/databricks/labs/ucx/source_code/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from urllib import parse

from databricks.labs.blueprint.parallel import ManyError, Threads
from databricks.labs.blueprint.paths import DBFSPath, WorkspacePath
from databricks.labs.blueprint.paths import DBFSPath
from databricks.labs.lsql.backends import SqlBackend
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.service import compute, jobs

from databricks.labs.ucx.assessment.crawlers import runtime_version_tuple
from databricks.labs.ucx.hive_metastore.migration_status import MigrationIndex
from databricks.labs.ucx.mixins.cached_workspace_path import WorkspaceCache
from databricks.labs.ucx.source_code.base import CurrentSessionState, is_a_notebook, LocatedAdvice
from databricks.labs.ucx.source_code.graph import (
Dependency,
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(self, ws: WorkspaceClient, task: jobs.Task, job: jobs.Job):
self._task = task
self._job = job
self._ws = ws
self._cache = WorkspaceCache(ws)
self._named_parameters: dict[str, str] | None = {}
self._parameters: list[str] | None = []
self._spark_conf: dict[str, str] | None = {}
Expand Down Expand Up @@ -123,7 +125,7 @@ def _as_path(self, path: str) -> Path:
parsed_path = parse.urlparse(path)
match parsed_path.scheme:
case "":
return WorkspacePath(self._ws, path)
return self._cache.get_path(path)
case "dbfs":
return DBFSPath(self._ws, parsed_path.path)
case other:
Expand Down Expand Up @@ -186,7 +188,7 @@ def _register_notebook(self, graph: DependencyGraph) -> Iterable[DependencyProbl
notebook_path = self._task.notebook_task.notebook_path
logger.info(f'Discovering {self._task.task_key} entrypoint: {notebook_path}')
# Notebooks can't be on DBFS.
path = WorkspacePath(self._ws, notebook_path)
path = self._cache.get_path(notebook_path)
return graph.register_notebook(path, False)

def _register_spark_python_task(self, graph: DependencyGraph):
Expand Down Expand Up @@ -261,7 +263,7 @@ def _register_pipeline_task(self, graph: DependencyGraph):
if library.notebook.path:
notebook_path = library.notebook.path
# Notebooks can't be on DBFS.
path = WorkspacePath(self._ws, notebook_path)
path = self._cache.get_path(notebook_path)
# the notebook is the root of the graph, so there's no context to inherit
yield from graph.register_notebook(path, inherit_context=False)
if library.jar:
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def mock_workspace_client(
]
),
}
ws.workspace.download.side_effect = lambda file_name: io.StringIO(download_yaml[os.path.basename(file_name)])
ws.workspace.download.side_effect = lambda file_name, *, format=None: io.StringIO(
download_yaml[os.path.basename(file_name)]
)
return ws


Expand Down
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from databricks.labs.blueprint.installation import MockInstallation
from databricks.labs.lsql.backends import MockBackend

from databricks.labs.ucx.source_code.graph import BaseNotebookResolver
from databricks.labs.ucx.source_code.path_lookup import PathLookup
from databricks.sdk import WorkspaceClient, AccountClient
Expand Down
Empty file added tests/unit/mixins/__init__.py
Empty file.
112 changes: 112 additions & 0 deletions tests/unit/mixins/test_cached_workspace_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import io
from unittest.mock import create_autospec

import pytest

from tests.unit import mock_workspace_client

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.workspace import ObjectInfo, ObjectType

from databricks.labs.ucx.mixins.cached_workspace_path import WorkspaceCache
from databricks.labs.ucx.source_code.base import guess_encoding


class TestWorkspaceCache(WorkspaceCache):

@property
def data_cache(self):
return self._cache


def test_path_like_returns_cached_instance():
cache = TestWorkspaceCache(mock_workspace_client())
parent = cache.get_path("path")
child = parent / "child"
_cache = getattr(child, "_cache")
assert _cache == cache.data_cache


def test_iterdir_returns_cached_instances():
ws = create_autospec(WorkspaceClient)
ws.workspace.get_status.return_value = ObjectInfo(object_type=ObjectType.DIRECTORY)
ws.workspace.list.return_value = list(ObjectInfo(object_type=ObjectType.FILE, path=s) for s in ("a", "b", "c"))
cache = TestWorkspaceCache(ws)
parent = cache.get_path("dir")
assert parent.is_dir()
for child in parent.iterdir():
_cache = getattr(child, "_cache")
assert _cache == cache.data_cache


def test_download_is_only_called_once_per_instance():
ws = mock_workspace_client()
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
cache = WorkspaceCache(ws)
path = cache.get_path("path")
for _ in range(0, 4):
_ = path.read_text()
assert ws.workspace.download.call_count == 1


def test_download_is_only_called_once_across_instances():
ws = mock_workspace_client()
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
cache = WorkspaceCache(ws)
for _ in range(0, 4):
path = cache.get_path("path")
_ = path.read_text()
assert ws.workspace.download.call_count == 1


def test_download_is_called_again_after_unlink():
ws = mock_workspace_client()
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
cache = WorkspaceCache(ws)
path = cache.get_path("path")
_ = path.read_text()
path = cache.get_path("path")
path.unlink()
_ = path.read_text()
assert ws.workspace.download.call_count == 2


def test_download_is_called_again_after_rename():
ws = mock_workspace_client()
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
cache = WorkspaceCache(ws)
path = cache.get_path("path")
_ = path.read_text()
path.rename("abcd")
_ = path.read_text()
assert ws.workspace.download.call_count == 3 # rename reads the old content


def test_encoding_is_guessed_after_download():
ws = mock_workspace_client()
ws.workspace.download.side_effect = lambda _, *, format: io.BytesIO("abc".encode())
cache = WorkspaceCache(ws)
path = cache.get_path("path")
_ = path.read_text()
guess_encoding(path)


@pytest.mark.parametrize(
"mode, data",
[
("r", io.BytesIO("abc".encode("utf-8-sig"))),
("rb", io.BytesIO("abc".encode("utf-8-sig"))),
],
)
def test_sequential_read_completes(mode, data):
ws = mock_workspace_client()
ws.workspace.download.side_effect = lambda _, *, format: data
cache = WorkspaceCache(ws)
path = cache.get_path("path")
with path.open(mode) as file:
count = 0
while _ := file.read(1):
count = count + 1
if count > 10:
break
assert count < 10

0 comments on commit fc554ae

Please sign in to comment.