Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: download files using multithreading TDE-822 #580

Merged
merged 19 commits into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions scripts/files/fs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional

from linz_logger import get_log

from scripts.aws.aws_helper import is_s3
from scripts.files import fs_local, fs_s3

Expand Down Expand Up @@ -42,3 +48,64 @@ def exists(path: str) -> bool:
if is_s3(path):
return fs_s3.exists(path)
return fs_local.exists(path)


def _read_write_file(target: str, file: str) -> str:
"""write a file to a specificed target dir.

Args:
target: target directory path
file: file to read-write

Returns:
written file path
"""
download_path = os.path.join(target, f"{file.split('/')[-1]}")
get_log().info("Read-Write File", path=file, target_path=download_path)
write(download_path, read(file))
return download_path


def write_all(inputs: List[str], target: str, concurrency: Optional[int] = 10) -> List[str]:
"""Writes list of files to target destination using multithreading.

Args:
inputs: list of files to read
target: target folder to write to

Returns:
list of written file paths
"""
written_tiffs: List[str] = []
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futuress = {executor.submit(_read_write_file, target, input): input for input in inputs}
for future in as_completed(futuress):
if future.exception():
get_log().warn("Failed Read-Write", error=future.exception())
else:
written_tiffs.append(future.result())

if len(inputs) != len(written_tiffs):
get_log().error("Missing Files", missing_file_count=len(inputs) - len(written_tiffs))
raise Exception("Not all source files were written")
return written_tiffs


def find_sidecars(inputs: List[str], extensions: List[str]) -> List[str]:
"""Searches for sidecar files.
A sidecar files is a file with the same name as the input file but with a different extension.

Args:
inputs: list of input files to search for extensions
extensions: the sidecar file extensions

Returns:
list of existing sidecar files
"""
sidecars = []
for file in inputs:
for extension in extensions:
sidecar = f"{file.split('.')[0]}{extension}"
if exists(sidecar):
sidecars.append(sidecar)
return sidecars
46 changes: 5 additions & 41 deletions scripts/standardising.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from multiprocessing import Pool
from typing import List, Optional

import ulid
from linz_logger import get_log

from scripts.aws.aws_helper import is_s3
from scripts.cli.cli_helper import TileFiles
from scripts.files.file_tiff import FileTiff, FileTiffType
from scripts.files.fs import exists, read, write
from scripts.files.files_helper import is_tiff
from scripts.files.fs import exists, find_sidecars, read, write, write_all
from scripts.gdal.gdal_bands import get_gdal_band_offset
from scripts.gdal.gdal_helper import get_gdal_version, run_gdal
from scripts.gdal.gdal_preset import (
Expand Down Expand Up @@ -74,44 +74,6 @@ def run_standardising(
return standardized_tiffs


def download_tiffs(files: List[str], target: str) -> List[str]:
"""Download a tiff file and some of its sidecar files if they exist to the target dir.

Args:
files: links source filename to target tilename
target: target folder to write too

Returns:
linked downloaded filename to target tilename

Example:
```
>>> download_tiff_file(("s3://elevation/SN9457_CE16_10k_0502.tif", "CE16_5000_1003"), "/tmp/")
("/tmp/123456.tif", "CE16_5000_1003")
```
"""
downloaded_files: List[str] = []
for file in files:
target_file_path = os.path.join(target, str(ulid.ULID()))
input_file_path = target_file_path + ".tiff"
get_log().info("download_tiff", path=file, target_path=input_file_path)

write(input_file_path, read(file))
downloaded_files.append(input_file_path)

base_file_path = os.path.splitext(file)[0]
# Attempt to download sidecar files too
for ext in [".prj", ".tfw"]:
try:
write(target_file_path + ext, read(base_file_path + ext))
get_log().info("download_tiff_sidecar", path=base_file_path + ext, target_path=target_file_path + ext)

except: # pylint: disable-msg=bare-except
pass

return downloaded_files


def create_vrt(source_tiffs: List[str], target_path: str, add_alpha: bool = False) -> str:
"""Create a VRT from a list of tiffs files

Expand Down Expand Up @@ -167,8 +129,10 @@ def standardising(
# Download any needed file from S3 ["/foo/bar.tiff", "s3://foo"] => "/tmp/bar.tiff", "/tmp/foo.tiff"
with tempfile.TemporaryDirectory() as tmp_path:
standardized_working_path = os.path.join(tmp_path, standardized_file_name)
sidecars = find_sidecars(files.input, [".prj", ".tfw"])
source_files = write_all(files.input + sidecars, tmp_path)
source_tiffs = [file for file in source_files if is_tiff(file)]

source_tiffs = download_tiffs(files.input, tmp_path)
vrt_add_alpha = True

for file in source_tiffs:
Expand Down