Skip to content

Commit

Permalink
Merge pull request #556 from dask-contrib/dask-task-usage
Browse files Browse the repository at this point in the history
fix: adapt to new Task spec in dask, now used in blockwise
  • Loading branch information
lgray authored Dec 16, 2024
2 parents 1d4d4e9 + 5d01fef commit d3f3e7c
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 28 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ classifiers = [
]
dependencies = [
"awkward >=2.5.1",
"dask >=2023.04.0",
"dask >=2024.12.0;python_version>'3.9'",
"dask >=2023.04.0;python_version<'3.10'",
"cachetools",
"typing_extensions >=4.8.0",
]
Expand Down
2 changes: 2 additions & 0 deletions src/dask_awkward/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ImplementsIOFunction,
ImplementsProjection,
IOFunctionWithMocking,
_dask_uses_tasks,
io_func_implements_projection,
)

Expand All @@ -18,4 +19,5 @@
"ImplementsIOFunction",
"IOFunctionWithMocking",
"io_func_implements_projection",
"_dask_uses_tasks",
)
29 changes: 21 additions & 8 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast

import dask

_dask_uses_tasks = hasattr(dask, "_task_spec")

from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token
from dask.highlevelgraph import MaterializedLayer
from dask.layers import DataFrameTreeReduction
from typing_extensions import TypeAlias

from dask_awkward.utils import LazyInputsDict

if _dask_uses_tasks:
from dask._task_spec import Task, TaskRef

if TYPE_CHECKING:
from awkward import Array as AwkwardArray
from awkward._nplikes.typetracer import TypeTracerReport
Expand Down Expand Up @@ -160,14 +167,20 @@ def __init__(
produces_tasks=self.produces_tasks,
)

super().__init__(
output=self.name,
output_indices="i",
dsk={name: (self.io_func, blockwise_token(0))},
indices=[(io_arg_map, "i")],
numblocks={},
annotations=None,
)
super_kwargs: dict[str, Any] = {
"output": self.name,
"output_indices": "i",
"indices": [(io_arg_map, "i")],
"numblocks": {},
"annotations": None,
}

if _dask_uses_tasks:
super_kwargs["task"] = Task(name, self.io_func, TaskRef(blockwise_token(0)))
else:
super_kwargs["dsk"] = {name: (self.io_func, blockwise_token(0))}

super().__init__(**super_kwargs)

def __repr__(self) -> str:
return f"AwkwardInputLayer<{self.output}>"
Expand Down
14 changes: 12 additions & 2 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@
from dask.utils import OperatorMethodMixin as DaskOperatorMethodMixin
from dask.utils import funcname, is_arraylike, key_split

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer
from dask_awkward.layers import (
AwkwardBlockwiseLayer,
AwkwardMaterializedLayer,
_dask_uses_tasks,
)
from dask_awkward.lib.optimize import all_optimizations
from dask_awkward.utils import (
ConcretizationTypeError,
Expand All @@ -57,6 +61,9 @@
is_empty_slice,
)

if _dask_uses_tasks:
from dask._task_spec import TaskRef

if TYPE_CHECKING:
from awkward.contents.content import Content
from awkward.forms.form import Form
Expand Down Expand Up @@ -1928,7 +1935,10 @@ def partitionwise_layer(
pairs.extend([arg.name, "i"])
numblocks[arg.name] = (1,)
elif isinstance(arg, Delayed):
pairs.extend([arg.key, None])
if _dask_uses_tasks:
pairs.extend([TaskRef(arg.key), None])
else:
pairs.extend([arg.key, None])
elif is_dask_collection(arg):
raise DaskAwkwardNotImplemented(
"Use of Array with other Dask collections is currently unsupported."
Expand Down
70 changes: 55 additions & 15 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@
from dask.highlevelgraph import HighLevelGraph
from dask.local import get_sync

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer
from dask_awkward.layers import (
AwkwardBlockwiseLayer,
AwkwardInputLayer,
_dask_uses_tasks,
)
from dask_awkward.lib.utils import typetracer_nochecks
from dask_awkward.utils import first

if _dask_uses_tasks:
from dask._task_spec import GraphNode, Task, TaskRef

if TYPE_CHECKING:
from awkward._nplikes.typetracer import TypeTracerReport
from dask.typing import Key
Expand Down Expand Up @@ -234,14 +241,23 @@ def _touch_all_data(*args, **kwargs):

def _mock_output(layer):
"""Update a layer to run the _touch_all_data."""
assert len(layer.dsk) == 1
if _dask_uses_tasks:
new_layer = copy.deepcopy(layer)
task = new_layer.task.copy()
# replace the original function with _touch_all_data
# and keep the rest of the task the same
task.func = _touch_all_data
new_layer.task = task
return new_layer
else:
assert len(layer.dsk) == 1

new_layer = copy.deepcopy(layer)
mp = new_layer.dsk.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_all_data,) + mp[k][1:]
new_layer.dsk = mp
return new_layer
new_layer = copy.deepcopy(layer)
mp = new_layer.dsk.copy()
for k in iter(mp.keys()):
mp[k] = (_touch_all_data,) + mp[k][1:]
new_layer.dsk = mp
return new_layer


@no_type_check
Expand Down Expand Up @@ -340,7 +356,10 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
deps[outkey] = deps[chain[0]]
[deps.pop(ch) for ch in chain[:-1]]

subgraph = layer0.dsk.copy() # mypy: ignore
if _dask_uses_tasks:
all_tasks = [layer0.task]
else:
subgraph = layer0.dsk.copy()
indices = list(layer0.indices)
parent = chain[0]

Expand All @@ -349,14 +368,28 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
layer = dsk.layers[chain_member]
for k in layer.io_deps: # mypy: ignore
outlayer.io_deps[k] = layer.io_deps[k]
func, *args = layer.dsk[chain_member] # mypy: ignore
args2 = _recursive_replace(args, layer, parent, indices)
subgraph[chain_member] = (func,) + tuple(args2)

if _dask_uses_tasks:
func = layer.task.func
args = [
arg.key if isinstance(arg, GraphNode) else arg
for arg in layer.task.args
]
# how to do this with `.substitute(...)`?
args2 = _recursive_replace(args, layer, parent, indices)
all_tasks.append(Task(chain_member, func, *args2))
else:
func, *args = layer.dsk[chain_member] # mypy: ignore
args2 = _recursive_replace(args, layer, parent, indices)
subgraph[chain_member] = (func,) + tuple(args2)
parent = chain_member
outlayer.numblocks = {
i[0]: (numblocks,) for i in indices if i[1] is not None
} # mypy: ignore
outlayer.dsk = subgraph # mypy: ignore
if _dask_uses_tasks:
outlayer.task = Task.fuse(*all_tasks)
else:
outlayer.dsk = subgraph # mypy: ignore
if hasattr(outlayer, "_dims"):
del outlayer._dims
outlayer.indices = tuple( # mypy: ignore
Expand All @@ -379,11 +412,18 @@ def _recursive_replace(args, layer, parent, indices):
args2.append(layer.indices[ind][0])
elif layer.indices[ind][0] == parent:
# arg refers to output of previous layer
args2.append(parent)
if _dask_uses_tasks:
args2.append(TaskRef(parent))
else:
args2.append(parent)
else:
# arg refers to things defined in io_deps
indices.append(layer.indices[ind])
args2.append(f"__dask_blockwise__{len(indices) - 1}")
arg2 = f"__dask_blockwise__{len(indices) - 1}"
if _dask_uses_tasks:
args2.append(TaskRef(arg2))
else:
args2.append(arg2)
elif isinstance(arg, list):
args2.append(_recursive_replace(arg, layer, parent, indices))
elif isinstance(arg, tuple):
Expand Down
8 changes: 6 additions & 2 deletions tests/test_io_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

import dask_awkward as dak
from dask_awkward.layers import _dask_uses_tasks
from dask_awkward.lib.core import Array
from dask_awkward.lib.optimize import optimize as dak_optimize
from dask_awkward.lib.testutils import assert_eq
Expand Down Expand Up @@ -94,8 +95,11 @@ def input_layer_array_partition0(collection: Array) -> ak.Array:
optimized_hlg = dak_optimize(collection.dask, collection.keys) # type: ignore
layers = list(optimized_hlg.layers) # type: ignore
layer_name = [name for name in layers if name.startswith("from-json")][0]
sgc, arg = optimized_hlg[(layer_name, 0)]
array = sgc.dsk[layer_name][0](arg)
if _dask_uses_tasks:
array = optimized_hlg[(layer_name, 0)]()
else:
sgc, arg = optimized_hlg[(layer_name, 0)]
array = sgc.dsk[layer_name][0](arg)
return array


Expand Down

0 comments on commit d3f3e7c

Please sign in to comment.