Skip to content

Commit

Permalink
Fuse operations with different numbers of tasks (#368)
Browse files Browse the repository at this point in the history
* Add num_input_blocks to primitive operation for bookkeeping
when fusing operations with different numbers of tasks

* Implement merge_chunks_new using general_blockwise

* Add failing test for merge_chunks_new

* Don't assert that num tasks must be the same in fuse_multiple

* Fuse primitive ops with different numbers of tasks

* Improve quad means test

* Move 'num_input_blocks' to blockwise spec

Add tests to check num_input_blocks

* Add another test, and improve code formatting
  • Loading branch information
tomwhite authored Feb 5, 2024
1 parent 1762d3c commit 8831b94
Show file tree
Hide file tree
Showing 7 changed files with 417 additions and 70 deletions.
3 changes: 2 additions & 1 deletion cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def max(x, /, *, axis=None, keepdims=False):
return reduction(x, nxp.max, axis=axis, dtype=x.dtype, keepdims=keepdims)


def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False):
def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False, split_every=None):
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")
# This implementation uses NumPy and Zarr's structured arrays to store a
Expand All @@ -47,6 +47,7 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False):
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
extra_func_kwargs=extra_func_kwargs,
)

Expand Down
67 changes: 65 additions & 2 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
from cubed.primitive.blockwise import blockwise as primitive_blockwise
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
from cubed.primitive.rechunk import rechunk as primitive_rechunk
from cubed.utils import chunk_memory, get_item, offset_to_block_id, to_chunksize
from cubed.utils import (
_concatenate2,
chunk_memory,
get_item,
offset_to_block_id,
to_chunksize,
)
from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks
from cubed.vendor.dask.array.utils import validate_axis
from cubed.vendor.dask.blockwise import broadcast_dimensions
from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product
from cubed.vendor.dask.utils import has_keyword

if TYPE_CHECKING:
Expand Down Expand Up @@ -266,6 +272,7 @@ def blockwise(
extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

fusable = kwargs.pop("fusable", True)
num_input_blocks = kwargs.pop("num_input_blocks", None)

name = gensym()
spec = check_array_specs(arrays)
Expand All @@ -287,6 +294,7 @@ def blockwise(
out_name=name,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
num_input_blocks=num_input_blocks,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -324,6 +332,8 @@ def general_blockwise(

extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

num_input_blocks = kwargs.pop("num_input_blocks", None)

name = gensym()
spec = check_array_specs(arrays)
if target_store is None:
Expand All @@ -341,6 +351,7 @@ def general_blockwise(
chunks=chunks,
in_names=in_names,
extra_func_kwargs=extra_func_kwargs,
num_input_blocks=num_input_blocks,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -759,6 +770,7 @@ def rechunk(x, chunks, target_store=None):


def merge_chunks(x, chunks):
"""Merge multiple chunks into one."""
target_chunksize = chunks
if len(target_chunksize) != x.ndim:
raise ValueError(
Expand Down Expand Up @@ -787,6 +799,56 @@ def _copy_chunk(e, x, target_chunks=None, block_id=None):
return out


def merge_chunks_new(x, chunks):
# new implementation that uses general_blockwise rather than map_direct
target_chunksize = chunks
if len(target_chunksize) != x.ndim:
raise ValueError(
f"Chunks {target_chunksize} must have same number of dimensions as array ({x.ndim})"
)
if not all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunksize)):
raise ValueError(
f"Chunks {target_chunksize} must be a multiple of array's chunks {x.chunksize}"
)

target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
axes = [
i for (i, (c0, c1)) in enumerate(zip(x.chunksize, target_chunksize)) if c0 != c1
]

def block_function(out_key):
out_coords = out_key[1:]

in_keys = []
for i, (c0, c1) in enumerate(zip(x.chunksize, target_chunksize)):
k = c1 // c0 # number of blocks to merge in axis i
if k == 1:
in_keys.append(out_coords[i])
else:
start = out_coords[i] * k
stop = min(start + k, x.numblocks[i])
in_keys.append(list(range(start, stop)))

# return a tuple with a single item that is the list of input keys to be merged
return (lol_product((x.name,), in_keys),)

num_input_blocks = int(
np.prod([c1 // c0 for (c0, c1) in zip(x.chunksize, target_chunksize)])
)

return general_blockwise(
_concatenate2,
block_function,
x,
shape=x.shape,
dtype=x.dtype,
chunks=target_chunks,
extra_projected_mem=0,
num_input_blocks=(num_input_blocks,),
axes=axes,
)


def reduction(
x: "Array",
func,
Expand Down Expand Up @@ -1059,6 +1121,7 @@ def block_function(out_key):
dtype=dtype,
chunks=chunks,
extra_projected_mem=extra_projected_mem,
num_input_blocks=(sum(split_every.values()),),
reduce_func=func,
initial_func=initial_func,
axis=axis,
Expand Down
5 changes: 5 additions & 0 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ def visualize(

# remove pipeline attribute since it is a long string that causes graphviz to fail
if "pipeline" in d:
pipeline = d["pipeline"]
if pipeline.config is not None:
tooltip += (
f"\nnum input blocks: {pipeline.config.num_input_blocks}"
)
del d["pipeline"]

if "stack_summaries" in d and d["stack_summaries"] is not None:
Expand Down
93 changes: 74 additions & 19 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import math
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -45,6 +46,8 @@ class BlockwiseSpec:
A function that maps input chunks to an output chunk.
function_nargs: int
The number of array arguments that ``function`` takes.
num_input_blocks: Tuple[int, ...]
The number of input blocks read from each input array.
reads_map : Dict[str, CubedArrayProxy]
Read proxy dictionary keyed by array name.
write : CubedArrayProxy
Expand All @@ -54,6 +57,7 @@ class BlockwiseSpec:
block_function: Callable[..., Any]
function: Callable[..., Any]
function_nargs: int
num_input_blocks: Tuple[int, ...]
reads_map: Dict[str, CubedArrayProxy]
write: CubedArrayProxy

Expand Down Expand Up @@ -119,6 +123,7 @@ def blockwise(
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
num_input_blocks: Optional[Tuple[int, ...]] = None,
**kwargs,
):
"""Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules.
Expand Down Expand Up @@ -201,6 +206,7 @@ def blockwise(
extra_projected_mem=extra_projected_mem,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
num_input_blocks=num_input_blocks,
**kwargs,
)

Expand All @@ -219,6 +225,7 @@ def general_blockwise(
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
num_input_blocks: Optional[Tuple[int, ...]] = None,
**kwargs,
):
"""A more general form of ``blockwise`` that uses a function to specify the block
Expand Down Expand Up @@ -271,12 +278,18 @@ def general_blockwise(

func_kwargs = extra_func_kwargs or {}
func_with_kwargs = partial(func, **{**kwargs, **func_kwargs})
num_input_blocks = num_input_blocks or (1,) * len(arrays)
read_proxies = {
name: CubedArrayProxy(array, array.chunks) for name, array in array_map.items()
}
write_proxy = CubedArrayProxy(target_array, chunksize)
spec = BlockwiseSpec(
block_function, func_with_kwargs, len(arrays), read_proxies, write_proxy
block_function,
func_with_kwargs,
len(arrays),
num_input_blocks,
read_proxies,
write_proxy,
)

# calculate projected memory
Expand Down Expand Up @@ -344,10 +357,17 @@ def can_fuse_multiple_primitive_ops(
if is_fuse_candidate(primitive_op) and all(
is_fuse_candidate(p) for p in predecessor_primitive_ops
):
# if the peak projected memory for running all the predecessor ops in order is
# larger than allowed_mem then we can't fuse
# If the peak projected memory for running all the predecessor ops in
# order is larger than allowed_mem then we can't fuse.
if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem:
return False
# If the number of input blocks for each input is not uniform, then we
# can't fuse. (This should never happen since all operations are
# currently uniform, and fused operations are too if fuse is applied in
# topological order.)
num_input_blocks = primitive_op.pipeline.config.num_input_blocks
if not all(num_input_blocks[0] == n for n in num_input_blocks):
return False
return all(
primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops
)
Expand Down Expand Up @@ -390,8 +410,17 @@ def fused_func(*args):
function_nargs = pipeline1.config.function_nargs
read_proxies = pipeline1.config.reads_map
write_proxy = pipeline2.config.write
num_input_blocks = tuple(
n * pipeline2.config.num_input_blocks[0]
for n in pipeline1.config.num_input_blocks
)
spec = BlockwiseSpec(
fused_blockwise_func, fused_func, function_nargs, read_proxies, write_proxy
fused_blockwise_func,
fused_func,
function_nargs,
num_input_blocks,
read_proxies,
write_proxy,
)

target_array = primitive_op2.target_array
Expand Down Expand Up @@ -424,12 +453,6 @@ def fuse_multiple(
Fuse a blockwise operation and its predecessors into a single operation, avoiding writing to (or reading from) the targets of the predecessor operations.
"""

assert all(
primitive_op.num_tasks == p.num_tasks
for p in predecessor_primitive_ops
if p is not None
)

pipeline = primitive_op.pipeline
predecessor_pipelines = [
primitive_op.pipeline if primitive_op is not None else None
Expand All @@ -444,42 +467,74 @@ def fuse_multiple(

mappable = pipeline.mappable

def apply_pipeline_block_func(pipeline, arg):
def apply_pipeline_block_func(pipeline, n_input_blocks, arg):
if pipeline is None:
return (arg,)
return pipeline.config.block_function(arg)
if n_input_blocks == 1:
assert isinstance(arg, tuple)
return pipeline.config.block_function(arg)
else:
# more than one input block is being read from arg
assert isinstance(arg, (list, Iterator))
return tuple(
list(item)
for item in zip(*(pipeline.config.block_function(a) for a in arg))
)

def fused_blockwise_func(out_key):
# this will change when multiple outputs are supported
args = pipeline.config.block_function(out_key)
# split all args to the fused function into groups, one for each predecessor function
func_args = tuple(
item
for p, a in zip(predecessor_pipelines, args)
for item in apply_pipeline_block_func(p, a)
for i, (p, a) in enumerate(zip(predecessor_pipelines, args))
for item in apply_pipeline_block_func(
p, pipeline.config.num_input_blocks[i], a
)
)
return split_into(func_args, predecessor_funcs_nargs)

def apply_pipeline_func(pipeline, *args):
def apply_pipeline_func(pipeline, n_input_blocks, *args):
if pipeline is None:
return args[0]
return pipeline.config.function(*args)
if n_input_blocks == 1:
ret = pipeline.config.function(*args)
else:
# more than one input block is being read from this group of args to primitive op
ret = [pipeline.config.function(*item) for item in list(zip(*args))]
return ret

def fused_func(*args):
# args are grouped appropriately so they can be called by each predecessor function
func_args = [
apply_pipeline_func(p, *a) for p, a in zip(predecessor_pipelines, args)
apply_pipeline_func(p, pipeline.config.num_input_blocks[i], *a)
for i, (p, a) in enumerate(zip(predecessor_pipelines, args))
]
return pipeline.config.function(*func_args)

function_nargs = pipeline.config.function_nargs
fused_function_nargs = pipeline.config.function_nargs
# ok to get num_input_blocks[0] since it is uniform (see check in can_fuse_multiple_primitive_ops)
fused_num_input_blocks = tuple(
pipeline.config.num_input_blocks[0] * n
for n in itertools.chain(
*(
p.pipeline.config.num_input_blocks if p is not None else (1,)
for p in predecessor_primitive_ops
)
)
)
read_proxies = dict(pipeline.config.reads_map)
for p in predecessor_pipelines:
if p is not None:
read_proxies.update(p.config.reads_map)
write_proxy = pipeline.config.write
spec = BlockwiseSpec(
fused_blockwise_func, fused_func, function_nargs, read_proxies, write_proxy
fused_blockwise_func,
fused_func,
fused_function_nargs,
fused_num_input_blocks,
read_proxies,
write_proxy,
)

target_array = primitive_op.target_array
Expand Down
15 changes: 12 additions & 3 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import cubed.random
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import merge_chunks, partial_reduce, tree_reduce
from cubed.core.optimization import fuse_all_optimize_dag
from cubed.core.optimization import fuse_all_optimize_dag, multiple_inputs_optimize_dag
from cubed.tests.utils import (
ALL_EXECUTORS,
MAIN_EXECUTORS,
Expand Down Expand Up @@ -531,10 +531,19 @@ def test_plan_quad_means(tmp_path, t_length):
u = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec)
v = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec)
uv = u * v
m = xp.mean(uv, axis=0)
m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True)

assert m.plan.num_tasks() > 0
m.visualize(filename=tmp_path / "quad_means")
m.visualize(
filename=tmp_path / "quad_means_unoptimized",
optimize_graph=False,
show_hidden=True,
)
m.visualize(
filename=tmp_path / "quad_means",
optimize_function=multiple_inputs_optimize_dag,
show_hidden=True,
)


def quad_means(tmp_path, t_length):
Expand Down
Loading

0 comments on commit 8831b94

Please sign in to comment.