diff --git a/pyproject.toml b/pyproject.toml index d5514159..2e4adb1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,8 @@ classifiers = [ ] dependencies = [ "awkward >=2.5.1", - "dask >=2023.04.0", + "dask >=2024.12.0;python_version>'3.9'", + "dask >=2023.04.0;python_version<'3.10'", "cachetools", "typing_extensions >=4.8.0", ] diff --git a/src/dask_awkward/layers/__init__.py b/src/dask_awkward/layers/__init__.py index d4ba4c5e..098bbf2e 100644 --- a/src/dask_awkward/layers/__init__.py +++ b/src/dask_awkward/layers/__init__.py @@ -6,6 +6,7 @@ ImplementsIOFunction, ImplementsProjection, IOFunctionWithMocking, + _dask_uses_tasks, io_func_implements_projection, ) @@ -18,4 +19,5 @@ "ImplementsIOFunction", "IOFunctionWithMocking", "io_func_implements_projection", + "_dask_uses_tasks", ) diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index 92441443..b5c54523 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -4,6 +4,10 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast +import dask + +_dask_uses_tasks = hasattr(dask, "_task_spec") + from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token from dask.highlevelgraph import MaterializedLayer from dask.layers import DataFrameTreeReduction @@ -11,6 +15,9 @@ from dask_awkward.utils import LazyInputsDict +if _dask_uses_tasks: + from dask._task_spec import Task, TaskRef + if TYPE_CHECKING: from awkward import Array as AwkwardArray from awkward._nplikes.typetracer import TypeTracerReport @@ -160,14 +167,20 @@ def __init__( produces_tasks=self.produces_tasks, ) - super().__init__( - output=self.name, - output_indices="i", - dsk={name: (self.io_func, blockwise_token(0))}, - indices=[(io_arg_map, "i")], - numblocks={}, - annotations=None, - ) + super_kwargs: dict[str, Any] = { + "output": self.name, + "output_indices": "i", + "indices": [(io_arg_map, "i")], + "numblocks": {}, + "annotations": None, + } + + if _dask_uses_tasks: + super_kwargs["task"] = Task(name, self.io_func, TaskRef(blockwise_token(0))) + else: + super_kwargs["dsk"] = {name: (self.io_func, blockwise_token(0))} + + super().__init__(**super_kwargs) def __repr__(self) -> str: return f"AwkwardInputLayer<{self.output}>" diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index d7c1a4e0..786ff2bb 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -45,7 +45,11 @@ from dask.utils import OperatorMethodMixin as DaskOperatorMethodMixin from dask.utils import funcname, is_arraylike, key_split -from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardMaterializedLayer +from dask_awkward.layers import ( + AwkwardBlockwiseLayer, + AwkwardMaterializedLayer, + _dask_uses_tasks, +) from dask_awkward.lib.optimize import all_optimizations from dask_awkward.utils import ( ConcretizationTypeError, @@ -57,6 +61,9 @@ is_empty_slice, ) +if _dask_uses_tasks: + from dask._task_spec import TaskRef + if TYPE_CHECKING: from awkward.contents.content import Content from awkward.forms.form import Form @@ -1928,7 +1935,10 @@ def partitionwise_layer( pairs.extend([arg.name, "i"]) numblocks[arg.name] = (1,) elif isinstance(arg, Delayed): - pairs.extend([arg.key, None]) + if _dask_uses_tasks: + pairs.extend([TaskRef(arg.key), None]) + else: + pairs.extend([arg.key, None]) elif is_dask_collection(arg): raise DaskAwkwardNotImplemented( "Use of Array with other Dask collections is currently unsupported." diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 6ad2e132..4b9dd6cf 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -13,10 +13,17 @@ from dask.highlevelgraph import HighLevelGraph from dask.local import get_sync -from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer +from dask_awkward.layers import ( + AwkwardBlockwiseLayer, + AwkwardInputLayer, + _dask_uses_tasks, +) from dask_awkward.lib.utils import typetracer_nochecks from dask_awkward.utils import first +if _dask_uses_tasks: + from dask._task_spec import GraphNode, Task, TaskRef + if TYPE_CHECKING: from awkward._nplikes.typetracer import TypeTracerReport from dask.typing import Key @@ -234,14 +241,23 @@ def _touch_all_data(*args, **kwargs): def _mock_output(layer): """Update a layer to run the _touch_all_data.""" - assert len(layer.dsk) == 1 + if _dask_uses_tasks: + new_layer = copy.deepcopy(layer) + task = new_layer.task.copy() + # replace the original function with _touch_all_data + # and keep the rest of the task the same + task.func = _touch_all_data + new_layer.task = task + return new_layer + else: + assert len(layer.dsk) == 1 - new_layer = copy.deepcopy(layer) - mp = new_layer.dsk.copy() - for k in iter(mp.keys()): - mp[k] = (_touch_all_data,) + mp[k][1:] - new_layer.dsk = mp - return new_layer + new_layer = copy.deepcopy(layer) + mp = new_layer.dsk.copy() + for k in iter(mp.keys()): + mp[k] = (_touch_all_data,) + mp[k][1:] + new_layer.dsk = mp + return new_layer @no_type_check @@ -340,7 +356,10 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG deps[outkey] = deps[chain[0]] [deps.pop(ch) for ch in chain[:-1]] - subgraph = layer0.dsk.copy() # mypy: ignore + if _dask_uses_tasks: + all_tasks = [layer0.task] + else: + subgraph = layer0.dsk.copy() indices = list(layer0.indices) parent = chain[0] @@ -349,14 +368,28 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG layer = dsk.layers[chain_member] for k in layer.io_deps: # mypy: ignore outlayer.io_deps[k] = layer.io_deps[k] - func, *args = layer.dsk[chain_member] # mypy: ignore - args2 = _recursive_replace(args, layer, parent, indices) - subgraph[chain_member] = (func,) + tuple(args2) + + if _dask_uses_tasks: + func = layer.task.func + args = [ + arg.key if isinstance(arg, GraphNode) else arg + for arg in layer.task.args + ] + # how to do this with `.substitute(...)`? + args2 = _recursive_replace(args, layer, parent, indices) + all_tasks.append(Task(chain_member, func, *args2)) + else: + func, *args = layer.dsk[chain_member] # mypy: ignore + args2 = _recursive_replace(args, layer, parent, indices) + subgraph[chain_member] = (func,) + tuple(args2) parent = chain_member outlayer.numblocks = { i[0]: (numblocks,) for i in indices if i[1] is not None } # mypy: ignore - outlayer.dsk = subgraph # mypy: ignore + if _dask_uses_tasks: + outlayer.task = Task.fuse(*all_tasks) + else: + outlayer.dsk = subgraph # mypy: ignore if hasattr(outlayer, "_dims"): del outlayer._dims outlayer.indices = tuple( # mypy: ignore @@ -379,11 +412,18 @@ def _recursive_replace(args, layer, parent, indices): args2.append(layer.indices[ind][0]) elif layer.indices[ind][0] == parent: # arg refers to output of previous layer - args2.append(parent) + if _dask_uses_tasks: + args2.append(TaskRef(parent)) + else: + args2.append(parent) else: # arg refers to things defined in io_deps indices.append(layer.indices[ind]) - args2.append(f"__dask_blockwise__{len(indices) - 1}") + arg2 = f"__dask_blockwise__{len(indices) - 1}" + if _dask_uses_tasks: + args2.append(TaskRef(arg2)) + else: + args2.append(arg2) elif isinstance(arg, list): args2.append(_recursive_replace(arg, layer, parent, indices)) elif isinstance(arg, tuple): diff --git a/tests/test_io_json.py b/tests/test_io_json.py index 688fb550..f7192aae 100644 --- a/tests/test_io_json.py +++ b/tests/test_io_json.py @@ -10,6 +10,7 @@ import pytest import dask_awkward as dak +from dask_awkward.layers import _dask_uses_tasks from dask_awkward.lib.core import Array from dask_awkward.lib.optimize import optimize as dak_optimize from dask_awkward.lib.testutils import assert_eq @@ -94,8 +95,11 @@ def input_layer_array_partition0(collection: Array) -> ak.Array: optimized_hlg = dak_optimize(collection.dask, collection.keys) # type: ignore layers = list(optimized_hlg.layers) # type: ignore layer_name = [name for name in layers if name.startswith("from-json")][0] - sgc, arg = optimized_hlg[(layer_name, 0)] - array = sgc.dsk[layer_name][0](arg) + if _dask_uses_tasks: + array = optimized_hlg[(layer_name, 0)]() + else: + sgc, arg = optimized_hlg[(layer_name, 0)] + array = sgc.dsk[layer_name][0](arg) return array