Skip to content

Commit

Permalink
zenodo: add allow_overwrite as parameter
Browse files Browse the repository at this point in the history
Currently, when there is a hash mismatch between the file on Zenodo and the one on disk, your only recourse is to either use `force_download=True` which re-downloads all the files or delete the file manually (which can be tricky) and then rerun the function. This one adds an argument `allow_overwrite` to allow overwriting those files where there is a hash mismatch only.
  • Loading branch information
JoepVanlier committed Jun 6, 2024
1 parent 86d6a6b commit ceb3d93
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## v1.6.0 | t.b.d.

#### New features

* Added parameter `allow_overwrite` to [`lk.download_from_doi()`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.download_from_doi.html#lumicks.pylake.download_from_doi) to allow re-downloading only those files where the checksum does not match.

## v1.5.1 | 2024-06-03

* Fixed bug that prevented loading an `h5` file where only a subset of the photon channels are available. This bug was introduced in Pylake `1.4.0`.
Expand Down
24 changes: 15 additions & 9 deletions lumicks/pylake/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def verify_hash(file_name, algorithm, reference_hash, chunk_size=65536):
return m.hexdigest() == reference_hash


def download_from_doi(doi, target_path="", force_download=False, show_progress=True):
def download_from_doi(
doi, target_path="", force_download=False, show_progress=True, allow_overwrite=False
):
"""Download files from a Zenodo DOI (i.e. 10.5281/zenodo.#######)
Note
Expand All @@ -100,6 +102,9 @@ def download_from_doi(doi, target_path="", force_download=False, show_progress=T
a freshly downloaded copy.
show_progress : bool
Show a progress bar while downloading.
allow_overwrite : bool
Re-download files for which the hash does not match the expected hash from Zenodo. Note that
this will overwrite the existing file with a freshly downloaded copy.
Returns
-------
Expand All @@ -122,22 +127,23 @@ def download_from_doi(doi, target_path="", force_download=False, show_progress=T
file_name, url = file["key"], file["links"]["self"]
full_path = os.path.join(target_path, file_name)

# If the file doesn't exist, we can't skip it
download = not os.path.exists(full_path)
# If the file doesn't exist or we are forcing to download all, we can't skip it
download = not os.path.exists(full_path) or force_download

# If a file with the requested filename exists but does not match the data from Zenodo,
# throw an error.
# Handle the case where a file with the requested filename exists, we are not forcing
# all files to download, but the file we have does not match the checksum from Zenodo.
hash_algorithm, checksum = file["checksum"].split(":")
if not download and not verify_hash(full_path, hash_algorithm, checksum):
if not force_download:
if allow_overwrite:
download = True
else:
raise RuntimeError(
f"File {file_name} does not match file from Zenodo. Set force_download=True "
f"File {file_name} does not match file from Zenodo. Set allow_overwrite=True "
f"if you wish to overwrite the existing file on disk with the version from "
f"Zenodo."
)

# Only download what we don't have yet.
if download or force_download:
if download:
download_file(url, target_path, file_name, show_progress)
if not verify_hash(full_path, hash_algorithm, checksum):
raise RuntimeError("Download failed. Invalid checksum after download.")
Expand Down
9 changes: 5 additions & 4 deletions lumicks/pylake/tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def test_download_record_metadata():


@pytest.mark.preflight
def test_download_from_doi(tmpdir_factory, capsys):
@pytest.mark.parametrize("force_arg", [{"force_download": True}, {"allow_overwrite": True}])
def test_download_from_doi(tmpdir_factory, capsys, force_arg):
tmpdir = tmpdir_factory.mktemp("download_testing")
record = download_record_metadata("4247279")

Expand All @@ -48,12 +49,12 @@ def test_download_from_doi(tmpdir_factory, capsys):

with pytest.raises(
RuntimeError,
match="Set force_download=True if you wish to overwrite the existing file on disk with the "
"version from Zenodo",
match="Set allow_overwrite=True if you wish to overwrite the existing file on disk with "
"the version from Zenodo",
):
download_from_doi("10.5281/zenodo.4247279", tmpdir, show_progress=False)

download_from_doi("10.5281/zenodo.4247279", tmpdir, force_download=True, show_progress=False)
download_from_doi("10.5281/zenodo.4247279", tmpdir, **force_arg, show_progress=False)

captured = capsys.readouterr()
assert not captured.out
Expand Down

0 comments on commit ceb3d93

Please sign in to comment.