Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

method heuristics: Avoid dot product as much as possible #347

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ERA5Dataset:
"""ERA5"""

def __init__(self, *args, **kwargs):
self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="h"))
self.axis = (-1,)
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 48))

Expand Down Expand Up @@ -143,7 +143,7 @@ class PerfectMonthly(Cohorts):
"""Perfectly chunked for a "cohorts" monthly mean climatology"""

def setup(self, *args, **kwargs):
self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="M"))
self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="ME"))
self.axis = (-1,)
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4))
self.by = self.time.dt.month.values - 1
Expand All @@ -164,7 +164,7 @@ def rechunk(self):
class ERA5Google(Cohorts):
def setup(self, *args, **kwargs):
TIME = 900 # 92044 in Google ARCO ERA5
self.time = pd.Series(pd.date_range("1959-01-01", freq="6H", periods=TIME))
self.time = pd.Series(pd.date_range("1959-01-01", freq="6h", periods=TIME))
self.axis = (2,)
self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, 1))
self.by = self.time.dt.day.values - 1
Expand Down Expand Up @@ -201,3 +201,12 @@ def setup(self, *args, **kwargs):
self.time = pd.Series(index)
self.by = self.time.dt.dayofyear.values - 1
self.expected = pd.RangeIndex(self.by.max() + 1)


class RandomBigArray(Cohorts):
def setup(self, *args, **kwargs):
M, N = 100_000, 20_000
self.array = dask.array.random.normal(size=(M, N), chunks=(10_000, N // 5)).T
self.by = np.random.choice(5_000, size=M)
self.expected = pd.RangeIndex(5000)
self.axis = (1,)
1 change: 1 addition & 0 deletions ci/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ channels:
- conda-forge
dependencies:
- asv
- build
- cachey
- dask-core
- numpy>=1.22
Expand Down
6 changes: 4 additions & 2 deletions docs/source/implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,10 @@ label overlaps with all other labels. The algorithm is as follows.
![cohorts-schematic](/../diagrams/containment.png)

1. To choose between `"map-reduce"` and `"cohorts"`, we need a summary measure of the degree to which the labels overlap with
each other. We use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`.
When sparsity > 0.6, we choose `"map-reduce"` since there is decent overlap between (any) cohorts. Otherwise we use `"cohorts"`.
each other. We can use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`.
We use sparsity(`S`) as an approximation for the sparsity(`C`) to avoid a potentially expensive sparse matrix dot product when `S`
isn't particularly sparse. When sparsity(`S`) > 0.4 (arbitrary), we choose `"map-reduce"` since there is decent overlap between
(any) cohorts. Otherwise we use `"cohorts"`.

Cool, isn't it?!

Expand Down
46 changes: 32 additions & 14 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,37 +363,55 @@ def invert(x) -> tuple[np.ndarray, ...]:
logger.info("find_group_cohorts: cohorts is preferred, chunking is perfect.")
return "cohorts", chunks_cohorts

# Containment = |Q & S| / |Q|
# We'll use containment to measure degree of overlap between labels.
# Containment C = |Q & S| / |Q|
# - |X| is the cardinality of set X
# - Q is the query set being tested
# - S is the existing set
# We'll use containment to measure degree of overlap between labels. The bitmask
# matrix allows us to calculate this pretty efficiently.
asfloat = bitmask.astype(float)
# Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
# makes it non-symmetric.
containment = csr_array((asfloat.T @ asfloat) / chunks_per_label)

# The containment matrix is a measure of how much the labels overlap
# with each other. We treat the sparsity = (nnz/size) as a summary measure of the net overlap.
# The bitmask matrix S allows us to calculate this pretty efficiently using a dot product.
# S.T @ S / chunks_per_label
#
# We treat the sparsity(C) = (nnz/size) as a summary measure of the net overlap.
# 1. For high enough sparsity, there is a lot of overlap and we should use "map-reduce".
# 2. When labels are uniformly distributed amongst all chunks
# (and number of labels < chunk size), sparsity is 1.
# 3. Time grouping cohorts (e.g. dayofyear) appear as lines in this matrix.
# 4. When there are no overlaps at all between labels, containment is a block diagonal matrix
# (approximately).
MAX_SPARSITY_FOR_COHORTS = 0.6 # arbitrary
sparsity = containment.nnz / math.prod(containment.shape)
#
# However computing S.T @ S can still be the slowest step, especially if S
# is not particularly sparse. Empirically the sparsity( S.T @ S ) > min(1, 2 x sparsity(S)).
# So we use sparsity(S) as a shortcut.
MAX_SPARSITY_FOR_COHORTS = 0.4 # arbitrary
sparsity = bitmask.nnz / math.prod(bitmask.shape)
preferred_method: Literal["map-reduce"] | Literal["cohorts"]
logger.debug(
"sparsity of bitmask is {}, threshold is {}".format( # noqa
sparsity, MAX_SPARSITY_FOR_COHORTS
)
)
if sparsity > MAX_SPARSITY_FOR_COHORTS:
logger.info("sparsity is {}".format(sparsity)) # noqa
if not merge:
logger.info("find_group_cohorts: merge=False, choosing 'map-reduce'")
logger.info(
"find_group_cohorts: bitmask sparsity={}, merge=False, choosing 'map-reduce'".format( # noqa
sparsity
)
)
return "map-reduce", {}
preferred_method = "map-reduce"
else:
preferred_method = "cohorts"

# Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
# makes it non-symmetric.
asfloat = bitmask.astype(float)
containment = csr_array(asfloat.T @ asfloat / chunks_per_label)

logger.debug(
"sparsity of containment matrix is {}".format( # noqa
containment.nnz / math.prod(containment.shape)
)
)
# Use a threshold to force some merging. We do not use the filtered
# containment matrix for estimating "sparsity" because it is a bit
# hard to reason about.
Expand Down
Loading