Skip to content

Commit

Permalink
remote: support legacy cache push/fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla authored and efiop committed Jun 9, 2023
1 parent 2cd33e1 commit 80f414c
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 42 deletions.
124 changes: 112 additions & 12 deletions dvc/data_cloud.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""Manages dvc remotes that user can use with push/pull/status commands."""

import logging
from typing import TYPE_CHECKING, Iterable, Optional
from typing import TYPE_CHECKING, Iterable, Optional, Set, Tuple

from dvc.config import NoRemoteError, RemoteConfigError
from dvc.utils.objects import cached_property
from dvc_data.hashfile.db import get_index
from dvc_data.hashfile.transfer import TransferResult

if TYPE_CHECKING:
from dvc.fs import FileSystem
from dvc_data.hashfile.db import HashFileDB
from dvc_data.hashfile.hash_info import HashInfo
from dvc_data.hashfile.status import CompareStatusResult
from dvc_data.hashfile.transfer import TransferResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,6 +50,21 @@ def legacy_odb(self) -> "HashFileDB":
return get_odb(self.fs, path, hash_name="md5-dos2unix", **self.config)


def _split_legacy_hash_infos(
hash_infos: Iterable["HashInfo"],
) -> Tuple[Set["HashInfo"], Set["HashInfo"]]:
from dvc.cachemgr import LEGACY_HASH_NAMES

legacy = set()
default = set()
for hi in hash_infos:
if hi.name in LEGACY_HASH_NAMES:
legacy.add(hi)
else:
default.add(hi)
return legacy, default


class DataCloud:
"""Class that manages dvc remotes.
Expand Down Expand Up @@ -167,14 +182,40 @@ def push(
By default remote from core.remote config option is used.
odb: optional ODB to push to. Overrides remote.
"""
odb = odb or self.get_remote_odb(remote, "push")
if odb is not None:
return self._push(objs, jobs=jobs, odb=odb)
legacy_objs, default_objs = _split_legacy_hash_infos(objs)
result = TransferResult(set(), set())
if legacy_objs:
odb = self.get_remote_odb(remote, "push", hash_name="md5-dos2unix")
t, f = self._push(legacy_objs, jobs=jobs, odb=odb)
result.transferred.update(t)
result.failed.update(f)
if default_objs:
odb = self.get_remote_odb(remote, "push")
t, f = self._push(default_objs, jobs=jobs, odb=odb)
result.transferred.update(t)
result.failed.update(f)
return result

def _push(
self,
objs: Iterable["HashInfo"],
*,
jobs: Optional[int] = None,
odb: "HashFileDB",
) -> "TransferResult":
if odb.hash_name == "md5-dos2unix":
cache = self.repo.cache.legacy
else:
cache = self.repo.cache.local
return self.transfer(
self.repo.cache.local,
cache,
odb,
objs,
jobs=jobs,
dest_index=get_index(odb),
cache_odb=self.repo.cache.local,
cache_odb=cache,
validate_status=self._log_missing,
)

Expand All @@ -194,14 +235,41 @@ def pull(
By default remote from core.remote config option is used.
odb: optional ODB to pull from. Overrides remote.
"""
odb = odb or self.get_remote_odb(remote, "pull")
if odb is not None:
return self._pull(objs, jobs=jobs, odb=odb)
legacy_objs, default_objs = _split_legacy_hash_infos(objs)
result = TransferResult(set(), set())
if legacy_objs:
odb = self.get_remote_odb(remote, "pull", hash_name="md5-dos2unix")
assert odb.hash_name == "md5-dos2unix"
t, f = self._pull(legacy_objs, jobs=jobs, odb=odb)
result.transferred.update(t)
result.failed.update(f)
if default_objs:
odb = self.get_remote_odb(remote, "pull")
t, f = self._pull(default_objs, jobs=jobs, odb=odb)
result.transferred.update(t)
result.failed.update(f)
return result

def _pull(
self,
objs: Iterable["HashInfo"],
*,
jobs: Optional[int] = None,
odb: "HashFileDB",
) -> "TransferResult":
if odb.hash_name == "md5-dos2unix":
cache = self.repo.cache.legacy
else:
cache = self.repo.cache.local
return self.transfer(
odb,
self.repo.cache.local,
cache,
objs,
jobs=jobs,
src_index=get_index(odb),
cache_odb=self.repo.cache.local,
cache_odb=cache,
verify=odb.verify,
validate_status=self._log_missing,
)
Expand All @@ -223,17 +291,49 @@ def status(
is used.
odb: optional ODB to check status from. Overrides remote.
"""
from dvc_data.hashfile.status import CompareStatusResult

if odb is not None:
return self._status(objs, jobs=jobs, odb=odb)
result = CompareStatusResult(set(), set(), set(), set())
legacy_objs, default_objs = _split_legacy_hash_infos(objs)
if legacy_objs:
odb = self.get_remote_odb(remote, "status", hash_name="md5-dos2unix")
assert odb.hash_name == "md5-dos2unix"
o, m, n, d = self._status(legacy_objs, jobs=jobs, odb=odb)
result.ok.update(o)
result.missing.update(m)
result.new.update(n)
result.deleted.update(d)
if default_objs:
odb = self.get_remote_odb(remote, "status")
o, m, n, d = self._status(default_objs, jobs=jobs, odb=odb)
result.ok.update(o)
result.missing.update(m)
result.new.update(n)
result.deleted.update(d)
return result

def _status(
self,
objs: Iterable["HashInfo"],
*,
jobs: Optional[int] = None,
odb: "HashFileDB",
):
from dvc_data.hashfile.status import compare_status

if not odb:
odb = self.get_remote_odb(remote, "status")
if odb.hash_name == "md5-dos2unix":
cache = self.repo.cache.legacy
else:
cache = self.repo.cache.local
return compare_status(
self.repo.cache.local,
cache,
odb,
objs,
jobs=jobs,
dest_index=get_index(odb),
cache_odb=self.repo.cache.local,
cache_odb=cache,
)

def get_url_for(self, remote, checksum):
Expand Down
72 changes: 43 additions & 29 deletions dvc/repo/imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from functools import partial
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, List, Set, Tuple, Union

Expand All @@ -17,17 +18,18 @@

def unfetched_view(
index: "Index", targets: "TargetType", unpartial: bool = False, **kwargs
) -> Tuple["IndexView", List["Dependency"]]:
) -> Tuple["IndexView", "IndexView", List["Dependency"]]:
"""Return index view of imports which have not been fetched.
Returns:
Tuple in the form (view, changed_deps) where changed_imports is a list
of import dependencies that cannot be fetched due to changed data
source.
Tuple in the form (legacy_view, view, changed_deps) where changed_imports is a
list of import dependencies that cannot be fetched due to changed data source.
"""
from dvc.cachemgr import LEGACY_HASH_NAMES

changed_deps: List["Dependency"] = []

def need_fetch(stage: "Stage") -> bool:
def need_fetch(stage: "Stage", legacy: bool = False) -> bool:
if not stage.is_import or (stage.is_partial_import and not unpartial):
return False

Expand All @@ -40,10 +42,19 @@ def need_fetch(stage: "Stage") -> bool:
changed_deps.append(dep)
return False

return True
if out.hash_name in LEGACY_HASH_NAMES and legacy:
return True
if out.hash_name not in LEGACY_HASH_NAMES and not legacy:
return True
return False

legacy_unfetched = index.targets_view(
targets,
stage_filter=partial(need_fetch, legacy=True),
**kwargs,
)
unfetched = index.targets_view(targets, stage_filter=need_fetch, **kwargs)
return unfetched, changed_deps
return legacy_unfetched, unfetched, changed_deps


def partial_view(index: "Index", targets: "TargetType", **kwargs) -> "IndexView":
Expand Down Expand Up @@ -94,33 +105,36 @@ def save_imports(

downloaded: Set["HashInfo"] = set()

unfetched, changed = unfetched_view(
legacy_unfetched, unfetched, changed = unfetched_view(
repo.index, targets, unpartial=unpartial, **kwargs
)
for dep in changed:
logger.warning(str(DataSourceChanged(f"{dep.stage} ({dep})")))

data_view = unfetched.data["repo"]
if len(data_view):
cache = repo.cache.local
if not cache.fs.exists(cache.path):
os.makedirs(cache.path)
with TemporaryDirectory(dir=cache.path) as tmpdir:
with Callback.as_tqdm_callback(
desc="Downloading imports from source",
unit="files",
) as cb:
checkout(data_view, tmpdir, cache.fs, callback=cb, storage="remote")
md5(data_view)
save(data_view, odb=cache, hardlink=True)

downloaded.update(
entry.hash_info
for _, entry in data_view.iteritems()
if entry.meta is not None
and not entry.meta.isdir
and entry.hash_info is not None
)
for view, cache in [
(legacy_unfetched, repo.cache.legacy),
(unfetched, repo.cache.local),
]:
data_view = view.data["repo"]
if len(data_view):
if not cache.fs.exists(cache.path):
os.makedirs(cache.path)
with TemporaryDirectory(dir=cache.path) as tmpdir:
with Callback.as_tqdm_callback(
desc="Downloading imports from source",
unit="files",
) as cb:
checkout(data_view, tmpdir, cache.fs, callback=cb, storage="remote")
md5(data_view, name=cache.hash_name)
save(data_view, odb=cache, hardlink=True)

downloaded.update(
entry.hash_info
for _, entry in data_view.iteritems()
if entry.meta is not None
and not entry.meta.isdir
and entry.hash_info is not None
)

if unpartial:
unpartial_imports(partial_view(repo.index, targets, **kwargs))
Expand Down
6 changes: 5 additions & 1 deletion dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _load_data_from_outs(index, prefix, outs):


def _load_storage_from_out(storage_map, key, out):
from dvc.cachemgr import LEGACY_HASH_NAMES
from dvc.config import NoRemoteError
from dvc_data.index import FileStorage, ObjectStorage

Expand All @@ -168,7 +169,10 @@ def _load_storage_from_out(storage_map, key, out):
)
)
else:
storage_map.add_remote(ObjectStorage(key, remote.odb, index=remote.index))
odb = (
remote.legacy_odb if out.hash_name in LEGACY_HASH_NAMES else remote.odb
)
storage_map.add_remote(ObjectStorage(key, odb, index=remote.index))
except NoRemoteError:
pass

Expand Down

0 comments on commit 80f414c

Please sign in to comment.