Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve laziness in fuse_multiple to avoid materializing array blocks in _partial_reduce #386

Merged
merged 2 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/dask-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
architecture: x64

- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v1
uses: ts-graphviz/setup-graphviz@v2

- name: Install
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/jax-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
architecture: x64

- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v1
uses: ts-graphviz/setup-graphviz@v2

- name: Install
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
architecture: x64

- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v1
uses: ts-graphviz/setup-graphviz@v2
with:
macos-skip-brew-update: 'true'

Expand Down
3 changes: 3 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,9 @@ def block_function(out_key):

def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None):
# reduce each array in turn, accumulating in result
assert not isinstance(
arrays, list
), "partial reduce expects an iterator of array blocks, not a list"
result = None
for array in arrays:
if initial_func is not None:
Expand Down
22 changes: 16 additions & 6 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,18 @@ def apply_pipeline_block_func(pipeline, n_input_blocks, arg):
else:
# more than one input block is being read from arg
assert isinstance(arg, (list, Iterator))
return tuple(
list(item)
for item in zip(*(pipeline.config.block_function(a) for a in arg))
)
if isinstance(arg, list):
return tuple(
list(item)
for item in zip(*(pipeline.config.block_function(a) for a in arg))
)
else:
# Return iterators to avoid materializing all array blocks at
# once.
return tuple(
iter(list(item))
for item in zip(*(pipeline.config.block_function(a) for a in arg))
)

def fused_blockwise_func(out_key):
# this will change when multiple outputs are supported
Expand All @@ -512,8 +520,10 @@ def apply_pipeline_func(pipeline, n_input_blocks, *args):
if n_input_blocks == 1:
ret = pipeline.config.function(*args)
else:
# more than one input block is being read from this group of args to primitive op
ret = [pipeline.config.function(*item) for item in list(zip(*args))]
# More than one input block is being read from this group of args to primitive op.
# Note that it is important that a list is not returned to avoid materializing all
# array blocks at once.
ret = map(lambda item: pipeline.config.function(*item), zip(*args))
return ret

def fused_func(*args):
Expand Down
23 changes: 23 additions & 0 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import platform
import random
from functools import partial

import dill
import numpy as np
Expand Down Expand Up @@ -578,3 +579,25 @@ def test_quad_means(tmp_path, t_length=50):
res1 = zarr.open_array(tmp_path / "result1")

assert_array_equal(res0[:], res1[:])


def test_quad_means_zarr(tmp_path, t_length=50):
# write inputs to Zarr first to test more realistic usage pattern
spec = cubed.Spec(tmp_path, allowed_mem="2GB", reserved_mem="100MB")
u = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec)
v = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec)

arrays = [u, v]
paths = [f"{tmp_path}/u_{t_length}.zarr", f"{tmp_path}/v_{t_length}.zarr"]
cubed.store(arrays, paths)

u = cubed.from_zarr(f"{tmp_path}/u_{t_length}.zarr", spec=spec)
v = cubed.from_zarr(f"{tmp_path}/v_{t_length}.zarr", spec=spec)
uv = u * v
m = xp.mean(uv, axis=0, use_new_impl=True, split_every=10)

opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=40)

m.visualize(filename=tmp_path / "quad_means", optimize_function=opt_fn)

cubed.to_zarr(m, store=tmp_path / "result", optimize_function=opt_fn)
54 changes: 53 additions & 1 deletion cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cubed
import cubed.array_api as xp
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import elemwise, merge_chunks_new
from cubed.core.ops import elemwise, merge_chunks_new, partial_reduce
from cubed.core.optimization import (
fuse_all_optimize_dag,
fuse_only_optimize_dag,
Expand Down Expand Up @@ -830,6 +830,58 @@ def test_fuse_merge_chunks_binary(spec):
assert_array_equal(result, 2 * np.ones((3, 2)))


# like test_fuse_merge_chunks_unary, except uses partial_reduce
def test_fuse_partial_reduce_unary(spec):
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
b = xp.negative(a)
c = partial_reduce(b, np.sum, split_every={0: 3})

# specify max_total_num_input_blocks to force c to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3)

c.visualize(optimize_function=opt_fn)

# check structure of optimized dag
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (a,), (c,))
optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(b.plan.dag, b.name) == (1,)
assert get_num_input_blocks(c.plan.dag, c.name) == (3,)
assert get_num_input_blocks(optimized_dag, c.name) == (3,)

result = c.compute(optimize_function=opt_fn)
assert_array_equal(result, -3 * np.ones((1, 2)))


# like test_fuse_merge_chunks_binary, except uses partial_reduce
def test_fuse_partial_reduce_binary(spec):
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
b = xp.ones((3, 2), chunks=(1, 2), spec=spec)
c = xp.add(a, b)
d = partial_reduce(c, np.sum, split_every={0: 3})

# specify max_total_num_input_blocks to force d to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6)

d.visualize(optimize_function=opt_fn)

# check structure of optimized dag
expected_fused_dag = create_dag()
add_placeholder_op(expected_fused_dag, (), (a,))
add_placeholder_op(expected_fused_dag, (), (b,))
add_placeholder_op(expected_fused_dag, (a, b), (d,))
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag
assert structurally_equivalent(optimized_dag, expected_fused_dag)
assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1)
assert get_num_input_blocks(d.plan.dag, d.name) == (3,)
assert get_num_input_blocks(optimized_dag, d.name) == (3, 3)

result = d.compute(optimize_function=opt_fn)
assert_array_equal(result, 6 * np.ones((1, 2)))


def test_fuse_only_optimize_dag(spec):
a = xp.ones((2, 2), chunks=(2, 2), spec=spec)
b = xp.negative(a)
Expand Down
Loading