-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Speed up quantiles with sorting (#1513)
<!--Please ensure the PR fulfills the following requirements! --> <!-- If this is your first PR, make sure to add your details to the AUTHORS.rst! --> ### Pull Request Checklist: - [x] This PR addresses an already opened issue (for bug fixes / features) - This PR will help #1255 - [x] Tests for the changes have been added (for bug fixes / features) - [ ] (If applicable) Documentation has been added / updated (for bug fixes / features) - [x] CHANGES.rst has been updated (with summary of main changes) - [x] Link to issue (:issue:`number`) and pull request (:pull:`number`) has been added ### What kind of change does this PR introduce? * `nbutils.quantile` has a speed-up of more than 2.5x by a combination of changes in `nbutils.quantile` and `nbutils._quantile` * This does not cover `nbutils.vec_quantiles` (used for `adapt_freq`) but similar principles could be used * It adds the possibility of using `fastnanquantile` module which is very fast ### Does this PR introduce a breaking change? No ### Other information: * The new low-level function to compute quantiles `nbutils._quantile` is a 1d jitted version of `xclim.core.utils._nan_quantile` * Manual benchmarking can be performed in the notebook `benchmarks/sdba_quantile.ipynb`, attached to this PR.
- Loading branch information
Showing
9 changed files
with
418 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from __future__ import annotations\n", | ||
"\n", | ||
"import time\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"import xclim\n", | ||
"from xclim import sdba\n", | ||
"from xclim.testing import open_dataset\n", | ||
"\n", | ||
"ds = open_dataset(\"sdba/CanESM2_1950-2100.nc\")\n", | ||
"tx = ds.sel(time=slice(\"1950\", \"1980\")).tasmax\n", | ||
"kws = {\"dim\": \"time\", \"q\": np.linspace(0, 1, 50)}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Tests with %%timeit (full 30 years)\n", | ||
"\n", | ||
"Here `fastnanquantile` is the best algorithm out of \n", | ||
"* `xr.DataArray.quantile`\n", | ||
"* `nbutils.quantile`, using: \n", | ||
" * `xclim.core.utils.nan_quantile`\n", | ||
" * `fastnanquantile`\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%timeit\n", | ||
"tx.quantile(**kws).compute()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%timeit\n", | ||
"sdba.nbutils.USE_FASTNANQUANTILE = False\n", | ||
"sdba.nbutils.quantile(tx, **kws).compute()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"! pip install fastnanquantile" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%timeit\n", | ||
"sdba.nbutils.USE_FASTNANQUANTILE = True\n", | ||
"sdba.nbutils.quantile(tx, **kws).compute()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Test computation time as a function of number of points\n", | ||
"\n", | ||
"For a smaller number of time steps <=2000, `_sortquantile` is the best algorithm in general" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import time\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import xarray as xr\n", | ||
"\n", | ||
"num_tests = 500\n", | ||
"timed = {}\n", | ||
"# fastnanquantile has nothing to do with sortquantile\n", | ||
"# I just added a third step using this variable\n", | ||
"\n", | ||
"for use_fnq in [True, False]:\n", | ||
" sdba.nbutils.USE_FASTNANQUANTILE = use_fnq\n", | ||
" # heat-up the jit\n", | ||
" sdba.nbutils.quantile(\n", | ||
" xr.DataArray(np.array([0, 1.5])), dim=\"dim_0\", q=np.array([0.5])\n", | ||
" )\n", | ||
" for size in np.arange(250, 2000 + 250, 250):\n", | ||
" da = tx.isel(time=slice(0, size))\n", | ||
" t0 = time.time()\n", | ||
" for ii in range(num_tests):\n", | ||
" sdba.nbutils.quantile(da, **kws).compute()\n", | ||
" timed[use_fnq].append([size, time.time() - t0])\n", | ||
"\n", | ||
"for k, lab in zip([True, False], [\"xclim.core.utils.nan_quantile\", \"fastnanquantile\"]):\n", | ||
" arr = np.array(timed[k])\n", | ||
" plt.plot(arr[:, 0], arr[:, 1] / num_tests, label=lab)\n", | ||
"plt.legend()\n", | ||
"plt.title(\"Quantile computation, average time vs array size, for 50 quantiles\")\n", | ||
"plt.xlabel(\"Number of time steps in the distribution\")\n", | ||
"plt.ylabel(\"Computation time (s)\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import pytest | ||
import xarray as xr | ||
|
||
from xclim.sdba import nbutils as nbu | ||
|
||
|
||
class TestQuantiles: | ||
@pytest.mark.parametrize("uses_dask", [True, False]) | ||
def test_quantile(self, open_dataset, uses_dask): | ||
da = ( | ||
open_dataset("sdba/CanESM2_1950-2100.nc").sel(time=slice("1950", "1955")).pr | ||
).load() | ||
if uses_dask: | ||
da = da.chunk({"location": 1}) | ||
else: | ||
da = da.load() | ||
q = np.linspace(0.1, 0.99, 50) | ||
out_nbu = nbu.quantile(da, q, dim="time").transpose("location", ...) | ||
out_xr = da.quantile(q=q, dim="time").transpose("location", ...) | ||
np.testing.assert_array_almost_equal(out_nbu.values, out_xr.values) | ||
|
||
def test_edge_cases(self, open_dataset): | ||
q = np.linspace(0.1, 0.99, 50) | ||
|
||
# only 1 non-null value | ||
da = xr.DataArray([1] + [np.nan] * 100, dims="dim_0") | ||
out_nbu = nbu.quantile(da, q, dim="dim_0") | ||
np.testing.assert_array_equal(out_nbu.values, np.full_like(q, 1)) | ||
|
||
# only NANs | ||
da = xr.DataArray([np.nan] * 100, dims="dim_0") | ||
out_nbu = nbu.quantile(da, q, dim="dim_0") | ||
np.testing.assert_array_equal(out_nbu.values, np.full_like(q, np.nan)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.