Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple outputs #419

Merged
merged 6 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,9 @@ def key_function(out_key):
key_function,
x,
template,
shape=shape,
dtype=x.dtype,
chunks=outchunks,
shapes=[shape],
dtypes=[x.dtype],
chunkss=[outchunks],
)


Expand Down Expand Up @@ -402,9 +402,9 @@ def key_function(out_key):
_read_stack_chunk,
key_function,
*arrays,
shape=shape,
dtype=dtype,
chunks=chunks,
shapes=[shape],
dtypes=[dtype],
chunkss=[chunks],
axis=axis,
fusable=False,
)
Expand Down
60 changes: 36 additions & 24 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from itertools import product
from numbers import Integral, Number
from operator import add
from typing import TYPE_CHECKING, Any, Sequence, Union
from typing import TYPE_CHECKING, Any, Sequence, Tuple, Union
from warnings import warn

import ndindex
Expand Down Expand Up @@ -333,14 +333,14 @@ def general_blockwise(
func,
key_function,
*arrays,
shape,
dtype,
chunks,
target_store=None,
target_path=None,
shapes,
dtypes,
chunkss,
target_stores=None,
target_paths=None,
extra_func_kwargs=None,
**kwargs,
) -> "Array":
) -> Union["Array", Tuple["Array", ...]]:
assert len(arrays) > 0

# replace arrays with zarr arrays
Expand All @@ -354,24 +354,33 @@ def general_blockwise(

num_input_blocks = kwargs.pop("num_input_blocks", None)

name = gensym()
spec = check_array_specs(arrays)
if target_store is None:
target_store = new_temp_path(name=name, spec=spec)

if isinstance(target_stores, list): # multiple outputs
name = [gensym() for _ in range(len(target_stores))]
target_stores = [
ts if ts is not None else new_temp_path(name=n, spec=spec)
for n, ts in zip(name, target_stores)
]
else: # single output
name = gensym()
if target_stores is None:
target_stores = [new_temp_path(name=name, spec=spec)]

op = primitive_general_blockwise(
func,
key_function,
*zargs,
allowed_mem=spec.allowed_mem,
reserved_mem=spec.reserved_mem,
extra_projected_mem=extra_projected_mem,
target_store=target_store,
target_path=target_path,
target_stores=target_stores,
target_paths=target_paths,
storage_options=spec.storage_options,
compressor=spec.zarr_compressor,
shape=shape,
dtype=dtype,
chunks=chunks,
shapes=shapes,
dtypes=dtypes,
chunkss=chunkss,
in_names=in_names,
extra_func_kwargs=extra_func_kwargs,
num_input_blocks=num_input_blocks,
Expand All @@ -387,7 +396,10 @@ def general_blockwise(
)
from cubed.array_api import Array

return Array(name, op.target_array, spec, plan)
if isinstance(op.target_array, list): # multiple outputs
return tuple(Array(n, ta, spec, plan) for n, ta in zip(name, op.target_array))
else: # single output
return Array(name, op.target_array, spec, plan)


def elemwise(func, *args: "Array", dtype=None) -> "Array":
Expand Down Expand Up @@ -914,9 +926,9 @@ def key_function(out_key):
_concatenate2,
key_function,
x,
shape=x.shape,
dtype=x.dtype,
chunks=target_chunks,
shapes=[x.shape],
dtypes=[x.dtype],
chunkss=[target_chunks],
extra_projected_mem=0,
num_input_blocks=(num_input_blocks,),
axes=axes,
Expand Down Expand Up @@ -1229,12 +1241,12 @@ def partial_reduce(
axis = tuple(ax for ax in split_every.keys())
combine_sizes = combine_sizes or {}
combine_sizes = {k: combine_sizes.get(k, 1) for k in axis}
chunks = [
chunks = tuple(
(combine_sizes[i],) * math.ceil(len(c) / split_every[i])
if i in split_every
else c
for (i, c) in enumerate(x.chunks)
]
)
shape = tuple(map(sum, chunks))

def key_function(out_key):
Expand Down Expand Up @@ -1263,9 +1275,9 @@ def key_function(out_key):
_partial_reduce,
key_function,
x,
shape=shape,
dtype=dtype,
chunks=chunks,
shapes=[shape],
dtypes=[dtype],
chunkss=[chunks],
extra_projected_mem=extra_projected_mem,
num_input_blocks=(sum(split_every.values()),),
reduce_func=func,
Expand Down
21 changes: 19 additions & 2 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def can_fuse(n):
if "primitive_op" not in nodes[op2]:
return False

# if node (op2) does not have exactly one input then don't fuse
# if node (op2) does not have exactly one input and output then don't fuse
# (it could have no inputs or multiple inputs)
if dag.in_degree(op2) != 1:
if dag.in_degree(op2) != 1 or dag.out_degree(op2) != 1:
return False

# if input is one of the arrays being computed then don't fuse
Expand Down Expand Up @@ -91,6 +91,12 @@ def predecessors_unordered(dag, name):
yield pre


def successors_unordered(dag, name):
"""Return a node's successors in no particular order, with repeats for multiple edges."""
for pre, _ in dag.out_edges(name):
yield pre


def predecessor_ops(dag, name):
"""Return an op node's op predecessors in the same order as the input source arrays for the op.

Expand Down Expand Up @@ -183,6 +189,17 @@ def can_fuse_predecessors(
)
return False

# if any predecessor ops have multiple outputs then don't fuse
# TODO: implement "child fusion" (where a multiple output op fuses its children)
if any(
len(list(successors_unordered(dag, pre))) > 1
for pre in predecessor_ops(dag, name)
):
logger.debug(
"can't fuse %s since at least one predecessor has multiple outputs", name
)
return False

# if node is in never_fuse or always_fuse list then it overrides logic below
if never_fuse is not None and name in never_fuse:
logger.debug("can't fuse %s since it is in 'never_fuse'", name)
Expand Down
60 changes: 41 additions & 19 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Plan:
def __init__(self, dag):
self.dag = dag

# args from pipeline onwards are omitted for creation functions when no computation is needed
# args from primitive_op onwards are omitted for creation functions when no computation is needed
@classmethod
def _new(
cls,
Expand Down Expand Up @@ -110,15 +110,26 @@ def _new(
op_display_name=f"{op_name_unique}\n{first_cubed_summary.name}",
hidden=hidden,
)
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
# array
if isinstance(name, list): # multiple outputs
for n, t in zip(name, target):
dag.add_node(
n,
name=n,
type="array",
target=t,
hidden=hidden,
)
dag.add_edge(op_name_unique, n)
else: # single output
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
else:
# op
dag.add_node(
Expand All @@ -132,15 +143,26 @@ def _new(
primitive_op=primitive_op,
pipeline=primitive_op.pipeline,
)
# array (when multiple outputs are supported there could be more than one)
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
# array
if isinstance(name, list): # multiple outputs
for n, t in zip(name, target):
dag.add_node(
n,
name=n,
type="array",
target=t,
hidden=hidden,
)
dag.add_edge(op_name_unique, n)
else: # single output
dag.add_node(
name,
name=name,
type="array",
target=target,
hidden=hidden,
)
dag.add_edge(op_name_unique, name)
for x in source_arrays:
if hasattr(x, "name"):
dag.add_edge(x.name, op_name_unique)
Expand Down
Loading
Loading