Skip to content

Commit

Permalink
Improve regridding script
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668194768
  • Loading branch information
shoyer authored and Dinosaur authors committed Sep 13, 2024
1 parent 2b57713 commit fd8272c
Showing 1 changed file with 65 additions and 30 deletions.
95 changes: 65 additions & 30 deletions dinosaur/pipelines/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
file are supported, but irregular spacing is OK.
"""

from concurrent import futures
import dataclasses
from typing import Any, Callable, Mapping

Expand All @@ -30,6 +31,7 @@
import xarray
import xarray_beam


# pylint: disable=logging-fstring-interpolation

HorizontalRegridderFactory = Callable[
Expand Down Expand Up @@ -343,16 +345,61 @@ class RegridTarget:
vertical_regridder: vertical_interpolation.Regridder | None
nan_filler: NaNFiller | None = None
output_chunks: dict[str, int] | None = None
zarr_metadata: dict[str, Any] | None = None


@dataclasses.dataclass
class _RegridTransform(beam.PTransform):
"""PTransform for regridding to a single target grid."""

def __init__(
self,
source_template: xarray.Dataset,
input_chunks: dict[str, int],
target: RegridTarget,
io_num_threads: int | None,
setup_executor: futures.ThreadPoolExecutor,
):
validate_horizontal_regridder(source_template, target.horizontal_regridder)
validate_vertical_regridder(source_template, target.vertical_regridder)

self.nan_filler = target.nan_filler
self.regrid = get_regrid_func(
target.horizontal_regridder, target.vertical_regridder
)
template = get_template(
source_template, target.horizontal_regridder, target.vertical_regridder
)
if target.zarr_metadata:
template.attrs.update(target.zarr_metadata)
self.output_chunks = get_output_chunks(
input_chunks, target.vertical_regridder, target.output_chunks
)
self.chunks_to_zarr = xarray_beam.ChunksToZarr(
target.output_path,
template,
self.output_chunks,
num_threads=io_num_threads,
setup_executor=setup_executor,
)

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
pcoll = pcoll | beam.MapTuple(self.regrid)
if self.output_chunks is not None:
pcoll |= xarray_beam.ConsolidateChunks(self.output_chunks)
if self.nan_filler is not None:
pcoll |= beam.MapTuple(self.nan_filler)
pcoll |= self.chunks_to_zarr
return pcoll


@dataclasses.dataclass
class MultiRegridTransform(beam.PTransform):
"""PTransform for regridding to multiple target grids.
The most expensive part of regridding (to coarser resolutions) is typically
reading the source dataset
from disk, so this transform does so nce and outputs multiple regridding
targets at the same time.
reading the source dataset from disk, so this transform reads source data
once and outputs all regridding targets simultaneously.
Parameters:
source: specification of how to load the source dataset.
Expand All @@ -374,32 +421,20 @@ def expand(self, pcoll: beam.PCollection) -> list[beam.PCollection]:
num_threads=self.io_num_threads,
)

output_pcollections = []
for target in self.regrid_targets:
validate_horizontal_regridder(source_ds, target.horizontal_regridder)
validate_vertical_regridder(source_ds, target.vertical_regridder)

regrid = get_regrid_func(
target.horizontal_regridder, target.vertical_regridder
)
template = get_template(
source_ds, target.horizontal_regridder, target.vertical_regridder
)
output_chunks = get_output_chunks(
input_chunks, target.vertical_regridder, target.output_chunks
)

pcoll = source_pcoll | beam.MapTuple(regrid)
if output_chunks is not None:
pcoll |= xarray_beam.ConsolidateChunks(output_chunks)
if target.nan_filler is not None:
pcoll |= beam.MapTuple(target.nan_filler)
pcoll |= xarray_beam.ChunksToZarr(
target.output_path,
template,
output_chunks,
num_threads=self.io_num_threads,
)
output_pcollections.append(pcoll)
# We setup Zarr stores using separate threads, which otherwise takes ~1
# minute per regridding target.
with futures.ThreadPoolExecutor(
max_workers=len(self.regrid_targets)
) as executor:
output_pcollections = []
for i, target in enumerate(self.regrid_targets):
pcoll = source_pcoll | f'Regrid{i}' >> _RegridTransform(
source_template=source_ds,
input_chunks=input_chunks,
target=target,
io_num_threads=self.io_num_threads,
setup_executor=executor,
)
output_pcollections.append(pcoll)

return output_pcollections

0 comments on commit fd8272c

Please sign in to comment.