Skip to content

Commit

Permalink
improve speed of bootstrap tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasloveday committed Feb 6, 2025
1 parent e7301ce commit 933b0a3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
5 changes: 3 additions & 2 deletions src/scores/processing/block_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
"""

import math
import os
from collections import OrderedDict
from itertools import chain, cycle, islice
from typing import Dict, Iterable, List, Tuple, Union
from typing import Dict, List, Tuple, Union

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -212,12 +213,12 @@ def _block_bootstrap( # pylint: disable=too-many-locals
"arrays containing lists of dimensions to exclude for each array"
)
renames = []

for i, (obj, exclude) in enumerate(zip(array_list, exclude_dims)):
array_list[i] = obj.rename(
{d: f"dim{ii}" for ii, d in enumerate(exclude)},
)
renames.append({f"dim{ii}": d for ii, d in enumerate(exclude)})

dim = list(blocks.keys())

# Ensure bootstrapped dimensions have consistent sizes across arrays_list
Expand Down
14 changes: 7 additions & 7 deletions tests/processing/test_block_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,32 +286,32 @@ def test_block_bootstrap(objects, blocks, n_iteration, exclude_dims, circular, e
),
# Dask arrays to meet block_size < 1
(
[xr.DataArray(da.random.random((1000, 1000, 30), chunks=dict(dim1=-1)), dims=["dim1", "dim2", "dim3"])],
[xr.DataArray(da.random.random((110, 110, 30), chunks=dict(dim1=-1)), dims=["dim1", "dim2", "dim3"])],
{"dim1": 2, "dim2": 2},
2,
None,
True,
(30, 1000, 1000, 2),
(30, 110, 110, 2),
),
# Dask arrays for a case with leftover != 0
(
[xr.DataArray(da.random.random((1000, 1000, 10), chunks=dict(dim1=-1)), dims=["dim1", "dim2", "dim3"])],
[xr.DataArray(da.random.random((110, 110, 10), chunks=dict(dim1=-1)), dims=["dim1", "dim2", "dim3"])],
{"dim1": 2, "dim2": 2},
3,
None,
True,
(10, 1000, 1000, 3),
(10, 110, 110, 3),
),
# Dataset with dask arrays
(
[
xr.Dataset(
{
"var1": xr.DataArray(
da.random.random((1000, 1000, 30), chunks=dict(dim1=-1)), dims=["dim1", "dim2", "dim3"]
da.random.random((110, 110, 30), chunks=dict(dim1=-1)), dims=["dim1", "dim2", "dim3"]
),
"var2": xr.DataArray(
da.random.random((1000, 1000, 30), chunks=dict(dim1=-1)), dims=["dim1", "dim2", "dim3"]
da.random.random((110, 110, 30), chunks=dict(dim1=-1)), dims=["dim1", "dim2", "dim3"]
),
}
)
Expand All @@ -320,7 +320,7 @@ def test_block_bootstrap(objects, blocks, n_iteration, exclude_dims, circular, e
3,
None,
True,
(30, 1000, 1000, 3),
(30, 110, 110, 3),
),
],
)
Expand Down

0 comments on commit 933b0a3

Please sign in to comment.