Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: download files using multithreading TDE-822 #580

Merged
merged 19 commits into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions scripts/files/fs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List

import ulid
from linz_logger import get_log

from scripts.aws.aws_helper import is_s3
from scripts.files import fs_local, fs_s3

Expand Down Expand Up @@ -42,3 +49,55 @@ def exists(path: str) -> bool:
if is_s3(path):
return fs_s3.exists(path)
return fs_local.exists(path)


def _download_tiff_and_sidecar(target: str, file: str) -> str:
"""
Download a tiff file and some of its sidecar files if they exist to the target dir.

Args:
target (str): target folder to write to
s3_file (str): source file

Returns:
downloaded file path
"""
download_path = os.path.join(target, f"{ulid.ULID()}.tiff")
get_log().info("Download File", path=file, target_path=download_path)
write(download_path, read(file))
for ext in [".prj", ".tfw"]:
try:
write(f"{target.split('.')[0]}{ext}", read(f"{file.split('.')[0]}{ext}"))
get_log().info(
"Download tiff sidecars", path=f"{file.split('.')[0]}{ext}", target_path=f"{target.split('.')[0]}{ext}"
)
except: # pylint: disable-msg=bare-except
pass
return download_path


def download_tiffs_parallel_multithreaded(inputs: List[str], target: str, concurrency: int = 10) -> List[str]:
"""
Download list of tiffs to target destination using multithreading.

Args:
inputs (list): list of tiffs to download
target (str): target folder to write to


Returns:
list of downloaded file paths
"""
downloaded_tiffs: List[str] = []
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futuress = {executor.submit(_download_tiff_and_sidecar, target, input): input for input in inputs}
for future in as_completed(futuress):
if future.exception():
get_log().warn("Failed Download", error=future.exception())
else:
downloaded_tiffs.append(future.result())

if len(inputs) != len(downloaded_tiffs):
get_log().error("Missing Files", missing_file_count=len(inputs) - len(downloaded_tiffs))
raise Exception("Not all source files were downloaded")
return downloaded_tiffs
43 changes: 2 additions & 41 deletions scripts/standardising.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from multiprocessing import Pool
from typing import List, Optional

import ulid
from linz_logger import get_log

from scripts.aws.aws_helper import is_s3
from scripts.cli.cli_helper import TileFiles
from scripts.files.file_tiff import FileTiff, FileTiffType
from scripts.files.fs import exists, read, write
from scripts.files.fs import download_tiffs_parallel_multithreaded, exists, read, write
from scripts.gdal.gdal_bands import get_gdal_band_offset
from scripts.gdal.gdal_helper import get_gdal_version, run_gdal
from scripts.gdal.gdal_preset import (
Expand Down Expand Up @@ -74,44 +73,6 @@ def run_standardising(
return standardized_tiffs


def download_tiffs(files: List[str], target: str) -> List[str]:
"""Download a tiff file and some of its sidecar files if they exist to the target dir.

Args:
files: links source filename to target tilename
target: target folder to write too

Returns:
linked downloaded filename to target tilename

Example:
```
>>> download_tiff_file(("s3://elevation/SN9457_CE16_10k_0502.tif", "CE16_5000_1003"), "/tmp/")
("/tmp/123456.tif", "CE16_5000_1003")
```
"""
downloaded_files: List[str] = []
for file in files:
target_file_path = os.path.join(target, str(ulid.ULID()))
input_file_path = target_file_path + ".tiff"
get_log().info("download_tiff", path=file, target_path=input_file_path)

write(input_file_path, read(file))
downloaded_files.append(input_file_path)

base_file_path = os.path.splitext(file)[0]
# Attempt to download sidecar files too
for ext in [".prj", ".tfw"]:
try:
write(target_file_path + ext, read(base_file_path + ext))
get_log().info("download_tiff_sidecar", path=base_file_path + ext, target_path=target_file_path + ext)

except: # pylint: disable-msg=bare-except
pass

return downloaded_files


def create_vrt(source_tiffs: List[str], target_path: str, add_alpha: bool = False) -> str:
"""Create a VRT from a list of tiffs files

Expand Down Expand Up @@ -167,8 +128,8 @@ def standardising(
# Download any needed file from S3 ["/foo/bar.tiff", "s3://foo"] => "/tmp/bar.tiff", "/tmp/foo.tiff"
with tempfile.TemporaryDirectory() as tmp_path:
standardized_working_path = os.path.join(tmp_path, standardized_file_name)
source_tiffs = download_tiffs_parallel_multithreaded(files.input, tmp_path)

source_tiffs = download_tiffs(files.input, tmp_path)
vrt_add_alpha = True

for file in source_tiffs:
Expand Down