Skip to content

Commit

Permalink
Merge remote-tracking branch 'aiddata/develop' into cru_ts
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobwhall committed Feb 8, 2023
2 parents 44d88fe + 87babd8 commit 130d625
Show file tree
Hide file tree
Showing 46 changed files with 1,829 additions and 881 deletions.
5 changes: 3 additions & 2 deletions dvnl/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from datetime import datetime
from configparser import ConfigParser
from typing import List, Literal

from prefect import flow
from prefect.filesystems import GitHub
Expand All @@ -20,7 +21,7 @@


@flow
def dvnl(raw_dir, output_dir, years, overwrite_download, overwrite_processing, backend, task_runner, run_parallel, max_workers, log_dir):
def dvnl(raw_dir: str, output_dir: str, years: List[int], overwrite_download: bool, overwrite_processing: bool, backend: Literal["local", "mpi", "prefect"], task_runner: Literal["sequential", "concurrent", "dask", "hpc"], run_parallel: bool, max_workers: int, log_dir: str):

timestamp = datetime.today()
time_str = timestamp.strftime("%Y_%m_%d_%H_%M")
Expand Down Expand Up @@ -54,4 +55,4 @@ def dvnl(raw_dir, output_dir, years, overwrite_download, overwrite_processing, b

class_instance = DVNL(raw_dir, output_dir, years, overwrite_download, overwrite_processing)

class_instance.run(backend=backend, task_runner=task_runner, run_parallel=run_parallel, max_workers=max_workers, log_dir=timestamp_log_dir, cluster_kwargs=cluster_kwargs)
class_instance.run(backend=backend, task_runner=task_runner, run_parallel=run_parallel, max_workers=max_workers, log_dir=timestamp_log_dir, cluster_kwargs=cluster_kwargs)
5 changes: 4 additions & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ dependencies:
- pandas==1.5.1
- pillow==9.2.0
- proj==0.2.0
- rasterio==1.3.3
- rasterio==1.3.3
- boxsdk==3.6.1
- dask==2023.01
- distributed==2023.01
122 changes: 71 additions & 51 deletions global_scripts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,7 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager
from tempfile import TemporaryDirectory, mkstemp, _get_default_tempdir


def find_tmp_dir():
"""
Find a temporary directory based on environment variables
This function is called if a Dataset is run without a temporary directory specified
"""
try:
tmp_dir = Path("/local/scr") / os.environ["USER"] / "TMPDIR"
if tmp_dir.exists():
return tmp_dir.as_posix()
except:
pass

try:
return _get_default_tempdir()
except FileNotFoundError:
raise FileNotFoundError("Unable to find a suitable temporary directory. Please specify tmp_dir when calling run")
from tempfile import TemporaryDirectory, mkstemp


"""
Expand Down Expand Up @@ -107,10 +89,10 @@ def get_logger(self):


@contextmanager
def tmp_to_dst_file(self, final_dst):
def tmp_to_dst_file(self, final_dst, tmp_dir=None):
logger = self.get_logger()
with TemporaryDirectory(dir=self.tmp_dir) as tmp_dir:
tmp_file = mkstemp(dir=self.tmp_dir)[1]
with TemporaryDirectory(dir=tmp_dir) as tmp_sub_dir:
tmp_file = mkstemp(dir=tmp_sub_dir)[1]
logger.debug(f"Created temporary file {tmp_file} with final destination {str(final_dst)}")
yield tmp_file
try:
Expand All @@ -126,14 +108,18 @@ def error_wrapper(self, func, args):
This is the wrapper that is used when running individual tasks
It will always return a TaskResult!
"""
logger = self.get_logger()

for try_no in range(self.retries + 1):
try:
return TaskResult(0, "Success", args, func(*args))
except Exception as e:
if try_no < self.retries:
logger.error(f"Task failed with exception (retrying): {repr(e)}")
time.sleep(self.retry_delay)
continue
else:
logger.error(f"Task failed with exception (giving up): {repr(e)}")
return TaskResult(1, repr(e), args, None)


Expand All @@ -142,6 +128,8 @@ def run_serial_tasks(self, name, func, input_list):
Run tasks in serial (locally), given a function and list of inputs
This will always return a list of TaskResults!
"""
logger = self.get_logger()
logger.debug(f"run_serial_tasks - input_list: {input_list}")
return [self.error_wrapper(func, i) for i in input_list]


Expand All @@ -163,8 +151,9 @@ def run_prefect_tasks(self, name, func, input_list, force_sequential):
"""

from prefect import task
logger = self.get_logger()

task_wrapper = task(func, name=name, retries=self.retries, retry_delay_seconds=self.retry_delay)
task_wrapper = task(func, name=name, retries=self.retries, retry_delay_seconds=self.retry_delay, persist_result=True)

futures = []
for i in input_list:
Expand All @@ -173,40 +162,66 @@ def run_prefect_tasks(self, name, func, input_list, force_sequential):

results = []

for inputs, future in futures:
state = future.wait(timeout=None)
if state.is_completed():
results.append(TaskResult(0, "Success", inputs, state.result()))
elif state.is_failed() or state.is_crashed():
try:
msg = repr(state.result(raise_on_failure=False))
except:
msg = "Unable to retrieve error message"
results.append(TaskResult(1, msg, inputs, None))
else:
pass

states = [(i[0], i[1].wait()) for i in futures]

while states:
for ix, (inputs, state) in enumerate(states):
if state.is_completed():
# print('complete', ix, inputs)
logger.info(f'complete - {ix} - {inputs}')

results.append(TaskResult(0, "Success", inputs, state.result()))
elif state.is_failed() or state.is_crashed() or state.is_cancelled():
# print('fail', ix, inputs)
logger.info(f'fail - {ix} - {inputs}')

try:
msg = repr(state.result(raise_on_failure=True))
except Exception as e:
msg = f"Unable to retrieve error message - {e}"
results.append(TaskResult(1, msg, inputs, None))
else:
# print('not ready', ix, inputs)
continue
_ = states.pop(ix)
time.sleep(5)


# for inputs, future in futures:
# state = future.wait(60*60*2)
# if state.is_completed():
# results.append(TaskResult(0, "Success", inputs, state.result()))
# elif state.is_failed() or state.is_crashed():
# try:
# msg = repr(state.result(raise_on_failure=False))
# except:
# msg = "Unable to retrieve error message"
# results.append(TaskResult(1, msg, inputs, None))
# else:
# pass

# while futures:
# for ix, (inputs, future) in enumerate(futures):
# state = future.get_state()
# print(repr(state))
# print(repr(future))
# # print(repr(state))
# # print(repr(future))
# if state.is_completed():
# print('complete', ix, inputs)
# results.append(TaskResult(0, "Success", inputs, future.result()))
# elif state.is_failed() or state.is_crashed() or state.is_cancelled():
# print('fail', ix, inputs)
# try:
# msg = repr(future.result(raise_on_failure=True))
# except:
# msg = "Unable to retrieve error message"
# except Exception as e:
# msg = f"Unable to retrieve error message - {e}"
# results.append(TaskResult(1, msg, inputs, None))
# else:
# print('not ready', ix, inputs)
# # print('not ready', ix, inputs)
# continue
# _ = futures.pop(ix)
# future.release()
# time.sleep(15)
# # future.release()
# time.sleep(5)

return results

Expand All @@ -232,7 +247,8 @@ def run_tasks(self,
name: Optional[str]=None,
retries: Optional[int]=3,
retry_delay: Optional[int]=60,
force_sequential: bool=False):
force_sequential: bool=False,
force_serial: bool=False):
"""
Run a bunch of tasks, calling one of the above run_tasks functions
This is the function that should be called most often from self.main()
Expand All @@ -259,7 +275,7 @@ def run_tasks(self,
elif not isinstance(name, str):
raise TypeError("Name of task run must be a string")

if self.backend == "serial":
if self.backend == "serial" or force_serial:
results = self.run_serial_tasks(name, func, input_list)
elif self.backend == "concurrent":
results = self.run_concurrent_tasks(name, func, input_list, force_sequential)
Expand Down Expand Up @@ -389,10 +405,10 @@ def run(
task_runner: Optional[str]=None,
run_parallel: bool=False,
max_workers: Optional[int]=None,
threads_per_worker: Optional[int]=1,
# cores_per_process: Optional[int]=None,
chunksize: int=1,
log_dir: str="logs",
tmp_dir: Optional[str]=find_tmp_dir(),
logger_level=logging.INFO,
retries: int=3,
retry_delay: int=5,
Expand All @@ -406,7 +422,6 @@ def run(
self.init_retries(retries, retry_delay, save_settings=True)

self.log_dir = Path(log_dir)
self.tmp_dir = Path(tmp_dir).as_posix()

self.chunksize = chunksize
os.makedirs(self.log_dir, exist_ok=True)
Expand All @@ -431,11 +446,16 @@ def run(
tr = ConcurrentTaskRunner
elif task_runner == "dask":
from prefect_dask import DaskTaskRunner
if "cluster" in kwargs:
del kwargs["cluster"]
if "cluster_kwargs" in kwargs:
del kwargs["cluster_kwargs"]
tr = DaskTaskRunner(**kwargs)
# if "cluster" in kwargs:
# del kwargs["cluster"]
# if "cluster_kwargs" in kwargs:
# del kwargs["cluster_kwargs"]

dask_cluster_kwargs = {
"n_workers": max_workers,
"threads_per_worker": threads_per_worker
}
tr = DaskTaskRunner(cluster_kwargs=dask_cluster_kwargs)
elif task_runner == "hpc":
from hpc import HPCDaskTaskRunner
job_name = "".join(self.name.split())
Expand Down
1 change: 1 addition & 0 deletions gpw/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sedac_cookie
43 changes: 43 additions & 0 deletions gpw/config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[main]

# Name of dataset
name = GPWv4

# Paths of input and output directories
raw_dir = /sciclone/aiddata10/REU/pre_geo/raw/gpw/gpw_v4_rev11
output_dir = /sciclone/aiddata10/REU/pre_geo/data/rasters/gpw/gpw_v4_rev11

# Years to process, must be separated by ", "
years = 2000, 2005, 2010, 2015, 2020

sedac_cookie = None

# Overwrite existing files?
overwrite_download = False
overwrite_extract = False
overwrite_processing = False


[run]

backend = prefect
task_runner = hpc
run_parallel = True
max_workers = 10


[github]

repo = https://github.com/aiddata/geo-datasets.git
branch = develop
directory = gpw


[deploy]

deployment_name = GPWv4
version = 1
flow_file_name = flow
flow_name = gpwv4
storage_block = geo-datasets-github
work_queue = geodata
60 changes: 60 additions & 0 deletions gpw/flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import sys
from pathlib import Path
from datetime import datetime
from configparser import ConfigParser

from prefect import flow
from prefect.filesystems import GitHub


config_file = "gpw/config.ini"
config = ConfigParser()
config.read(config_file)

block_name = config["deploy"]["storage_block"]
GitHub.load(block_name).get_directory("global_scripts")

from main import GPWv4

tmp_dir = Path(os.getcwd()) / config["github"]["directory"]


@flow
def gpwv4(raw_dir, output_dir, years, sedac_cookie, overwrite_download, overwrite_extract, overwrite_processing, backend, task_runner, run_parallel, max_workers, log_dir):

timestamp = datetime.today()
time_str = timestamp.strftime("%Y_%m_%d_%H_%M")
timestamp_log_dir = Path(log_dir) / time_str
timestamp_log_dir.mkdir(parents=True, exist_ok=True)

cluster = "vortex"

cluster_kwargs = {
"shebang": "#!/bin/tcsh",
"resource_spec": "nodes=1:c18a:ppn=12",
"walltime": "2:00:00",
"cores": 5,
"processes": 5,
"memory": "30GB",
"interface": "ib0",
"job_extra_directives": [
"-j oe",
],
"job_script_prologue": [
"source /usr/local/anaconda3-2021.05/etc/profile.d/conda.csh",
"module load anaconda3/2021.05",
"conda activate geodata38",
f"cd {tmp_dir}",
],
"log_directory": str(timestamp_log_dir),
}


class_instance = GPWv4(raw_dir, output_dir, years, sedac_cookie, overwrite_download, overwrite_extract, overwrite_processing)

if task_runner != 'hpc':
os.chdir(tmp_dir)
class_instance.run(backend=backend, task_runner=task_runner, run_parallel=run_parallel, max_workers=max_workers, log_dir=timestamp_log_dir)
else:
class_instance.run(backend=backend, task_runner=task_runner, run_parallel=run_parallel, max_workers=max_workers, log_dir=timestamp_log_dir, cluster=cluster, cluster_kwargs=cluster_kwargs)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 130d625

Please sign in to comment.