From fd8272cfd3b4788c313ac6ba4814b79af5ab1e70 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 27 Aug 2024 16:58:08 -0700 Subject: [PATCH] Improve regridding script PiperOrigin-RevId: 668194768 --- dinosaur/pipelines/regrid.py | 95 ++++++++++++++++++++++++------------ 1 file changed, 65 insertions(+), 30 deletions(-) diff --git a/dinosaur/pipelines/regrid.py b/dinosaur/pipelines/regrid.py index 9d581eb..fa9feef 100644 --- a/dinosaur/pipelines/regrid.py +++ b/dinosaur/pipelines/regrid.py @@ -17,6 +17,7 @@ file are supported, but irregular spacing is OK. """ +from concurrent import futures import dataclasses from typing import Any, Callable, Mapping @@ -30,6 +31,7 @@ import xarray import xarray_beam + # pylint: disable=logging-fstring-interpolation HorizontalRegridderFactory = Callable[ @@ -343,6 +345,52 @@ class RegridTarget: vertical_regridder: vertical_interpolation.Regridder | None nan_filler: NaNFiller | None = None output_chunks: dict[str, int] | None = None + zarr_metadata: dict[str, Any] | None = None + + +@dataclasses.dataclass +class _RegridTransform(beam.PTransform): + """PTransform for regridding to a single target grid.""" + + def __init__( + self, + source_template: xarray.Dataset, + input_chunks: dict[str, int], + target: RegridTarget, + io_num_threads: int | None, + setup_executor: futures.ThreadPoolExecutor, + ): + validate_horizontal_regridder(source_template, target.horizontal_regridder) + validate_vertical_regridder(source_template, target.vertical_regridder) + + self.nan_filler = target.nan_filler + self.regrid = get_regrid_func( + target.horizontal_regridder, target.vertical_regridder + ) + template = get_template( + source_template, target.horizontal_regridder, target.vertical_regridder + ) + if target.zarr_metadata: + template.attrs.update(target.zarr_metadata) + self.output_chunks = get_output_chunks( + input_chunks, target.vertical_regridder, target.output_chunks + ) + self.chunks_to_zarr = xarray_beam.ChunksToZarr( + target.output_path, + template, + self.output_chunks, + num_threads=io_num_threads, + setup_executor=setup_executor, + ) + + def expand(self, pcoll: beam.PCollection) -> beam.PCollection: + pcoll = pcoll | beam.MapTuple(self.regrid) + if self.output_chunks is not None: + pcoll |= xarray_beam.ConsolidateChunks(self.output_chunks) + if self.nan_filler is not None: + pcoll |= beam.MapTuple(self.nan_filler) + pcoll |= self.chunks_to_zarr + return pcoll @dataclasses.dataclass @@ -350,9 +398,8 @@ class MultiRegridTransform(beam.PTransform): """PTransform for regridding to multiple target grids. The most expensive part of regridding (to coarser resolutions) is typically - reading the source dataset - from disk, so this transform does so nce and outputs multiple regridding - targets at the same time. + reading the source dataset from disk, so this transform reads source data + once and outputs all regridding targets simultaneously. Parameters: source: specification of how to load the source dataset. @@ -374,32 +421,20 @@ def expand(self, pcoll: beam.PCollection) -> list[beam.PCollection]: num_threads=self.io_num_threads, ) - output_pcollections = [] - for target in self.regrid_targets: - validate_horizontal_regridder(source_ds, target.horizontal_regridder) - validate_vertical_regridder(source_ds, target.vertical_regridder) - - regrid = get_regrid_func( - target.horizontal_regridder, target.vertical_regridder - ) - template = get_template( - source_ds, target.horizontal_regridder, target.vertical_regridder - ) - output_chunks = get_output_chunks( - input_chunks, target.vertical_regridder, target.output_chunks - ) - - pcoll = source_pcoll | beam.MapTuple(regrid) - if output_chunks is not None: - pcoll |= xarray_beam.ConsolidateChunks(output_chunks) - if target.nan_filler is not None: - pcoll |= beam.MapTuple(target.nan_filler) - pcoll |= xarray_beam.ChunksToZarr( - target.output_path, - template, - output_chunks, - num_threads=self.io_num_threads, - ) - output_pcollections.append(pcoll) + # We setup Zarr stores using separate threads, which otherwise takes ~1 + # minute per regridding target. + with futures.ThreadPoolExecutor( + max_workers=len(self.regrid_targets) + ) as executor: + output_pcollections = [] + for i, target in enumerate(self.regrid_targets): + pcoll = source_pcoll | f'Regrid{i}' >> _RegridTransform( + source_template=source_ds, + input_chunks=input_chunks, + target=target, + io_num_threads=self.io_num_threads, + setup_executor=executor, + ) + output_pcollections.append(pcoll) return output_pcollections