Skip to content

Commit

Permalink
added outrange feature to qdm and presrat
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster authored and bnb32 committed Aug 16, 2024
1 parent 74492f5 commit e9969da
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
27 changes: 22 additions & 5 deletions sup3r/bias/bias_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def local_qdm_bc(
no_trend=False,
delta_denom_min=None,
delta_denom_zero=None,
out_range=None,
max_workers=1,
):
"""Bias correction using QDM
Expand Down Expand Up @@ -535,6 +536,8 @@ def local_qdm_bc(
division by a very small number making delta blow up and resulting
in very large output bias corrected values. See equation 4 of
Cannon et al., 2015 for the delta term.
out_range : None | tuple
Option to set floor/ceiling values on the output data.
max_workers: int | None
Max number of workers to use for QDM process pool
Expand Down Expand Up @@ -647,6 +650,10 @@ def local_qdm_bc(
# Position output respecting original time axis sequence
output[:, :, subset_idx] = tmp

if out_range is not None:
output = np.maximum(output, np.min(out_range))
output = np.minimum(output, np.max(out_range))

return output


Expand Down Expand Up @@ -785,13 +792,14 @@ def apply_presrat_bc(data, time_index, base_params, bias_params,
bias_fut_params, bias_tau_fut, k_factor,
time_window_center, dist='empirical', sampling='invlog',
log_base=10, relative=True, no_trend=False,
zero_rate_threshold=1.182033e-07):
zero_rate_threshold=1.182033e-07, out_range=None,
max_workers=1):
"""Run PresRat to bias correct data from input parameters and not from bias
correction file on disk."""

data_unbiased = np.full_like(data, np.nan)
closest_time_idx = abs(time_window_center[:, np.newaxis] -
np.array(time_index.day_of_year))
closest_time_idx = abs(time_window_center[:, np.newaxis]
- np.array(time_index.day_of_year))
closest_time_idx = closest_time_idx.argmin(axis=0)

for nt in set(closest_time_idx):
Expand Down Expand Up @@ -819,7 +827,7 @@ def apply_presrat_bc(data, time_index, base_params, bias_params,
# QDM expects input arr with shape (time, space)
tmp = subset.reshape(-1, subset.shape[-1]).T
# Apply QDM correction
tmp = QDM(tmp)
tmp = QDM(tmp, max_workers=max_workers)
# Reorgnize array back from (time, space)
# to (spatial, spatial, temporal)
subset = tmp.T.reshape(subset.shape)
Expand All @@ -832,6 +840,10 @@ def apply_presrat_bc(data, time_index, base_params, bias_params,

data_unbiased[:, :, subset_idx] = subset

if out_range is not None:
data_unbiased = np.maximum(data_unbiased, np.min(out_range))
data_unbiased = np.minimum(data_unbiased, np.max(out_range))

return data_unbiased


Expand All @@ -845,6 +857,7 @@ def local_presrat_bc(data: np.ndarray,
threshold=0.1,
relative=True,
no_trend=False,
out_range=None,
max_workers=1,
):
"""Bias correction using PresRat
Expand Down Expand Up @@ -899,6 +912,8 @@ def local_presrat_bc(data: np.ndarray,
:class:`rex.utilities.bc_utils.QuantileDeltaMapping`. Note that this
assumes that params_mh is the data distribution representative for the
target data.
out_range : None | tuple
Option to set floor/ceiling values on the output data.
max_workers : int | None
Max number of workers to use for QDM process pool
"""
Expand Down Expand Up @@ -936,6 +951,8 @@ def local_presrat_bc(data: np.ndarray,
time_window_center, dist=dist,
sampling=sampling, log_base=log_base,
relative=relative, no_trend=no_trend,
zero_rate_threshold=zero_rate_threshold)
zero_rate_threshold=zero_rate_threshold,
out_range=out_range,
max_workers=max_workers)

return data_unbiased
7 changes: 2 additions & 5 deletions tests/rasterizers/test_rasterizer_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,8 @@ def test_data_caching(input_files, ext, shape, target, features):
rasterizer, cache_kwargs={'cache_pattern': cache_pattern}
)

assert rasterizer.shape[:3] == (
shape[0],
shape[1],
rasterizer.shape[2],
)
good_shape = (shape[0], shape[1], rasterizer.shape[2])
assert rasterizer.shape[:3] == good_shape
assert rasterizer.data.dtype == np.dtype(np.float32)
loader = Loader(cacher.out_files)
assert np.array_equal(
Expand Down

0 comments on commit e9969da

Please sign in to comment.