diff --git a/environment.yml b/environment.yml index b31578c785..c5ba1ff2c8 100644 --- a/environment.yml +++ b/environment.yml @@ -10,6 +10,7 @@ dependencies: - cftime - compilers - dask + - distributed - esgf-pyclient>=0.3.1 - esmpy!=8.1.0 - filelock @@ -18,7 +19,7 @@ dependencies: - geopy - humanfriendly - importlib_resources - - iris>=3.4.0 + - iris>=3.6.0 - iris-esmf-regrid >=0.6.0 # to work with latest esmpy - isodate - jinja2 diff --git a/esmvalcore/_provenance.py b/esmvalcore/_provenance.py index dfe70e7220..9adb57518f 100644 --- a/esmvalcore/_provenance.py +++ b/esmvalcore/_provenance.py @@ -3,6 +3,7 @@ import logging import os from functools import total_ordering +from pathlib import Path from netCDF4 import Dataset from PIL import Image @@ -104,7 +105,7 @@ class TrackedFile: """File with provenance tracking.""" def __init__(self, - filename, + filename: Path, attributes=None, ancestors=None, prov_filename=None): @@ -112,7 +113,7 @@ def __init__(self, Arguments --------- - filename: str + filename: Path to the file on disk. attributes: dict Dictionary with facets describing the file. If set to None, this diff --git a/esmvalcore/_recipe/recipe.py b/esmvalcore/_recipe/recipe.py index 4644ef524b..2c0c0f19fd 100644 --- a/esmvalcore/_recipe/recipe.py +++ b/esmvalcore/_recipe/recipe.py @@ -227,7 +227,10 @@ def _get_default_settings(dataset): settings['remove_supplementary_variables'] = {} # Configure saving cubes to file - settings['save'] = {'compress': session['compress_netcdf']} + settings['save'] = { + 'compress': session['compress_netcdf'], + 'compute': session['max_parallel_tasks'] != 0, + } if facets['short_name'] != facets['original_short_name']: settings['save']['alias'] = facets['short_name'] @@ -537,6 +540,9 @@ def _get_downstream_settings(step, order, products): if key in remaining_steps: if all(p.settings.get(key, object()) == value for p in products): settings[key] = value + save = dict(some_product.settings.get('save', {})) + save.pop('filename', None) + settings['save'] = save return settings @@ -1305,7 +1311,7 @@ def run(self): if self.session['search_esgf'] != 'never': esgf.download(self._download_files, self.session['download_dir']) - self.tasks.run(max_parallel_tasks=self.session['max_parallel_tasks']) + self.tasks.run(self.session) logger.info( "Wrote recipe with version numbers and wildcards " "to:\nfile://%s", filled_recipe) diff --git a/esmvalcore/_task.py b/esmvalcore/_task.py index c05491eda0..630783424e 100644 --- a/esmvalcore/_task.py +++ b/esmvalcore/_task.py @@ -1,7 +1,10 @@ """ESMValtool task definition.""" +from __future__ import annotations + import abc import contextlib import datetime +import importlib import logging import numbers import os @@ -15,15 +18,21 @@ from multiprocessing import Pool from pathlib import Path, PosixPath from shutil import which -from typing import Optional +from typing import TYPE_CHECKING +import dask +import dask.distributed import psutil import yaml from ._citation import _write_citation_files from ._provenance import TrackedFile, get_task_provenance +from .config import Session from .config._diagnostics import DIAGNOSTICS, TAGS +if TYPE_CHECKING: + from esmvalcore.preprocessor import PreprocessingTask + def path_representer(dumper, data): """For printing pathlib.Path objects in yaml files.""" @@ -191,7 +200,9 @@ def _ncl_type(value): lines = [] # ignore some settings for NCL diagnostic - ignore_settings = ['profile_diagnostic', ] + ignore_settings = [ + 'profile_diagnostic', + ] for sett in ignore_settings: settings_copy = dict(settings) if 'diag_script_info' not in settings_copy: @@ -414,7 +425,9 @@ def write_settings(self): run_dir.mkdir(parents=True, exist_ok=True) # ignore some settings for diagnostic - ignore_settings = ['profile_diagnostic', ] + ignore_settings = [ + 'profile_diagnostic', + ] for sett in ignore_settings: settings_copy = dict(self.settings) settings_copy.pop(sett, None) @@ -694,6 +707,54 @@ def __repr__(self): return string +@contextlib.contextmanager +def get_distributed_client(session): + """Get a Dask distributed client.""" + dask_args = session.get('dask', {}) + client_args = dask_args.get('client', {}).copy() + cluster_args = dask_args.get('cluster', {}).copy() + + # Start a cluster, if requested + if 'address' in client_args: + # Use an externally managed cluster. + cluster = None + if cluster_args: + logger.warning( + "Not using 'dask: cluster' settings because a cluster " + "'address' is already provided in 'dask: client'.") + elif cluster_args: + # Start cluster. + cluster_type = cluster_args.pop( + 'type', + 'dask.distributed.LocalCluster', + ) + cluster_module_name, cluster_cls_name = cluster_type.rsplit('.', 1) + cluster_module = importlib.import_module(cluster_module_name) + cluster_cls = getattr(cluster_module, cluster_cls_name) + cluster = cluster_cls(**cluster_args) + client_args['address'] = cluster.scheduler_address + else: + # No cluster configured, use Dask default scheduler, or a LocalCluster + # managed through Client. + cluster = None + + # Start a client, if requested + if dask_args: + client = dask.distributed.Client(**client_args) + logger.info(f"Dask dashboard: {client.dashboard_link}") + else: + logger.info("Using the Dask default scheduler.") + client = None + + try: + yield client + finally: + if client is not None: + client.close() + if cluster is not None: + cluster.close() + + class TaskSet(set): """Container for tasks.""" @@ -710,18 +771,101 @@ def get_independent(self) -> 'TaskSet': independent_tasks.add(task) return independent_tasks - def run(self, max_parallel_tasks: Optional[int] = None) -> None: + def run(self, session: Session) -> None: """Run tasks. Parameters ---------- - max_parallel_tasks : int - Number of processes to run. If `1`, run the tasks sequentially. + session : esmvalcore.config.Session + Session. """ - if max_parallel_tasks == 1: - self._run_sequential() - else: - self._run_parallel(max_parallel_tasks) + with get_distributed_client(session) as client: + if client is None: + scheduler_address = None + else: + scheduler_address = client.scheduler.address + for task in self.flatten(): + if (isinstance(task, DiagnosticTask) + and Path(task.script).suffix.lower() == '.py'): + # Only use the scheduler address if running a + # Python script. + task.settings['scheduler_address'] = scheduler_address + + max_parallel_tasks = session['max_parallel_tasks'] + if max_parallel_tasks == 0: + if client is None: + raise ValueError( + "Unable to run tasks using Dask distributed without a " + "configured dask client. Please edit config-user.yml " + "to configure dask.") + self._run_distributed(client) + elif max_parallel_tasks == 1: + self._run_sequential() + else: + self._run_parallel(scheduler_address, max_parallel_tasks) + + def _run_distributed(self, client: dask.distributed.Client) -> None: + """Run tasks using Dask Distributed.""" + client.forward_logging() + tasks = sorted((t for t in self.flatten()), key=lambda t: t.priority) + + # Create a graph for dask.array operations in PreprocessingTasks + preprocessing_tasks = [t for t in tasks if hasattr(t, 'delayeds')] + + futures_to_preproc_tasks: dict[dask.distributed.Future, + PreprocessingTask] = {} + for task in preprocessing_tasks: + future = client.submit(_run_preprocessing_task, + task, + priority=-task.priority) + futures_to_preproc_tasks[future] = task + + for future in dask.distributed.as_completed(futures_to_preproc_tasks): + task = futures_to_preproc_tasks[future] + _copy_preprocessing_results(task, future) + + # Launch dask.array compute operations for PreprocessingTasks + futures_to_files: dict[dask.distributed.Future, Path] = {} + for task in preprocessing_tasks: + logger.info(f"Computing task {task.name}") + futures = client.compute( + list(task.delayeds.values()), + priority=-task.priority, + ) + futures_to_files.update(zip(futures, task.delayeds)) + + # Start computing DiagnosticTasks as soon as the relevant + # PreprocessingTasks complete + waiting = [t for t in tasks if t not in preprocessing_tasks] + futures_to_tasks: dict[dask.distributed.Future, BaseTask] = {} + done_files = set() + done_tasks = set() + iterator = dask.distributed.as_completed(futures_to_files) + for future in iterator: + if future in futures_to_files: + filename = futures_to_files[future] + logger.info(f"Wrote (delayed) {filename}") + done_files.add(filename) + # Check if a PreprocessingTask has finished + for preproc_task in preprocessing_tasks: + filenames = set(preproc_task.delayeds) + if filenames.issubset(done_files): + done_tasks.add(preproc_task) + elif future in futures_to_tasks: + # Check if a ResumeTask or DiagnosticTask has finished + task = futures_to_tasks[future] + _copy_distributed_results(task, future) + done_tasks.add(task) + + # Schedule any new tasks that can be scheduled + for task in waiting: + if set(task.ancestors).issubset(done_tasks): + future = client.submit(_run_task, + task, + priority=-task.priority) + iterator.add(future) + futures_to_tasks[future] = task + waiting.pop(waiting.index(task)) def _run_sequential(self) -> None: """Run tasks sequentially.""" @@ -732,7 +876,7 @@ def _run_sequential(self) -> None: for task in sorted(tasks, key=lambda t: t.priority): task.run() - def _run_parallel(self, max_parallel_tasks=None): + def _run_parallel(self, scheduler_address, max_parallel_tasks=None): """Run tasks in parallel.""" scheduled = self.flatten() running = {} @@ -757,14 +901,15 @@ def done(task): if len(running) >= max_parallel_tasks: break if all(done(t) for t in task.ancestors): - future = pool.apply_async(_run_task, [task]) + future = pool.apply_async(_run_task, + [task, scheduler_address]) running[task] = future scheduled.remove(task) # Handle completed tasks ready = {t for t in running if running[t].ready()} for task in ready: - _copy_results(task, running[task]) + _copy_multiprocessing_results(task, running[task]) running.pop(task) # Wait if there are still tasks running @@ -785,12 +930,31 @@ def done(task): pool.join() -def _copy_results(task, future): +def _run_task(task, scheduler_address=None): + """Run task and return the result.""" + if scheduler_address is None: + output_files = task.run() + else: + with dask.distributed.Client(scheduler_address): + output_files = task.run() + return output_files, task.products + + +def _copy_distributed_results(task, future): + """Update task with the results from the dask worker.""" + task.output_files, task.products = future.result() + + +def _copy_multiprocessing_results(task, future): """Update task with the results from the remote process.""" task.output_files, task.products = future.get() -def _run_task(task): - """Run task and return the result.""" +def _run_preprocessing_task(task): output_files = task.run() - return output_files, task.products + return output_files, task.products, task.delayeds + + +def _copy_preprocessing_results(task, future): + """Update task with the results from the dask worker.""" + task.output_files, task.products, task.delayeds = future.result() diff --git a/esmvalcore/config/_config_validators.py b/esmvalcore/config/_config_validators.py index 736a6ba689..29b0213996 100644 --- a/esmvalcore/config/_config_validators.py +++ b/esmvalcore/config/_config_validators.py @@ -282,6 +282,7 @@ def validate_diagnostics( 'auxiliary_data_dir': validate_path, 'compress_netcdf': validate_bool, 'config_developer_file': validate_config_developer, + 'dask': validate_dict, 'download_dir': validate_path, 'drs': validate_drs, 'exit_on_warning': validate_bool, diff --git a/esmvalcore/preprocessor/__init__.py b/esmvalcore/preprocessor/__init__.py index 90b743b0b0..83fec008d1 100644 --- a/esmvalcore/preprocessor/__init__.py +++ b/esmvalcore/preprocessor/__init__.py @@ -8,6 +8,7 @@ from pprint import pformat from typing import Any, Iterable +from dask.delayed import Delayed from iris.cube import Cube from .._provenance import TrackedFile @@ -489,23 +490,29 @@ def cubes(self, value): def save(self): """Save cubes to disk.""" - preprocess(self._cubes, - 'save', - input_files=self._input_files, - **self.settings['save']) + result = save( + self._cubes, + **self.settings['save'], + ) + self.files = [self.settings['save']['filename']] if 'cleanup' in self.settings: - preprocess([], - 'cleanup', - input_files=self._input_files, - **self.settings['cleanup']) + self.files = preprocess( + self.files, + 'cleanup', + input_files=self._input_files, + **self.settings.get('cleanup', {}), + ) + return result def close(self): """Close the file.""" + result = None if self._cubes is not None: self._update_attributes() - self.save() + result = self.save() self._cubes = None self.save_provenance() + return result def _update_attributes(self): """Update product attributes from cube metadata.""" @@ -603,6 +610,7 @@ def __init__( self.order = list(order) self.debug = debug self.write_ncl_interface = write_ncl_interface + self.delayeds: dict[Path, Delayed] = {} def _initialize_product_provenance(self): """Initialize product provenance.""" @@ -647,6 +655,7 @@ def _initialize_products(self, products): def _run(self, _): """Run the preprocessor.""" + self.delayeds.clear() self._initialize_product_provenance() steps = { @@ -670,13 +679,17 @@ def _run(self, _): product.apply(step, self.debug) if block == blocks[-1]: product.cubes # pylint: disable=pointless-statement - product.close() + result = product.close() + if isinstance(result, Delayed): + self.delayeds[product.filename] = result saved.add(product.filename) for product in self.products: if product.filename not in saved: product.cubes # pylint: disable=pointless-statement - product.close() + result = product.close() + if isinstance(result, Delayed): + self.delayeds[product.filename] = result metadata_files = write_metadata(self.products, self.write_ncl_interface) diff --git a/esmvalcore/preprocessor/_io.py b/esmvalcore/preprocessor/_io.py index 564ec89fe3..9b45208577 100644 --- a/esmvalcore/preprocessor/_io.py +++ b/esmvalcore/preprocessor/_io.py @@ -249,6 +249,7 @@ def save(cubes, filename, optimize_access='', compress=False, + compute=True, alias='', **kwargs): """Save iris cubes to file. @@ -273,13 +274,17 @@ def save(cubes, compress: bool, optional Use NetCDF internal compression. + compute: bool, optional + If true save immediately, otherwise return a dask.delayed.Delayed + object that can be used for saving the data later. + alias: str, optional Var name to use when saving instead of the one in the cube. Returns ------- - str - filename + str or dask.delayed.Delayed + filename or delayed Raises ------ @@ -288,6 +293,8 @@ def save(cubes, """ if not cubes: raise ValueError(f"Cannot save empty cubes '{cubes}'") + if len(cubes) > 1: + raise ValueError(f"`save` expects as single cube, got '{cubes}") # Rename some arguments kwargs['target'] = filename @@ -330,9 +337,19 @@ def save(cubes, logger.debug('Changing var_name from %s to %s', cube.var_name, alias) cube.var_name = alias - iris.save(cubes, **kwargs) - return filename + cube = cubes[0] + if not compute and not cube.has_lazy_data(): + # What should happen if the data is not lazy and we're asked for a + # lazy save? + # https://github.com/SciTools/iris/pull/5031#issuecomment-1322166230 + compute = True + + result = iris.save(cube, compute=compute, **kwargs) + if compute: + logger.info("Wrote (immediate) %s", filename) + return filename + return result def _get_debug_filename(filename, step): diff --git a/setup.py b/setup.py index 113bda3ebe..d10f045674 100755 --- a/setup.py +++ b/setup.py @@ -28,9 +28,8 @@ # Use with pip install . to install from source 'install': [ 'cartopy', - # see https://github.com/SciTools/cf-units/issues/218 'cf-units', - 'dask[array]', + 'dask[array,distributed]', 'esgf-pyclient>=0.3.1', 'esmf-regrid', 'esmpy!=8.1.0', @@ -56,8 +55,8 @@ 'pyyaml', 'requests', 'scipy>=1.6', - 'scitools-iris>=3.4.0', - 'shapely[vectorized]', + 'scitools-iris>=3.6.0', + 'shapely', 'stratify', 'yamale', ], diff --git a/tests/integration/recipe/test_recipe.py b/tests/integration/recipe/test_recipe.py index b63a4bd7f7..c8054bfa1c 100644 --- a/tests/integration/recipe/test_recipe.py +++ b/tests/integration/recipe/test_recipe.py @@ -112,6 +112,7 @@ def _get_default_settings_for_chl(save_filename): 'remove_supplementary_variables': {}, 'save': { 'compress': False, + 'compute': True, 'filename': save_filename, } } @@ -563,6 +564,7 @@ def test_default_fx_preprocessor(tmp_path, patched_datafinder, session): 'remove_supplementary_variables': {}, 'save': { 'compress': False, + 'compute': True, 'filename': product.filename, } } @@ -3135,6 +3137,7 @@ def test_recipe_run(tmp_path, patched_datafinder, session, mocker): recipe = get_recipe(tmp_path, content, session) + os.makedirs(session['output_dir']) recipe.tasks.run = mocker.Mock() recipe.write_filled_recipe = mocker.Mock() recipe.write_html_summary = mocker.Mock() @@ -3142,8 +3145,8 @@ def test_recipe_run(tmp_path, patched_datafinder, session, mocker): esmvalcore._recipe.recipe.esgf.download.assert_called_once_with( set(), session['download_dir']) - recipe.tasks.run.assert_called_once_with( - max_parallel_tasks=session['max_parallel_tasks']) + session['write_ncl_interface'] = False + recipe.tasks.run.assert_called_once_with(session) recipe.write_filled_recipe.assert_called_once() recipe.write_html_summary.assert_called_once() diff --git a/tests/integration/test_task.py b/tests/integration/test_task.py index 42b724e1c9..1d6cd61214 100644 --- a/tests/integration/test_task.py +++ b/tests/integration/test_task.py @@ -73,7 +73,10 @@ def test_run_tasks(monkeypatch, tmp_path, max_parallel_tasks, example_tasks, """Check that tasks are run correctly.""" monkeypatch.setattr(esmvalcore._task, 'Pool', multiprocessing.get_context(mpmethod).Pool) - example_tasks.run(max_parallel_tasks=max_parallel_tasks) + cfg = { + 'max_parallel_tasks': max_parallel_tasks, + } + example_tasks.run(cfg) for task in example_tasks: print(task.name, task.output_files) @@ -82,7 +85,11 @@ def test_run_tasks(monkeypatch, tmp_path, max_parallel_tasks, example_tasks, @pytest.mark.parametrize('runner', [ TaskSet._run_sequential, - partial(TaskSet._run_parallel, max_parallel_tasks=1), + partial( + TaskSet._run_parallel, + scheduler_address=None, + max_parallel_tasks=1, + ), ]) def test_runner_uses_priority(monkeypatch, runner, example_tasks): """Check that the runner tries to respect task priority.""" diff --git a/tests/unit/preprocessor/test_preprocessor_file.py b/tests/unit/preprocessor/test_preprocessor_file.py index 3ebb3385d6..f6c9362424 100644 --- a/tests/unit/preprocessor/test_preprocessor_file.py +++ b/tests/unit/preprocessor/test_preprocessor_file.py @@ -6,6 +6,7 @@ import pytest from iris.cube import Cube, CubeList +import esmvalcore.preprocessor from esmvalcore.preprocessor import PreprocessorFile ATTRIBUTES = { @@ -147,36 +148,41 @@ def test_close(): assert product._cubes is None -@mock.patch('esmvalcore.preprocessor.preprocess', autospec=True) -def test_save_no_cleanup(mock_preprocess): +def test_save_no_cleanup(mocker): """Test ``save``.""" - product = mock.create_autospec(PreprocessorFile, instance=True) - product.settings = {'save': {}} - product._cubes = mock.sentinel.cubes - product._input_files = mock.sentinel.input_files + mocker.patch.object(esmvalcore.preprocessor, 'preprocess', autospec=True) + mocker.patch.object(esmvalcore.preprocessor, 'save', autospec=True) + product = mocker.create_autospec(PreprocessorFile, instance=True) + product.settings = {'save': {'filename': Path('file1.nc')}} + product._cubes = [mocker.sentinel.cube] + product._input_files = mocker.sentinel.input_files PreprocessorFile.save(product) - assert mock_preprocess.mock_calls == [ - mock.call( - mock.sentinel.cubes, 'save', input_files=mock.sentinel.input_files - ), - ] + esmvalcore.preprocessor.save.assert_called_with( + [mocker.sentinel.cube], + filename=Path('file1.nc'), + ) + esmvalcore.preprocessor.preprocess.assert_not_called() -@mock.patch('esmvalcore.preprocessor.preprocess', autospec=True) -def test_save_cleanup(mock_preprocess): +def test_save_cleanup(mocker): """Test ``save``.""" - product = mock.create_autospec(PreprocessorFile, instance=True) - product.settings = {'save': {}, 'cleanup': {}} - product._cubes = mock.sentinel.cubes - product._input_files = mock.sentinel.input_files + mocker.patch.object(esmvalcore.preprocessor, 'preprocess', autospec=True) + mocker.patch.object(esmvalcore.preprocessor, 'save', autospec=True) + product = mocker.create_autospec(PreprocessorFile, instance=True) + product.settings = {'save': {'filename': Path('file1.nc')}, 'cleanup': {}} + product._cubes = [mocker.sentinel.cube] + product._input_files = mocker.sentinel.input_files PreprocessorFile.save(product) - assert mock_preprocess.mock_calls == [ - mock.call( - mock.sentinel.cubes, 'save', input_files=mock.sentinel.input_files - ), - mock.call([], 'cleanup', input_files=mock.sentinel.input_files), - ] + esmvalcore.preprocessor.save.assert_called_with( + [mocker.sentinel.cube], + filename=Path('file1.nc'), + ) + esmvalcore.preprocessor.preprocess.assert_called_with( + [Path('file1.nc')], + 'cleanup', + input_files=mock.sentinel.input_files, + ) diff --git a/tests/unit/recipe/test_recipe.py b/tests/unit/recipe/test_recipe.py index b830351318..18fdb68dcb 100644 --- a/tests/unit/recipe/test_recipe.py +++ b/tests/unit/recipe/test_recipe.py @@ -556,7 +556,11 @@ def test_get_default_settings(mocker): return_value=Path('/path/to/file.nc'), ) session = mocker.create_autospec(esmvalcore.config.Session, instance=True) - session.__getitem__.return_value = False + config = { + 'compress_netcdf': False, + 'max_parallel_tasks': 1, + } + session.__getitem__ = lambda self, key: config[key] dataset = Dataset( short_name='sic', @@ -570,7 +574,7 @@ def test_get_default_settings(mocker): assert settings == { 'load': {'callback': 'default'}, 'remove_supplementary_variables': {}, - 'save': {'compress': False, 'alias': 'sic'}, + 'save': {'compress': False, 'alias': 'sic', 'compute': True}, }